Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 191 additions & 4 deletions benchmark/examples/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,41 @@ def parse_args():
default=None,
help="Override GEMM_SMS for WG-specialized variant (default: auto)",
)

# WG kernel tuning sweep arguments.
parser.add_argument(
"--gemm_sms_values",
type=int,
nargs="+",
default=None,
metavar="N",
help=(
"List of GEMM_SMS values to sweep when --tune is active "
"(e.g. --gemm_sms_values 64 128 192). "
"Ignored unless --tune is set."
),
)
parser.add_argument(
"--block_m_values",
type=int,
nargs="+",
default=None,
metavar="N",
help=(
"List of BLOCK_M tile sizes to sweep when --tune is active "
"(e.g. --block_m_values 64 128 256). "
"Ignored unless --tune is set."
),
)
parser.add_argument(
"--tune",
action="store_true",
help=(
"Sweep all (gemm_sms, block_m) combinations for the WG fusion mode "
"across each bpe and report the best configuration per bpe. "
"Requires --fusion_mode wg_fused_grouped_matmul_convert_ep_to_dp."
),
)
return parser.parse_args()


Expand All @@ -170,6 +205,9 @@ def _run_dist_once(
shmem,
fusion_config,
gemm_sms=None,
block_m=None,
block_n=None,
block_k=None,
):
return mixture_of_expt_epsharded(
x_dp_local,
Expand All @@ -181,7 +219,88 @@ def _run_dist_once(
shmem,
fusion_config=fusion_config,
gemm_sms=gemm_sms,
block_m=block_m,
block_n=block_n,
block_k=block_k,
)


def _bench_dist(run_fn, shmem, heap_snapshot, n_warmup, n_repeat):
"""Benchmark a single distributed run function and return mean latency in ms."""
reset_heap = _make_heap_resetter(shmem.heap.allocator, heap_snapshot)
saved_refresh = shmem.heap.refresh_peer_access
shmem.heap.refresh_peer_access = lambda: None
ms = iris.do_bench(
run_fn,
barrier_fn=shmem.barrier,
preamble_fn=reset_heap,
n_warmup=n_warmup,
n_repeat=n_repeat,
return_mode="mean",
)
shmem.heap.refresh_peer_access = saved_refresh
reset_heap()
return float(ms)


def _tune_wg_configs(
x_dp_local,
l_dp_local,
w_ep_local,
b_ep_local,
expt_assignment,
n_expts_act,
shmem,
fusion_config,
heap_snapshot,
gemm_sms_values,
block_m_values,
rank,
n_warmup=5,
n_repeat=20,
):
"""Sweep (gemm_sms, block_m) combinations and return the best config.

Returns:
best_gemm_sms (int), best_block_m (int), tune_configs (list[dict])
"""
tune_configs = []
best_ms = float("inf")
best_gemm_sms = gemm_sms_values[0]
best_block_m = block_m_values[0]

for gs in gemm_sms_values:
for bm in block_m_values:
run_fn = functools.partial(
_run_dist_once,
x_dp_local,
l_dp_local,
w_ep_local,
b_ep_local,
expt_assignment,
n_expts_act,
shmem,
fusion_config,
gs,
bm,
)
try:
ms = _bench_dist(run_fn, shmem, heap_snapshot, n_warmup, n_repeat)
except Exception as e:
if rank == 0:
print(f" [tune] gemm_sms={gs} block_m={bm} FAILED: {e}")
ms = float("inf")

if rank == 0:
print(f" [tune] gemm_sms={gs:4d} block_m={bm:4d} ms={ms:.3f}")
tune_configs.append({"gemm_sms": gs, "block_m": bm, "ms": ms})

if ms < best_ms:
best_ms = ms
best_gemm_sms = gs
best_block_m = bm

return best_gemm_sms, best_block_m, tune_configs


def _worker(rank: int, world_size: int, init_url: str, args):
Expand All @@ -205,6 +324,9 @@ def _worker(rank: int, world_size: int, init_url: str, args):
if args.n_expts_tot % ws != 0:
raise ValueError(f"n_expts_tot ({args.n_expts_tot}) must be divisible by world_size ({ws})")

if getattr(args, "tune", False) and args.fusion_mode != "wg_fused_grouped_matmul_convert_ep_to_dp":
raise ValueError("--tune requires --fusion_mode wg_fused_grouped_matmul_convert_ep_to_dp")

if args.batch_per_expt:
sweep = args.batch_per_expt
else:
Expand All @@ -213,6 +335,27 @@ def _worker(rank: int, world_size: int, init_url: str, args):
if rank == 0:
os.makedirs(args.output_dir, exist_ok=True)

# Derive default sweep values for tune mode.
if getattr(args, "tune", False):
cu_count = torch.cuda.get_device_properties(device).multi_processor_count
num_sms = int(cu_count)
if getattr(args, "gemm_sms_values", None):
gemm_sms_sweep = args.gemm_sms_values
else:
# Default: quarter, half, and three-quarter of available SMs,
# each clamped to [1, num_sms - 1].
gemm_sms_sweep = sorted(
{
max(1, min(num_sms // 4, num_sms - 1)),
max(1, min(num_sms // 2, num_sms - 1)),
max(1, min(3 * num_sms // 4, num_sms - 1)),
}
)
block_m_sweep = getattr(args, "block_m_values", None) or [32, 64, 128, 256]
else:
gemm_sms_sweep = None
block_m_sweep = None

results: list[dict] = []
sweep_heap_base = shmem.heap.allocator.heap_offset

Expand Down Expand Up @@ -247,6 +390,41 @@ def _worker(rank: int, world_size: int, init_url: str, args):
w_ep_local = w_global[expt_assignment.expt_boolmask[rank]].contiguous()
b_ep_local = b_global[expt_assignment.expt_boolmask[rank]].contiguous()

# --- Tune: sweep (gemm_sms, block_m) and pick the best config -------
tune_result = {}
active_gemm_sms = args.gemm_sms
active_block_m = None
if getattr(args, "tune", False):
heap_snapshot = shmem.heap.allocator.heap_offset
if rank == 0:
print(f"[tune bpe={bpe}] sweeping gemm_sms={gemm_sms_sweep} block_m={block_m_sweep}")
best_gs, best_bm, tune_configs = _tune_wg_configs(
x_dp_local,
l_dp_local,
w_ep_local,
b_ep_local,
expt_assignment,
args.n_expts_act,
shmem,
fusion_config,
heap_snapshot,
gemm_sms_sweep,
block_m_sweep,
rank,
n_warmup=5,
n_repeat=20,
)
active_gemm_sms = best_gs
active_block_m = best_bm
tune_result = {
"tune_best_gemm_sms": best_gs,
"tune_best_block_m": best_bm,
"tune_configs": tune_configs,
}
if rank == 0:
print(f"[tune bpe={bpe}] best: gemm_sms={best_gs} block_m={best_bm}")
shmem.heap.allocator.heap_offset = sweep_heap_base

run_dist = functools.partial(
_run_dist_once,
x_dp_local,
Expand All @@ -257,7 +435,8 @@ def _worker(rank: int, world_size: int, init_url: str, args):
args.n_expts_act,
shmem,
fusion_config,
args.gemm_sms,
active_gemm_sms,
active_block_m,
)

if args.validate or args.compare_single_gpu:
Expand All @@ -284,7 +463,8 @@ def _worker(rank: int, world_size: int, init_url: str, args):
shmem,
fusion_config=fusion_config,
timing_dict=td,
gemm_sms=args.gemm_sms,
gemm_sms=active_gemm_sms,
block_m=active_block_m,
)
if rank == 0:
for j in range(1, len(td)):
Expand All @@ -310,6 +490,7 @@ def _worker(rank: int, world_size: int, init_url: str, args):
"dtype": args.datatype,
"fusion_mode": fusion_config.mode_name(),
}
result.update(tune_result)

if args.validate:
diff = (y_ref.float() - y_tri.float()).abs()
Expand Down Expand Up @@ -364,6 +545,12 @@ def run_ref():
else ""
)
+ (f" max_diff={result.get('validate_max_diff', 0.0):.4f}" if args.validate else "")
+ (
f" best_config=(gemm_sms={tune_result['tune_best_gemm_sms']}"
f" block_m={tune_result['tune_best_block_m']})"
if tune_result
else ""
)
)
results.append(result)

Expand Down Expand Up @@ -393,8 +580,8 @@ def run_ref():

def main():
args = parse_args()
if not args.benchmark and not args.validate and not args.compare_single_gpu:
print("No mode selected. Use at least one of: --benchmark, --validate, --compare_single_gpu")
if not args.benchmark and not args.validate and not args.compare_single_gpu and not getattr(args, "tune", False):
print("No mode selected. Use at least one of: --benchmark, --validate, --compare_single_gpu, --tune")
sys.exit(1)

init_url = f"tcp://127.0.0.1:{args.init_port}"
Expand Down
65 changes: 53 additions & 12 deletions examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp_wg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
Lock granularity: one lock per (expert, N-tile, M-tile) triple.
"""

import math

import torch
import triton
import triton.language as tl
Expand Down Expand Up @@ -189,6 +187,33 @@ def _wg_fused_exp_matmul_ep_to_dp_kernel(
iris.store(dst_ptrs_2d, out, SRC_RANK, r, heap_bases, mask=store_mask, hint=(1, 16))


def _heuristic_wg_config(num_sms: int, avg_bpe: int) -> tuple[int, int]:
"""Select (gemm_sms, block_m) heuristically based on avg tokens-per-expert.

Heuristic derived from a tune sweep on MI300X (304 CUs, 8 ranks) across
bpe ∈ {64, 128, 256, 512, 1024} with block_m ∈ {32, 64, 128, 256} and
gemm_sms ∈ {¼, ½, ¾} × num_sms:

bpe ≤ 64 → gemm_sms = num_sms // 2, block_m = 128
bpe ≤ 128 → gemm_sms = 3 * num_sms // 4, block_m = 128
bpe > 128 → gemm_sms = 3 * num_sms // 4, block_m = 256

Args:
num_sms: Total CU count on the device.
avg_bpe: Average number of tokens routed per local expert
(n_slots_per_rank // n_local_experts).

Returns:
(gemm_sms, block_m) tuple.
"""
if avg_bpe <= 64:
return max(1, num_sms // 2), 128
elif avg_bpe <= 128:
return max(1, 3 * num_sms // 4), 128
else:
return max(1, 3 * num_sms // 4), 256


def wg_fused_exp_matmul_ep_to_dp(
x_ep_local: torch.Tensor,
w_ep_local: torch.Tensor,
Expand All @@ -200,6 +225,9 @@ def wg_fused_exp_matmul_ep_to_dp(
shmem,
ragged_metadata: RaggedTensorMetadata | None = None,
gemm_sms: int | None = None,
block_m: int | None = None,
block_n: int | None = None,
block_k: int | None = None,
) -> torch.Tensor:
"""WG-specialized fused expert matmul + EP->DP scatter.

Expand All @@ -216,7 +244,14 @@ def wg_fused_exp_matmul_ep_to_dp(
combine_indx: (n_total_slots,) col_sorted_indx.
shmem: iris.Iris instance.
ragged_metadata: local-expert-view ragged metadata.
gemm_sms: Number of CUs for GEMM path. Default: 2^floor(log2(cu_count)).
gemm_sms: Number of CUs for GEMM path.
Default: auto-selected by _heuristic_wg_config based on avg bpe.
block_m: GEMM tile size along the M (token) dimension.
Default: auto-selected by _heuristic_wg_config based on avg bpe.
block_n: GEMM tile size along the N (output) dimension.
Default: min(triton.next_power_of_2(N), 128).
block_k: GEMM tile size along the K (reduction) dimension.
Default: min(triton.next_power_of_2(K), 64).

Returns:
(n_slots_per_rank, d_model) DP-local combined output.
Expand All @@ -228,9 +263,21 @@ def wg_fused_exp_matmul_ep_to_dp(
K = d_model
N = d_model

BLOCK_M = 128
BLOCK_N = min(triton.next_power_of_2(N), 128)
BLOCK_K = min(triton.next_power_of_2(K), 64)
device = x_ep_local.device
num_sms = torch.cuda.get_device_properties(device).multi_processor_count

# Derive heuristic defaults for gemm_sms / block_m when not specified.
if gemm_sms is None or block_m is None:
avg_bpe = n_slots_per_rank // max(n_local_experts, 1)
h_gemm_sms, h_block_m = _heuristic_wg_config(num_sms, avg_bpe)
if gemm_sms is None:
gemm_sms = h_gemm_sms
if block_m is None:
block_m = h_block_m

BLOCK_M = block_m
BLOCK_N = block_n if block_n is not None else min(triton.next_power_of_2(N), 128)
BLOCK_K = block_k if block_k is not None else min(triton.next_power_of_2(K), 64)

max_slice_size = int(ragged_metadata.slice_sizes.max().item())
max_m_tiles = triton.cdiv(max_slice_size, BLOCK_M)
Expand All @@ -242,12 +289,6 @@ def wg_fused_exp_matmul_ep_to_dp(
shmem.barrier()
return dst_local

device = x_ep_local.device
cu_count = torch.cuda.get_device_properties(device).multi_processor_count
num_sms = cu_count
if gemm_sms is None:
gemm_sms = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1

y_buf = torch.zeros((n_total_slots, N), dtype=x_ep_local.dtype, device=device)
dst_local = shmem.zeros((n_slots_per_rank, d_model), dtype=x_ep_local.dtype)
n_locks = n_n_tiles * n_local_experts * max_m_tiles
Expand Down
Loading