Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
e171679
sample kernel and test
panditsa Aug 25, 2025
07cbe57
shmem exp
panditsa Aug 26, 2025
a4c908c
turn off min. shared allocs - temp
panditsa Sep 2, 2025
b465b07
updated moe block align kernel
panditsa Sep 2, 2025
1de9d0a
fixed test and kernel
panditsa Sep 2, 2025
f3d624d
padded counts calculation
panditsa Sep 2, 2025
5a552b6
update moe to use cumsum
panditsa Sep 8, 2025
b904db6
create inclusive and exclusive cumsum buffers
panditsa Sep 10, 2025
dc61d93
cleanup
panditsa Sep 10, 2025
a62e13f
[HACK] node in custom.start, not in custom.implicit_captures
panditsa Sep 12, 2025
de66aa3
[Working] expert_ids calculated
panditsa Sep 12, 2025
6891ba5
working sorted_token_ids
panditsa Sep 15, 2025
0f042bb
examples-simple
panditsa Sep 15, 2025
57a997e
exp. with scalar write
panditsa Sep 15, 2025
7799b57
cleanup
panditsa Sep 23, 2025
bcdc247
parametrize block_sizes
panditsa Sep 23, 2025
3acb909
first draft
panditsa Sep 23, 2025
9868dd5
reference code, simplified from triton implementation
panditsa Sep 24, 2025
2b556f4
some simple gemm exps
panditsa Sep 24, 2025
d8f269c
indexed weight GEMM example
panditsa Sep 25, 2025
11c4ca2
remove unnecessary flags
panditsa Sep 25, 2025
03f7521
standalone gemm.py test that dynamically creates A
panditsa Sep 28, 2025
9ee890f
reorder A for test
panditsa Sep 30, 2025
4672e6b
WIP
panditsa Sep 30, 2025
a36bd2c
basic scatter working
panditsa Oct 1, 2025
7bb9e22
simple scatter with gemm working, gemm broken
panditsa Oct 1, 2025
cba033e
removed one condition
panditsa Oct 1, 2025
a261e01
scatter A gemm working
panditsa Oct 1, 2025
fbfa869
fixed expert based scatter gemm test
panditsa Oct 1, 2025
e392d07
reorder test based on complexity
panditsa Oct 1, 2025
30d0ec7
hackfixme
panditsa Oct 1, 2025
92bbdfc
scatter gemm with padding value
panditsa Oct 1, 2025
3a9c7bc
remove scatter_gemm.mlir file
panditsa Oct 1, 2025
6dd4a38
working scatter-gather gemm for one expert
panditsa Oct 2, 2025
2185d36
hackfixme
panditsa Oct 3, 2025
2f537bf
moe gemm example
panditsa Oct 3, 2025
b6df91a
use moe_gemm in moe.py and test
panditsa Oct 6, 2025
715d736
silu_and_mul check
panditsa Oct 7, 2025
ce8f1b1
add second gemm and comments
panditsa Oct 7, 2025
1216e4f
add a reduce_sum kernel
panditsa Oct 7, 2025
f852e4f
use the sum kernel
panditsa Oct 7, 2025
16f1860
in reference calculate topk ids
panditsa Oct 7, 2025
f87b07a
silu and mul fixes and test
panditsa Oct 7, 2025
ec7a2a9
working moe
panditsa Oct 7, 2025
b4cb085
update test.py
panditsa Oct 7, 2025
bddec3a
placeholder has no index, get_custom first
panditsa Oct 8, 2025
98a3d3f
update all gemm examples
panditsa Oct 8, 2025
77ddb43
fix block align
panditsa Oct 8, 2025
57cf100
use wave moe_align_block_size kernel
panditsa Oct 8, 2025
06d88f2
cleanup
panditsa Oct 8, 2025
7f55706
wip
panditsa Oct 8, 2025
3b0beb1
non-fast dims work? reduction
panditsa Oct 9, 2025
e8241da
passing test
panditsa Oct 9, 2025
6df684c
update the reduce kernel to add broadcast
panditsa Oct 9, 2025
524bdbc
working moe
panditsa Oct 9, 2025
99cc679
final cleanup
panditsa Oct 9, 2025
d4bfb5b
use wave topk
panditsa Oct 16, 2025
e78bd99
WIP, large histogram
panditsa Oct 20, 2025
ced6c17
set the index only for specific dimension
panditsa Oct 22, 2025
347d083
fused gemm example
panditsa Oct 22, 2025
fb102c7
silu_and_mul update MILESTONE
panditsa Oct 22, 2025
4e472fb
tensor ops
panditsa Oct 27, 2025
f8c10fd
examples large histogram
panditsa Oct 27, 2025
7f45f33
examples gemm-gemm
panditsa Oct 27, 2025
b20f689
fix type propagation issues
panditsa Nov 13, 2025
eea02d8
block scan changes still required
panditsa Nov 13, 2025
fb97962
MoE working with PyTorch code chunks for gather, scatter, and routing…
nirmie Mar 3, 2026
ef77c21
Merge branch 'MoE-clean' into MoE
nirmie Mar 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,776 changes: 1,776 additions & 0 deletions examples/gemm.py

Large diffs are not rendered by default.

189 changes: 189 additions & 0 deletions examples/python/3_atomics.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,195 @@ def wave_kernel(
print(c)


def test_histogram(is_debug=False):
NUM_EXPERTS = tkl.sym.NUM_EXPERTS

"""Atomic add operation to a histogram using dynamic mapping."""
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(1, 1, 1),
vector_shapes={M: M, NUM_EXPERTS: NUM_EXPERTS},
)
]
constraints += [tkw.WorkgroupConstraint(M, M, 0)]
constraints += [tkw.WorkgroupConstraint(NUM_EXPERTS, NUM_EXPERTS, 1)]
constraints += [tkw.WaveConstraint(M, M)]
constraints += [tkw.WaveConstraint(NUM_EXPERTS, NUM_EXPERTS)]

i = tkw.IndexMapping.iterator(0)
d0 = tkw.IndexMapping.dynamic_val(0)

topk_read_map = tkw.IndexMapping(
num_iterators=1,
inputs={M: d0},
outputs={M: i},
dynamic_val_mappings={M: i},
)

expert_read_map = tkw.IndexMapping(
num_iterators=1,
inputs={NUM_EXPERTS: d0},
outputs={NUM_EXPERTS: i},
dynamic_val_mappings={NUM_EXPERTS: i},
)

@tkw.wave(constraints)
def histogram_atomic_add(
topk_ids: tkl.Memory[M, ADDRESS_SPACE, tkl.i32],
experts: tkl.Memory[NUM_EXPERTS, ADDRESS_SPACE, tkl.i32],
):
one_reg = tkw.Register[NUM_EXPERTS, tkl.i32](1)
tid = tkw.scalar(THREAD_0, tkl.i32)

zero_vec = tkl.Register[NUM_EXPERTS, tkl.i32](0)
shmem = tkw.allocate(
shape=(NUM_EXPERTS,),
distributed_shape=(NUM_EXPERTS,),
dtype=tkl.i32,
)
tkw.write(zero_vec, shmem)

expert_id = tkw.read(
topk_ids,
mapping=topk_read_map,
mapping_dynamic_vals=(tid,),
elements_per_thread=1,
)

tkw.atomic_add(
one_reg,
shmem,
mapping=expert_read_map,
mapping_dynamic_vals=(expert_id,),
elements_per_thread=1,
)

tmp = tkw.read(shmem)
tkw.write(tmp, experts)

num_experts = 10
num_tokens = 64
hyperparams = {
M: num_tokens,
NUM_EXPERTS: num_experts,
}
options = WaveCompileOptions(
subs=hyperparams,
canonicalize=True,
minimize_shared_allocs=False,
print_ir_after="all" if is_debug else [],
)
histogram_atomic_add = wave_compile(options, histogram_atomic_add)
if is_debug:
print(histogram_atomic_add.asm)

topk_ids = torch.randint(0, num_experts, (num_tokens,), dtype=torch.int32).cuda()
experts = torch.zeros((num_experts,), dtype=torch.int32).cuda()
histogram_atomic_add(topk_ids, experts)
print("topk_ids: ", topk_ids)
print("experts: ", experts)
print("expected experts: ", torch.bincount(topk_ids, minlength=num_experts))


def test_large_histogram(is_debug=False):
NUM_EXPERTS = tkl.sym.NUM_EXPERTS
TOKEN_OFFSET = tkl.sym.TOKEN_OFFSET
"""Atomic add operation to a histogram using dynamic mapping."""
constraints: list[tkw.Constraint] = []
constraints += [tkw.WorkgroupConstraint(M, M, 0)]
constraints += [tkw.WorkgroupConstraint(NUM_EXPERTS, NUM_EXPERTS, 1)]
constraints += [tkw.WaveConstraint(M, M)]
constraints += [tkw.WaveConstraint(NUM_EXPERTS, NUM_EXPERTS)]

constraints += [tkw.TilingConstraint(TOKEN_OFFSET)]

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(1, 1, 1),
vector_shapes={M: M, NUM_EXPERTS: NUM_EXPERTS, TOKEN_OFFSET: 0},
)
]

i = tkw.IndexMapping.iterator(0)
d0 = tkw.IndexMapping.dynamic_val(0)

topk_read_map = tkw.IndexMapping(
num_iterators=1,
inputs={M: d0},
outputs={M: i},
dynamic_val_mappings={M: i},
)

expert_read_map = tkw.IndexMapping(
num_iterators=1,
inputs={NUM_EXPERTS: d0},
outputs={NUM_EXPERTS: i},
dynamic_val_mappings={NUM_EXPERTS: i},
)

@tkw.wave(constraints)
def histogram_atomic_add(
topk_ids: tkl.Memory[M, ADDRESS_SPACE, tkl.i32],
experts: tkl.Memory[NUM_EXPERTS, ADDRESS_SPACE, tkl.i32],
):
one_reg = tkw.Register[NUM_EXPERTS, tkl.i32](1)
zero_reg = tkw.Register[TOKEN_OFFSET, tkl.i32](0)

loop_condition = TOKEN_OFFSET < M

@tkw.iterate(
TOKEN_OFFSET, start=zero_reg, condition=loop_condition, init_args=[]
)
def count_tokens():
token_idx = tkw.self_index(TOKEN_OFFSET, tkl.i32)
tid_reg = tkw.Register[TOKEN_OFFSET, tkl.i32](THREAD_0)
token_idx = token_idx * tkl.Register[TOKEN_OFFSET, tkl.i32](64) + tid_reg

expert_id = tkw.read(
topk_ids,
mapping=topk_read_map,
mapping_dynamic_vals=(token_idx,),
elements_per_thread=1,
)

tkw.atomic_add(
one_reg,
experts,
mapping=expert_read_map,
mapping_dynamic_vals=(expert_id,),
elements_per_thread=1,
)

next_token_idx = token_idx + tkl.Register[TOKEN_OFFSET, tkl.i32](64)
tkw.set_symbol(TOKEN_OFFSET, next_token_idx)

num_experts = 10
num_tokens = 64
hyperparams = {
M: num_tokens,
NUM_EXPERTS: num_experts,
}
options = WaveCompileOptions(
subs=hyperparams,
canonicalize=True,
minimize_shared_allocs=False,
print_ir_after="all" if is_debug else [],
)
histogram_atomic_add = wave_compile(options, histogram_atomic_add)
if is_debug:
print(histogram_atomic_add.asm)

topk_ids = torch.randint(0, num_experts, (num_tokens,), dtype=torch.int32).cuda()
experts = torch.zeros((num_experts,), dtype=torch.int32).cuda()

histogram_atomic_add(topk_ids, experts)
print("topk_ids: ", topk_ids)
print("experts: ", experts)
print("expected experts: ", torch.bincount(topk_ids, minlength=num_experts))


if __name__ == "__main__":
args = parse_args()
if args.list_tests:
Expand Down
143 changes: 142 additions & 1 deletion examples/python/5_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import torch
import argparse

import wave_lang.kernel.wave as tkw
import wave_lang.kernel.lang as tkl
from wave_lang.kernel._support.dtype import f16, f32, i32
from wave_lang.kernel._support.indexing import sym
from wave_lang.kernel.lang.global_symbols import *
Expand Down Expand Up @@ -1621,6 +1621,147 @@ def then():
print("GEMM test passed!")


def fused_gemms(is_debug=False):
"""Fused GEMM kernel where we run two GEMMs back to back."""
N1 = sym.N1
N2 = sym.N2
BLOCK_N1 = sym.BLOCK_N1
BLOCK_N2 = sym.BLOCK_N2

# Define constraints for the kernel
constraints = [
tkw.WorkgroupConstraint(M, BLOCK_M, 0),
tkw.WorkgroupConstraint(N2, BLOCK_N2, 1),
tkw.WaveConstraint(M, BLOCK_M / 2),
tkw.WaveConstraint(N2, BLOCK_N2 / 2),
tkw.TilingConstraint(K, BLOCK_K),
tkw.TilingConstraint(N1, BLOCK_N1),
tkw.HardwareConstraint(
threads_per_wave=64,
mma_type=tkw.MMAType.F32_16x16x16_F16,
vector_shapes={M: 16, N1: 16, N2: 16, K: 16},
),
]

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)
k = tkw.IndexMapping.iterator(2)
d0 = tkw.IndexMapping.dynamic_val(0)

a_read_map = tkw.IndexMapping(
num_iterators=2,
inputs={M: d0, K: j},
outputs={M: i, K: j},
dynamic_val_mappings={M: i},
)

w1_read_map = tkw.IndexMapping(
num_iterators=2,
inputs={N1: i, K: j},
outputs={N1: i, K: j},
)

w2_read_map = tkw.IndexMapping(
num_iterators=2,
inputs={N2: i, N1: j},
outputs={N2: i, N1: j},
)

@tkw.wave(constraints)
def gemm(
a: Memory[M, K, ADDRESS_SPACE_A, f16], # Input matrix A
w1: Memory[N1, K, ADDRESS_SPACE_B, f16], # Input matrix B
w2: Memory[N2, N1, ADDRESS_SPACE_B, f16], # Input matrix D
c: Memory[M, N2, ADDRESS_SPACE_C, f32], # Output matrix C
):
# Initialize the accumulator register with zeros
c_reg1 = Register[M, N1, f32](0.0)
c_reg2 = Register[M, N2, f32](0.0)

c_back1 = tkw.allocate(
shape=(M, N1),
distributed_shape=(M, N1),
dtype=tkl.f32,
)

# Iterate over the K dimension to compute the dot product
@tkw.iterate(K, init_args=[c_reg1])
def repeat1(acc: Register[M, N1, f32]) -> Register[M, N1, f32]:
# Load elements from A and B
a_reg = tkw.read(a)
w1_reg = tkw.read(w1)
acc = tkw.mma(a_reg, w1_reg, acc)
return acc

# Store the final result to C
tkw.write(repeat1, c_back1)

@tkw.iterate(N1, init_args=[c_reg2])
def repeat2(acc: Register[M, N2, f32]) -> Register[M, N2, f32]:
# Load elements from A and B
a_reg = tkw.read(c_back1)
a_reg = tkw.cast(a_reg, f16)
w2_reg = tkw.read(w2)
acc = tkw.mma(a_reg, w2_reg, acc)
return acc

# Store the final result to C
tkw.write(repeat2, c)

# Create test matrices
m, k = 64, 64 # Small dimensions for testing
n1, n2 = 64, 64
# Initialize input matrices with random values
torch.manual_seed(0)
a = torch.randn(m, k, dtype=torch.float16, device="cuda")
w1 = torch.randn(n1, k, dtype=torch.float16, device="cuda")
w2 = torch.randn(n2, n1, dtype=torch.float16, device="cuda")
c = torch.zeros(m, n2, dtype=torch.float32, device="cuda")
c_back1 = torch.zeros(m, n1, dtype=torch.float32, device="cuda")

# Set hyperparameters for compilation
hyperparams = {
ADDRESS_SPACE_A: GLOBAL_ADDRESS_SPACE,
ADDRESS_SPACE_B: GLOBAL_ADDRESS_SPACE,
ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE,
BLOCK_M: 64,
BLOCK_N1: 64,
BLOCK_N2: 64,
BLOCK_K: 32,
M: m,
N1: n1,
N2: n2,
K: k,
}

# Compile the kernel
options = WaveCompileOptions(
subs=hyperparams,
print_ir_after="all" if is_debug else [],
)
options = set_default_run_config(options)
compiled_gemm = wave_compile(options, gemm)

if is_debug:
print(compiled_gemm.asm)
with open("gemm.mlir", "w") as f:
f.write(compiled_gemm.asm)

# Run the GEMM kernel
compiled_gemm(a, w1, w2, c)

# Verify the result using PyTorch's matmul
expected = torch.matmul(a, w1.t())
expected = torch.matmul(expected, w2.t())

# Check if results are close (accounting for floating-point precision)
assert torch.allclose(
c.to(torch.float16), expected, rtol=1e-2, atol=1e-2
), f"GEMM result doesn't match expected output\nMax difference: {(c - expected).abs().max()}"

print("GEMM test passed!")


if __name__ == "__main__":
args = parse_args()
if args.list_tests:
Expand Down
Loading
Loading