diff --git a/examples/gemm.py b/examples/gemm.py new file mode 100644 index 0000000000..7246be2465 --- /dev/null +++ b/examples/gemm.py @@ -0,0 +1,1776 @@ +import torch +import argparse + +import wave_lang.kernel.wave as tkw +from wave_lang.kernel._support.dtype import f16, f32, i32 +from wave_lang.kernel._support.indexing import sym +from wave_lang.kernel.lang.global_symbols import * +from wave_lang.kernel.lang.wave_types import * +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile +from wave_lang.kernel.wave.utils.run_utils import set_default_run_config + +# Define symbolic dimensions for our matrices +M = sym.M # Rows of A and C +N = sym.N # Rows of B and columns of C +K = sym.K # Columns of A and B + +# Define workgroup tile sizes +BLOCK_M = sym.BLOCK_M +BLOCK_N = sym.BLOCK_N +BLOCK_K = sym.BLOCK_K + +# Define the address space for our memory buffers +ADDRESS_SPACE_A = sym.ADDRESS_SPACE_A +ADDRESS_SPACE_B = sym.ADDRESS_SPACE_B +ADDRESS_SPACE_C = sym.ADDRESS_SPACE_C + + +def parse_args(): + parser = argparse.ArgumentParser() + # one of the tests or list_tests is required + parser.add_argument("--test", type=str, required=False) + parser.add_argument("--list_tests", action="store_true") + parser.add_argument("--debug", action="store_true") + parser.add_argument("--repeat", type=int, default=1) + return parser.parse_args() + + +def list_tests(): + # find all the functions in the file that end with _test + tests = [f for f in globals() if f.endswith("_test")] + print("Available tests:") + for test in tests: + print(f" {test}") + + +def simple_gemm_test(): + # Define constraints for the kernel + constraints = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + tkw.TilingConstraint(K, BLOCK_K), + tkw.WaveConstraint(M, BLOCK_M / 2), + tkw.WaveConstraint(N, BLOCK_N / 2), + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=tkw.MMAType.F32_16x16x16_F16, + vector_shapes={M: 16, N: 16, K: 16}, + ), + ] + + @tkw.wave(constraints) + def gemm( + a: Memory[M, K, ADDRESS_SPACE_A, f16], # Input matrix A + b: Memory[N, K, ADDRESS_SPACE_B, f16], # Input matrix B + c: Memory[M, N, ADDRESS_SPACE_C, f32], # Output matrix C + ): + # Initialize the accumulator register with zeros + c_reg = Register[M, N, f32](0.0) + + # Iterate over the K dimension to compute the dot product + @tkw.iterate(K, init_args=[c_reg]) + def repeat(acc: Register[M, N, f32]) -> Register[M, N, f32]: + # Load elements from A and B + a_reg = tkw.read(a) + b_reg = tkw.read(b) + + # Compute matrix multiplication and accumulate + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + # Store the final result to C + tkw.write(repeat, c) + + # Create test matrices + m, n, k = 64, 64, 64 # Small dimensions for testing + + # Initialize input matrices with random values + torch.manual_seed(0) + a = torch.randn(m, k, dtype=torch.float16, device="cuda") + b = torch.randn(n, k, dtype=torch.float16, device="cuda") + c = torch.zeros(m, n, dtype=torch.float32, device="cuda") + + # Set hyperparameters for compilation + hyperparams = { + ADDRESS_SPACE_A: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_B: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: m, + N: n, + K: k, + } + + # Compile the kernel + options = WaveCompileOptions( + subs=hyperparams, + print_ir_after="all", + print_ir_before="all", + ) + options = set_default_run_config(options) + compiled_gemm = wave_compile(options, gemm) + + # Run the GEMM kernel + compiled_gemm(a, b, c) + + with open("gemm.mlir", "w") as f: + f.write(compiled_gemm.asm) + + # Verify the result using PyTorch's matmul + expected = torch.matmul(a, b.t()) + + # Check if results are close (accounting for floating-point precision) + assert torch.allclose( + c.to(torch.float16), expected, rtol=1e-2, atol=1e-2 + ), f"GEMM result doesn't match expected output\nMax difference: {(c - expected).abs().max()}" + + print("GEMM test passed!") + + +def downcast_gemm_test(is_debug=False): + E = sym.E + # Define constraints for the kernel + constraints = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + tkw.TilingConstraint(K, BLOCK_K), + tkw.WaveConstraint(M, BLOCK_M / 2), + tkw.WaveConstraint(N, BLOCK_N / 2), + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=tkw.MMAType.F32_16x16x16_F16, + vector_shapes={E: E}, + ), + ] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + mapping = tkw.IndexMapping( + num_iterators=2, + inputs={E: sympy.Integer(1), N: i, K: j}, + outputs={N: i, K: j}, + ) + + @tkw.wave(constraints) + def gemm( + a: Memory[M, K, ADDRESS_SPACE_A, f16], # Input matrix A + b: Memory[E, N, K, ADDRESS_SPACE_B, f16], # Input matrix B + c: Memory[M, N, ADDRESS_SPACE_C, f32], # Output matrix C + ): + # Initialize the accumulator register with zeros + c_reg = Register[M, N, f32](0.0) + + # Iterate over the K dimension to compute the dot product + @tkw.iterate(K, init_args=[c_reg]) + def repeat(acc: Register[M, N, f32]) -> Register[M, N, f32]: + # Load elements from A and B + a_reg = tkw.read(a) + b_reg = tkw.read(b, mapping=mapping) + + # Compute matrix multiplication and accumulate + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + # Store the final result to C + tkw.write(repeat, c) + + # Create test matrices + m, n, k = 64, 64, 128 # Small dimensions for testing + e = 8 + + # Initialize input matrices with random values + torch.manual_seed(0) + a = torch.randn(m, k, dtype=torch.float16, device="cuda") + b = torch.randn(e, n, k, dtype=torch.float16, device="cuda") + c = torch.zeros(m, n, dtype=torch.float32, device="cuda") + + # Set hyperparameters for compilation + hyperparams = { + ADDRESS_SPACE_A: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_B: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: m, + N: n, + K: k, + E: e, + } + + # Compile the kernel + options = WaveCompileOptions( + subs=hyperparams, + ) + options = set_default_run_config(options) + compiled_gemm = wave_compile(options, gemm) + + # Run the GEMM kernel + compiled_gemm(a, b, c) + + # Verify the result using PyTorch's matmul + expected = torch.matmul(a, b[1].t()) + + # Check if results are close (accounting for floating-point precision) + assert torch.allclose( + c.to(torch.float16), expected, rtol=1e-2, atol=1e-2 + ), f"GEMM result doesn't match expected output\nMax difference: {(c - expected).abs().max()}" + + print(compiled_gemm.asm) + print("GEMM test passed!") + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + e = tkw.IndexMapping.iterator(2) + d0 = tkw.IndexMapping.dynamic_val(0) + + IDX = sym.IDX + mapping = tkw.IndexMapping( + num_iterators=2, + inputs={E: IDX, N: i, K: j}, + outputs={N: i, K: j}, + ) + + @tkw.wave(constraints) + def gemm( + a: Memory[M, K, ADDRESS_SPACE_A, f16], # Input matrix A + b: Memory[E, N, K, ADDRESS_SPACE_B, f16], # Input matrix B + idx: i32, + c: Memory[M, N, ADDRESS_SPACE_C, f32], # Output matrix C + ): + # Initialize the accumulator register with zeros + c_reg = Register[M, N, f32](0.0) + tkw.set_symbol(IDX, idx) + + # Iterate over the K dimension to compute the dot product + @tkw.iterate(K, init_args=[c_reg]) + def repeat(acc: Register[M, N, f32]) -> Register[M, N, f32]: + # Load elements from A and B + a_reg = tkw.read(a) + b_reg = tkw.read(b, mapping=mapping) + + # Compute matrix multiplication and accumulate + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + # Store the final result to C + tkw.write(repeat, c) + + # Create test matrices + m, n, k = 64, 64, 128 # Small dimensions for testing + e = 8 + + # Initialize input matrices with random values + torch.manual_seed(0) + a = torch.randn(m, k, dtype=torch.float16, device="cuda") + b = torch.randn(e, n, k, dtype=torch.float16, device="cuda") + c = torch.zeros(m, n, dtype=torch.float32, device="cuda") + + # Set hyperparameters for compilation + hyperparams = { + ADDRESS_SPACE_A: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_B: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: m, + N: n, + K: k, + E: e, + } + + # Compile the kernel + options = WaveCompileOptions( + subs=hyperparams, + ) + options = set_default_run_config(options) + compiled_gemm = wave_compile(options, gemm) + + # Run the GEMM kernel + compiled_gemm(a, b, 1, c) + + # Verify the result using PyTorch's matmul + expected = torch.matmul(a, b[1].t()) + + # Check if results are close (accounting for floating-point precision) + assert torch.allclose( + c.to(torch.float16), expected, rtol=1e-2, atol=1e-2 + ), f"GEMM result doesn't match expected output\nMax difference: {(c - expected).abs().max()}" + + print("GEMM test passed!") + + +def dyn_downcast_gemm_test(is_debug=False): + E = sym.E + # Define constraints for the kernel + constraints = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + tkw.WorkgroupConstraint(E, E, 2), + tkw.TilingConstraint(K, BLOCK_K), + tkw.WaveConstraint(M, BLOCK_M / 2), + tkw.WaveConstraint(N, BLOCK_N / 2), + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=tkw.MMAType.F32_16x16x16_F16, + vector_shapes={E: E}, + ), + ] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + e = tkw.IndexMapping.iterator(2) + d0 = tkw.IndexMapping.dynamic_val(0) + + IDX = sym.IDX + mapping = tkw.IndexMapping( + num_iterators=2, + inputs={E: IDX, N: i, K: j}, + outputs={N: i, K: j}, + ) + + a_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, K: j}, + outputs={M: i, K: j}, + ) + + @tkw.wave(constraints) + def gemm( + a: Memory[M, K, ADDRESS_SPACE_A, f16], # Input matrix A + b: Memory[E, N, K, ADDRESS_SPACE_B, f16], # Input matrix B + idx: i32, + c: Memory[M, N, ADDRESS_SPACE_C, f32], # Output matrix C + ): + # Initialize the accumulator register with zeros + c_reg = Register[M, N, f32](0.0) + tkw.set_symbol(IDX, idx) + + # Iterate over the K dimension to compute the dot product + @tkw.iterate(K, init_args=[c_reg]) + def repeat(acc: Register[M, N, f32]) -> Register[M, N, f32]: + # Load elements from A and B + a_reg = tkw.read(a, mapping=a_read_map) + b_reg = tkw.read(b, mapping=mapping) + + # Compute matrix multiplication and accumulate + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + # Store the final result to C + tkw.write(repeat, c) + + # Create test matrices + m, n, k = 64, 64, 128 # Small dimensions for testing + e = 8 + + # Initialize input matrices with random values + torch.manual_seed(0) + a = torch.randn(m, k, dtype=torch.float16, device="cuda") + b = torch.randn(e, n, k, dtype=torch.float16, device="cuda") + c = torch.zeros(m, n, dtype=torch.float32, device="cuda") + + # Set hyperparameters for compilation + hyperparams = { + ADDRESS_SPACE_A: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_B: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: m, + N: n, + K: k, + E: e, + } + + # Compile the kernel + options = WaveCompileOptions( + subs=hyperparams, + print_ir_after="all" if is_debug else [], + print_ir_before="all" if is_debug else [], + ) + options = set_default_run_config(options) + compiled_gemm = wave_compile(options, gemm) + + # Run the GEMM kernel + compiled_gemm(a, b, 1, c) + if is_debug: + print(compiled_gemm.asm) + + # Verify the result using PyTorch's matmul + expected = torch.matmul(a, b[1].t()) + + # Check if results are close (accounting for floating-point precision) + assert torch.allclose( + c.to(torch.float16), expected, rtol=1e-2, atol=1e-2 + ), f"GEMM result doesn't match expected output\nMax difference: {(c - expected).abs().max()}" + + print("GEMM test passed!") + + +def reorder_a_gemm_test(is_debug=False): + E = sym.E + # Define constraints for the kernel + constraints = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + tkw.WorkgroupConstraint(E, E, 2), + tkw.TilingConstraint(K, BLOCK_K), + tkw.WaveConstraint(M, BLOCK_M / 2), + tkw.WaveConstraint(N, BLOCK_N / 2), + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=tkw.MMAType.F32_16x16x16_F16, + vector_shapes={E: E, M: 16, K: 16}, + ), + ] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + e = tkw.IndexMapping.iterator(2) + d0 = tkw.IndexMapping.dynamic_val(0) + + IDX = sym.IDX + mapping = tkw.IndexMapping( + num_iterators=2, + inputs={E: IDX, N: i, K: j}, + outputs={N: i, K: j}, + ) + + a_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: d0, K: j}, + outputs={M: i, K: j}, + dynamic_val_mappings={M: i}, + ) + + @tkw.wave(constraints) + def gemm( + a: Memory[M, K, ADDRESS_SPACE_A, f16], # Input matrix A + b: Memory[E, N, K, ADDRESS_SPACE_B, f16], # Input matrix B + reorder_a: Memory[M, ADDRESS_SPACE_A, i32], # Input matrix A + a_back: Memory[M, K, ADDRESS_SPACE_A, f16], # Output matrix A + c: Memory[M, N, ADDRESS_SPACE_C, f32], # Output matrix C + idx: i32, + ): + # Initialize the accumulator register with zeros + c_reg = Register[M, N, f32](0.0) + a_reg = Register[M, K, f16](0.0) + tkw.set_symbol(IDX, idx) + + # Iterate over the K dimension to compute the dot product + @tkw.iterate(K, init_args=[c_reg]) + def repeat(acc: Register[M, N, f32]) -> Register[M, N, f32]: + # Load elements from A and B + reordered_idx = tkw.read(reorder_a, elements_per_thread=1) + a_reg = tkw.read( + a, mapping=a_read_map, mapping_dynamic_vals=(reordered_idx,) + ) + b_reg = tkw.read(b, mapping=mapping) + + # Compute matrix multiplication and accumulate + acc = tkw.mma(a_reg, b_reg, acc) + + tkw.write(a_reg, a_back) + return acc + + # Store the final result to C + tkw.write(repeat, c) + + # Create test matrices + m, n, k = 64, 64, 128 # Small dimensions for testing + e = 8 + + # Initialize input matrices with random values + torch.manual_seed(0) + a = torch.randn(m, k, dtype=torch.float16, device="cuda") + a_back = torch.zeros(m, k, dtype=torch.float16, device="cuda") + b = torch.randn(e, n, k, dtype=torch.float16, device="cuda") + c = torch.zeros(m, n, dtype=torch.float32, device="cuda") + + # create reorder_a such that it is a permutation of the rows of a + reorder_a = torch.randperm(m).to(torch.int32).to(device="cuda") + # Set hyperparameters for compilation + hyperparams = { + ADDRESS_SPACE_A: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_B: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: m, + N: n, + K: k, + E: e, + } + + # Compile the kernel + options = WaveCompileOptions( + subs=hyperparams, + print_ir_after="all" if is_debug else [], + print_ir_before="all" if is_debug else [], + ) + options = set_default_run_config(options) + compiled_gemm = wave_compile(options, gemm) + print(compiled_gemm.asm) + + # Run the GEMM kernel + compiled_gemm(a, b, reorder_a, a_back, c, 1) + reordered_a = a[reorder_a] + + print("Reorder idx: ", reorder_a) + print("A back: ", a_back[0]) + print("A: ", a[reorder_a[0]]) + print("Reordered A: ", reordered_a[0]) + + assert torch.allclose( + a_back, reordered_a, rtol=1e-2, atol=1e-2 + ), f"A back doesn't match expected output\nMax difference: {(a_back - reordered_a).abs().max()}" + + # Verify the result using PyTorch's matmul + expected = torch.matmul(reordered_a, b[1].t()) + + # Check if results are close (accounting for floating-point precision) + assert torch.allclose( + c.to(torch.float16), expected, rtol=1e-2, atol=1e-2 + ), f"GEMM result doesn't match expected output\nMax difference: {(c - expected).abs().max()}" + + print("GEMM test passed!") + + +def scatter_a_test(is_debug=False): + M_DIV_2 = sym.M_DIV_2 + I = sym.I + # Define constraints for the kernel + constraints = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.TilingConstraint(K, BLOCK_K), + tkw.WaveConstraint(M, BLOCK_M / 2), + tkw.HardwareConstraint( + threads_per_wave=64, + vector_shapes={M_DIV_2: M_DIV_2, M: M, K: BLOCK_K}, + ), + ] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + d0 = tkw.IndexMapping.dynamic_val(0) + + a_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: d0, K: j}, + outputs={M: i, K: j}, + dynamic_val_mappings={M: i}, + ) + + a_write_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, K: j}, + outputs={M: i, K: j}, + ) + + dyn_reorder_a_read_map = tkw.IndexMapping( + num_iterators=1, + inputs={M_DIV_2: d0}, + outputs={M_DIV_2: i}, + dynamic_val_mappings={M_DIV_2: i}, + ) + + @tkw.wave(constraints) + def gemm( + a: Memory[M, K, ADDRESS_SPACE_A, f16], # Input matrix A + reorder_a: Memory[M_DIV_2, ADDRESS_SPACE_A, i32], # Input matrix A + a_back: Memory[M, K, ADDRESS_SPACE_A, f16], # Output matrix A + ): + # Initialize the accumulator register with zeros + @tkw.conditional(THREAD_0 < M_DIV_2) + def scatter_op(): + tid = tkw.scalar(THREAD_0, i32) + reordered_idx = tkw.read( + reorder_a, + mapping=dyn_reorder_a_read_map, + mapping_dynamic_vals=(tid,), + ) + + @tkw.iterate(K, init_args=[]) + def copy_row(): + a_row_data = tkw.read( + a, + mapping=a_read_map, + mapping_dynamic_vals=(reordered_idx,), + elements_per_thread=BLOCK_K, + ) + tkw.write( + a_row_data, + a_back, + mapping=a_write_map, + elements_per_thread=BLOCK_K, + ) + + # Create test matrices + m, k = 64, 128 # Small dimensions for testing + + # Initialize input matrices with random values + torch.manual_seed(0) + a = torch.randn(m, k, dtype=torch.float16, device="cuda") + a_back = torch.zeros(m, k, dtype=torch.float16, device="cuda") + + # Set hyperparameters for compilation + hyperparams = { + ADDRESS_SPACE_A: GLOBAL_ADDRESS_SPACE, + BLOCK_M: 64, + BLOCK_K: 32, + M: m, + K: k, + M_DIV_2: m // 2, + } + + # Compile the kernel + if is_debug: + options = WaveCompileOptions( + subs=hyperparams, + print_ir_after="all", + print_ir_before="all", + ) + else: + options = WaveCompileOptions( + subs=hyperparams, + ) + + options = set_default_run_config(options) + compiled_gemm = wave_compile(options, gemm) + + if is_debug: + with open("scatter_a.mlir", "w") as f: + f.write(compiled_gemm.asm) + + # create reorder_a such that it is a permutation of the rows of a + reorder_a = torch.randperm(m // 2).to(torch.int32).to(device="cuda") + reorder_a_clone = reorder_a.clone().to(device="cuda") + compiled_gemm(a, reorder_a_clone, a_back) + reordered_a = torch.zeros((m, k), dtype=torch.float16).to(device="cuda") + + # read rows of a in reorder_a order + for i in range(m // 2): + reordered_a[i] = a[reorder_a[i]] + + print("Reorder idx: ", reorder_a) + print("A back: ", a_back[0]) + print("A: ", a[reorder_a[0]]) + print("Reordered A: ", reordered_a[0]) + + print("scatter_a test passed!") + + +def scatter_a_simple_gemm_test(is_debug=False): + M_DIV_2 = sym.M_DIV_2 + SHARED_MEM = sym.SHARED_MEM + I = sym.I + # Define constraints for the kernel + constraints = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + tkw.TilingConstraint(K, BLOCK_K), + tkw.TilingConstraint(I), + tkw.WaveConstraint(M, BLOCK_M / 2), + tkw.WaveConstraint(N, BLOCK_N / 2), + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=tkw.MMAType.F32_16x16x16_F16, + vector_shapes={M_DIV_2: M_DIV_2, M: 16, N: 16, K: 16, I: 0}, + ), + ] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + d0 = tkw.IndexMapping.dynamic_val(0) + d1 = tkw.IndexMapping.dynamic_val(1) + + a_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: d0, K: j}, + outputs={M: i, K: j}, + dynamic_val_mappings={M: i}, + ) + + a_write_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, K: j}, + outputs={M: d0, K: j}, + dynamic_val_mappings={M: i}, + ) + + dyn_reorder_a_read_map = tkw.IndexMapping( + num_iterators=1, + inputs={M_DIV_2: d0}, + outputs={M_DIV_2: i}, + dynamic_val_mappings={M_DIV_2: i}, + ) + + a_simple_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, K: j}, + outputs={M: i, K: j}, + ) + + @tkw.wave(constraints) + def gemm( + a: Memory[M, K, ADDRESS_SPACE_A, f16], # Input matrix A + b: Memory[N, K, ADDRESS_SPACE_B, f16], # Input matrix B + reorder_a: Memory[M_DIV_2, ADDRESS_SPACE_A, i32], # Input matrix A + a_back: Memory[M, K, ADDRESS_SPACE_A, f16], # Output matrix A + c: Memory[M, N, ADDRESS_SPACE_C, f32], # Output matrix C + ): + + zero_reg = tkw.Register[M, K, f16](0.0) + tkw.write(zero_reg, a_back) + + valid_threads = THREAD_0 < M_DIV_2 + + @tkw.conditional(valid_threads) + def scatter_op(): + tid = tkw.Register[M_DIV_2, i32](THREAD_0) + reordered_idx = tkw.read( + reorder_a, + mapping=dyn_reorder_a_read_map, + mapping_dynamic_vals=(tid,), + ) + + @tkw.iterate(K, init_args=[]) + def copy_row(): + a_row_data = tkw.read( + a, + mapping=a_read_map, + mapping_dynamic_vals=(reordered_idx,), + elements_per_thread=16, + ) + + tkw.write( + a_row_data, + a_back, + mapping=a_write_map, + mapping_dynamic_vals=(tid,), + elements_per_thread=16, + ) + + tkw.workgroup_barrier() + c_reg = Register[M, N, f32](0.0) + + @tkw.iterate(K, init_args=[c_reg]) + def repeat(acc: Register[M, N, f32]) -> Register[M, N, f32]: + # Load elements from A and B + a_reg = tkw.read(a_back) + b_reg = tkw.read(b) + + # Compute matrix multiplication and accumulate + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + # Store the final result to C + tkw.write(repeat, c) + + # Create test matrices + m, n, k = 128, 128, 128 + + # Initialize input matrices with random values + torch.manual_seed(0) + a = torch.randn(m, k, dtype=torch.float16, device="cuda") + a_back = torch.zeros(m, k, dtype=torch.float16, device="cuda") + b = torch.randn(n, k, dtype=torch.float16, device="cuda") + c = torch.zeros(m, n, dtype=torch.float32, device="cuda") + + # Set hyperparameters for compilation + hyperparams = { + ADDRESS_SPACE_A: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_B: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE, + SHARED_MEM: SHARED_ADDRESS_SPACE, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: m, + N: n, + K: k, + M_DIV_2: 4, + } + + # Compile the kernel + if is_debug: + options = WaveCompileOptions( + subs=hyperparams, + print_ir_after="all", + print_ir_before="all", + ) + else: + options = WaveCompileOptions( + subs=hyperparams, + ) + options = set_default_run_config(options) + compiled_gemm = wave_compile(options, gemm) + + if is_debug: + with open("scatter_a_simple_gemm.mlir", "w") as f: + f.write(compiled_gemm.asm) + + # create reorder_a such that it is a permutation of the rows of a + reorder_a = torch.randperm(4).to(torch.int32).to(device="cuda") + reorder_a_clone = reorder_a.clone().to(device="cuda") + compiled_gemm(a, b, reorder_a_clone, a_back, c) + reordered_a = torch.zeros((m, k), dtype=torch.float16).to(device="cuda") + + # read rows of a in reorder_a order + for i in range(4): + reordered_a[i] = a[reorder_a[i]] + + # print("Reorder idx: ", reorder_a) + # print("A back: ", a_back[0]) + # print("A: ", a[reorder_a[0]]) + # print("Reordered A: ", reordered_a[0]) + + assert torch.allclose( + a_back, reordered_a, rtol=1e-2, atol=1e-2 + ), f"A back doesn't match expected output\nMax difference: {(a_back - reordered_a).abs().max()}" + + expected = torch.matmul(reordered_a, b.t()) + + print("Expected: ", expected[0]) + print("C: ", c[0]) + + assert torch.allclose( + c.to(torch.float16), expected, rtol=1e-2, atol=1e-2 + ), f"GEMM result doesn't match expected output\nMax difference: {(c - expected).abs().max()}" + + print("scatter_a_simple_gemm test passed!") + + +def scatter_gemm_test(is_debug=False): + E = sym.E + M_DIV_2 = sym.M_DIV_2 + # Define constraints for the kernel + constraints = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + tkw.WorkgroupConstraint(E, E, 2), + tkw.TilingConstraint(K, BLOCK_K), + tkw.WaveConstraint(M, BLOCK_M / 2), + tkw.WaveConstraint(N, BLOCK_N / 2), + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=tkw.MMAType.F32_16x16x16_F16, + vector_shapes={E: E, M_DIV_2: M_DIV_2, M: 16, N: 16, K: 16}, + ), + ] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + e = tkw.IndexMapping.iterator(2) + d0 = tkw.IndexMapping.dynamic_val(0) + + IDX = sym.IDX + mapping = tkw.IndexMapping( + num_iterators=2, + inputs={E: IDX, N: i, K: j}, + outputs={N: i, K: j}, + ) + + a_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: d0, K: j}, + outputs={M: i, K: j}, + dynamic_val_mappings={M: i}, + ) + + a_write_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, K: j}, + outputs={M: d0, K: j}, + dynamic_val_mappings={M: i}, + ) + + dyn_reorder_a_read_map = tkw.IndexMapping( + num_iterators=1, + inputs={M_DIV_2: d0}, + outputs={M_DIV_2: i}, + dynamic_val_mappings={M_DIV_2: i}, + ) + + @tkw.wave(constraints) + def gemm( + a: Memory[M, K, ADDRESS_SPACE_A, f16], # Input matrix A + b: Memory[E, N, K, ADDRESS_SPACE_B, f16], # Input matrix B + reorder_a: Memory[M_DIV_2, ADDRESS_SPACE_A, i32], # Input matrix A + a_back: Memory[M, K, ADDRESS_SPACE_A, f16], # Output matrix A + c: Memory[M, N, ADDRESS_SPACE_C, f32], # Output matrix C + idx: i32, + ): + # Initialize the accumulator register with zeros + zero_reg = Register[M, K, f16](0.0) + tkw.write(zero_reg, a_back) + + tkw.set_symbol(IDX, idx) + + condition = THREAD_0 < M_DIV_2 + + @tkw.conditional(condition) + def scatter_op(): + tid = tkw.Register[M_DIV_2, i32](THREAD_0) + reordered_idx = tkw.read( + reorder_a, + mapping=dyn_reorder_a_read_map, + mapping_dynamic_vals=(tid,), + ) + + @tkw.iterate(K, init_args=[]) + def copy_row(): + a_row_data = tkw.read( + a, + mapping=a_read_map, + mapping_dynamic_vals=(reordered_idx,), + elements_per_thread=16, + ) + tkw.write( + a_row_data, + a_back, + mapping=a_write_map, + mapping_dynamic_vals=(tid,), + elements_per_thread=16, + ) + + tkw.workgroup_barrier() + c_reg = Register[M, N, f32](0.0) + + # Iterate over the K dimension to compute the dot product + @tkw.iterate(K, init_args=[c_reg]) + def gemm_compute(acc: Register[M, N, f32]) -> Register[M, N, f32]: + # Load elements from A and B + a_reg = tkw.read(a_back) + b_reg = tkw.read(b, mapping=mapping) + + # Compute matrix multiplication and accumulate + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + # Store the final result to C + tkw.write(gemm_compute, c) + + # Create test matrices + m, n, k = 64, 64, 128 # Small dimensions for testing + e = 8 + + # Initialize input matrices with random values + torch.manual_seed(0) + a = torch.randn(m, k, dtype=torch.float16, device="cuda") + a_back = torch.zeros(m, k, dtype=torch.float16, device="cuda") + b = torch.randn(e, n, k, dtype=torch.float16, device="cuda") + c = torch.zeros(m, n, dtype=torch.float32, device="cuda") + + # Set hyperparameters for compilation + hyperparams = { + ADDRESS_SPACE_A: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_B: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: m, + N: n, + K: k, + E: e, + M_DIV_2: m // 2, + } + + # Compile the kernel + if is_debug: + options = WaveCompileOptions( + subs=hyperparams, + print_ir_after="all", + print_ir_before="all", + ) + else: + options = WaveCompileOptions( + subs=hyperparams, + ) + + options = set_default_run_config(options) + compiled_gemm = wave_compile(options, gemm) + + if is_debug: + with open("scatter_gemm.mlir", "w") as f: + f.write(compiled_gemm.asm) + + # create reorder_a such that it is a permutation of the rows of a + reorder_a = torch.randperm(m // 2).to(torch.int32).to(device="cuda") + compiled_gemm(a, b, reorder_a, a_back, c, 1) + reordered_a = torch.zeros((m, k), dtype=torch.float16).to(device="cuda") + + # read rows of a in reorder_a order + for i in range(m // 2): + reordered_a[i] = a[reorder_a[i]] + + print("Reorder idx: ", reorder_a) + print("A back: ", a_back[0]) + print("A: ", a[reorder_a[0]]) + print("Reordered A: ", reordered_a[0]) + + assert torch.allclose( + a_back, reordered_a, rtol=1e-2, atol=1e-2 + ), f"A back doesn't match expected output\nMax difference: {(a_back - reordered_a).abs().max()}" + + # Verify the result using PyTorch's matmul + expected = torch.matmul(reordered_a, b[1].t()) + + # Check if results are close (accounting for floating-point precision) + assert torch.allclose( + c.to(torch.float16), expected, rtol=1e-2, atol=1e-2 + ), f"GEMM result doesn't match expected output\nMax difference: {(c - expected).abs().max()}" + + print("GEMM test passed!") + + +def scatter_gemm_w_padding_test(is_debug=False): + E = sym.E + BLOCK_SHAPE = sym.BLOCK_SHAPE + PAD_VALUE = sym.PAD_VALUE + + IDX = sym.IDX + SCATTER_IDX = sym.SCATTER_IDX + # Define constraints for the kernel + constraints = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + tkw.WorkgroupConstraint(E, E, 2), + tkw.TilingConstraint(K, BLOCK_K), + tkw.WaveConstraint(M, BLOCK_M / 2), + tkw.WaveConstraint(N, BLOCK_N / 2), + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=tkw.MMAType.F32_16x16x16_F16, + vector_shapes={E: E, BLOCK_SHAPE: BLOCK_SHAPE, M: 16, N: 16, K: 16}, + ), + ] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + e = tkw.IndexMapping.iterator(2) + d0 = tkw.IndexMapping.dynamic_val(0) + + mapping = tkw.IndexMapping( + num_iterators=2, + inputs={E: IDX, N: i, K: j}, + outputs={N: i, K: j}, + ) + + a_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: d0, K: j}, + outputs={M: i, K: j}, + dynamic_val_mappings={M: i}, + ) + + a_write_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, K: j}, + outputs={M: d0, K: j}, + dynamic_val_mappings={M: i}, + ) + + c_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: d0, N: j}, + outputs={M: i, N: j}, + dynamic_val_mappings={M: i}, + ) + + c_write_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, N: j}, + outputs={M: d0, N: j}, + dynamic_val_mappings={M: i}, + ) + + dyn_reorder_a_read_map = tkw.IndexMapping( + num_iterators=1, + inputs={BLOCK_SHAPE: d0}, + outputs={BLOCK_SHAPE: i}, + dynamic_val_mappings={BLOCK_SHAPE: i}, + ) + + @tkw.wave(constraints) + def gemm( + a: Memory[M, K, ADDRESS_SPACE_A, f16], # Input matrix A + b: Memory[E, N, K, ADDRESS_SPACE_B, f16], # Input matrix B + reorder_a: Memory[BLOCK_SHAPE, ADDRESS_SPACE_A, i32], # Input matrix A + a_back: Memory[M, K, ADDRESS_SPACE_A, f16], # Output matrix A + c: Memory[M, N, ADDRESS_SPACE_C, f32], # Output matrix C + c_back: Memory[M, N, ADDRESS_SPACE_C, f32], # Output matrix C + idx: i32, + ): + # Initialize the accumulator register with zeros + zero_reg = Register[M, K, f16](0.0) + zero_reg_mn = Register[M, N, f32](0.0) + tkw.write(zero_reg, a_back) + tkw.write(zero_reg_mn, c_back) + mock_reg = tkw.read(reorder_a) + + tkw.set_symbol(IDX, idx) + + condition = THREAD_0 < BLOCK_SHAPE + + @tkw.conditional(condition) + def scatter_op(): + tid = tkw.Register[BLOCK_SHAPE, i32](THREAD_0) + reordered_idx = tkw.read( + reorder_a, + mapping=dyn_reorder_a_read_map, + mapping_dynamic_vals=(tid,), + ) + + tkw.set_symbol(SCATTER_IDX, reordered_idx) + is_not_padding = reordered_idx < tkw.scalar(PAD_VALUE, i32) + + @tkw.conditional(is_not_padding) + def then(): + @tkw.iterate(K, init_args=[]) + def copy_row(): + a_row_data = tkw.read( + a, + mapping=a_read_map, + mapping_dynamic_vals=(reordered_idx,), + elements_per_thread=16, + ) + tkw.write( + a_row_data, + a_back, + mapping=a_write_map, + mapping_dynamic_vals=(tid,), + elements_per_thread=16, + ) + + tkw.workgroup_barrier() + c_reg = Register[M, N, f32](0.0) + + # Iterate over the K dimension to compute the dot product + @tkw.iterate(K, init_args=[c_reg]) + def gemm_compute(acc: Register[M, N, f32]) -> Register[M, N, f32]: + # Load elements from A and B + a_reg = tkw.read(a_back) + b_reg = tkw.read(b, mapping=mapping) + + # Compute matrix multiplication and accumulate + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + # Store the final result to C + tkw.write(gemm_compute, c_back) + + @tkw.conditional(condition) + def scatter_op(): + tid = tkw.Register[BLOCK_SHAPE, i32](THREAD_0) + reordered_idx = tkw.read( + reorder_a, + mapping=dyn_reorder_a_read_map, + mapping_dynamic_vals=(tid,), + ) + + tkw.set_symbol(SCATTER_IDX, reordered_idx) + is_not_padding = SCATTER_IDX < PAD_VALUE + + @tkw.conditional(is_not_padding) + def then(): + c_row_data = tkw.read( + c_back, + mapping=c_read_map, + mapping_dynamic_vals=(tid,), + elements_per_thread=16, + ) + tkw.write( + c_row_data, + c, + mapping=c_write_map, + mapping_dynamic_vals=(reordered_idx,), + elements_per_thread=16, + ) + + # Create test matrices + m, n, k = 64, 64, 128 # Small dimensions for testing + e = 8 + block_shape = 16 + + # Initialize input matrices with random values + torch.manual_seed(0) + a = torch.randn(m, k, dtype=torch.float16, device="cuda") + a_back = torch.zeros(m, k, dtype=torch.float16, device="cuda") + b = torch.randn(e, n, k, dtype=torch.float16, device="cuda") + c = torch.zeros(m, n, dtype=torch.float32, device="cuda") + c_back = torch.zeros(m, n, dtype=torch.float32, device="cuda") + + # Set hyperparameters for compilation + hyperparams = { + ADDRESS_SPACE_A: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_B: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: m, + N: n, + K: k, + E: e, + BLOCK_SHAPE: block_shape, + PAD_VALUE: m, + } + + # Compile the kernel + if is_debug: + options = WaveCompileOptions( + subs=hyperparams, + print_ir_after="all", + print_ir_before="all", + ) + else: + options = WaveCompileOptions( + subs=hyperparams, + ) + + options = set_default_run_config(options) + compiled_gemm = wave_compile(options, gemm) + + if is_debug: + with open("scatter_gemm_w_padding.mlir", "w") as f: + f.write(compiled_gemm.asm) + + # create reorder_a such that it is a permutation of the rows of a + reorder_a = torch.randperm(m).to(torch.int32).to(device="cuda") + reorder_a = reorder_a[:block_shape] + # make last two values of reorder_a m + reorder_a[-2] = m + reorder_a[-1] = m + + compiled_gemm(a, b, reorder_a, a_back, c, c_back, 1) + reordered_a = torch.zeros((m, k), dtype=torch.float16).to(device="cuda") + + # read rows of a in reorder_a order + for i in range(block_shape): + if reorder_a[i] < m: + reordered_a[i] = a[reorder_a[i]] + + print("Reorder idx: ", reorder_a) + print("A back: ", a_back[0]) + print("A: ", a[reorder_a[0]]) + print("Reordered A: ", reordered_a[0]) + + assert torch.allclose( + a_back, reordered_a, rtol=1e-2, atol=1e-2 + ), f"A back doesn't match expected output\nMax difference: {(a_back - reordered_a).abs().max()}" + + # Verify the result using PyTorch's matmul + expected_int = torch.matmul(reordered_a, b[1].t()) + expected = torch.zeros((m, n), dtype=torch.float32).to(device="cuda") + + for i in range(block_shape): + if reorder_a[i] < m: + expected[reorder_a[i]] = expected_int[i] + + assert torch.allclose( + c_back.to(torch.float16), expected_int, rtol=1e-2, atol=1e-2 + ), f"C back doesn't match expected output\nMax difference: {(c_back.to(torch.float16) - expected_int).abs().max()}" + + assert torch.allclose( + c, expected, rtol=1e-2, atol=1e-2 + ), f"GEMM result doesn't match expected output\nMax difference: {(c - expected).abs().max()}" + + print("GEMM test passed!") + + +def scatter_gemm_fused_test(is_debug=False): + E = sym.E + TOTAL_ELEMS = sym.TOTAL_ELEMS + NUM_BLOCKS = sym.NUM_BLOCKS + BLOCK_SHAPE = sym.BLOCK_SHAPE + PAD_VALUE = sym.PAD_VALUE + + IDX = sym.IDX + SCATTER_IDX = sym.SCATTER_IDX + # Define constraints for the kernel + constraints = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + tkw.WorkgroupConstraint(TOTAL_ELEMS, BLOCK_SHAPE, 2), + tkw.TilingConstraint(K, BLOCK_K), + tkw.WaveConstraint(M, BLOCK_M / 2), + tkw.WaveConstraint(N, BLOCK_N / 2), + tkw.WaveConstraint(TOTAL_ELEMS, BLOCK_SHAPE), + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=tkw.MMAType.F32_16x16x16_F16, + vector_shapes={ + E: E, + TOTAL_ELEMS: TOTAL_ELEMS, + BLOCK_SHAPE: BLOCK_SHAPE, + M: 16, + N: 16, + K: 16, + NUM_BLOCKS: NUM_BLOCKS, + }, + ), + ] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + e = tkw.IndexMapping.iterator(2) + d0 = tkw.IndexMapping.dynamic_val(0) + + b_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={E: IDX, N: i, K: j}, + outputs={N: i, K: j}, + ) + + a_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: d0, K: j}, + outputs={M: i, K: j}, + dynamic_val_mappings={M: i}, + ) + + a_back_write_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, K: j}, + outputs={NUM_BLOCKS: WORKGROUP_2, M: d0, K: j}, + dynamic_val_mappings={M: i}, + ) + + a_back_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={NUM_BLOCKS: WORKGROUP_2, M: i, K: j}, + outputs={M: i, K: j}, + ) + + c_back_write_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, N: j}, + outputs={NUM_BLOCKS: WORKGROUP_2, M: i, N: j}, + ) + + c_back_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={NUM_BLOCKS: WORKGROUP_2, M: d0, N: j}, + outputs={M: i, N: j}, + dynamic_val_mappings={M: i}, + ) + + c_write_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, N: j}, + outputs={M: d0, N: j}, + dynamic_val_mappings={M: i}, + ) + + dyn_reorder_a_read_map = tkw.IndexMapping( + num_iterators=1, + inputs={TOTAL_ELEMS: d0}, + outputs={TOTAL_ELEMS: i}, + dynamic_val_mappings={TOTAL_ELEMS: i}, + ) + + expert_id_read_map = tkw.IndexMapping( + num_iterators=1, + inputs={NUM_BLOCKS: d0}, + outputs={NUM_BLOCKS: i}, + dynamic_val_mappings={NUM_BLOCKS: i}, + ) + + @tkw.wave(constraints) + def gemm( + a: Memory[M, K, ADDRESS_SPACE_A, f16], # Input matrix A + b: Memory[E, N, K, ADDRESS_SPACE_B, f16], # Input matrix B + reorder_a: Memory[TOTAL_ELEMS, ADDRESS_SPACE_A, i32], # Input matrix A + expert_ids: Memory[NUM_BLOCKS, ADDRESS_SPACE_A, i32], # Input matrix A + a_back: Memory[NUM_BLOCKS, M, K, ADDRESS_SPACE_A, f16], # Output matrix A + c: Memory[M, N, ADDRESS_SPACE_C, f32], # Output matrix C + c_back: Memory[NUM_BLOCKS, M, N, ADDRESS_SPACE_C, f32], # Output matrix C + ): + # Initialize the accumulator register with zeros + zero_reg = Register[M, K, f16](0.0) + zero_reg_mn = Register[M, N, f32](0.0) + tkw.write(zero_reg, a_back) + tkw.write(zero_reg_mn, c_back) + mock_reg = tkw.read(reorder_a) + + wid = tkw.scalar(WORKGROUP_2, i32) + expert_id = tkw.read( + expert_ids, mapping=expert_id_read_map, mapping_dynamic_vals=(wid,) + ) + tkw.set_symbol(IDX, expert_id) + condition = THREAD_0 < BLOCK_SHAPE + + @tkw.conditional(condition) + def scatter_op(): + tid = tkw.Register[TOTAL_ELEMS, i32](THREAD_0) + wid = tkw.Register[TOTAL_ELEMS, i32](WORKGROUP_2) + tid_offset = tkw.Register[TOTAL_ELEMS, i32](BLOCK_SHAPE) * wid + tid + reordered_idx = tkw.read( + reorder_a, + mapping=dyn_reorder_a_read_map, + mapping_dynamic_vals=(tid_offset,), + ) + + tkw.set_symbol(SCATTER_IDX, reordered_idx) + is_not_padding = SCATTER_IDX < PAD_VALUE + + @tkw.conditional(is_not_padding) + def then(): + @tkw.iterate(K, init_args=[]) + def copy_row(): + a_row_data = tkw.read( + a, + mapping=a_read_map, + mapping_dynamic_vals=(reordered_idx,), + elements_per_thread=16, + ) + + tkw.write( + a_row_data, + a_back, + mapping=a_back_write_map, + mapping_dynamic_vals=(tid,), + elements_per_thread=16, + ) + + tkw.workgroup_barrier() + c_reg = Register[M, N, f32](0.0) + + # Iterate over the K dimension to compute the dot product + @tkw.iterate(K, init_args=[c_reg]) + def gemm_compute(acc: Register[M, N, f32]) -> Register[M, N, f32]: + # Load elements from A and B + a_reg = tkw.read(a_back, mapping=a_back_read_map) + b_reg = tkw.read(b, mapping=b_read_map) + + # Compute matrix multiplication and accumulate + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(gemm_compute, c_back, mapping=c_back_write_map) + + @tkw.conditional(condition) + def scatter_op(): + tid = tkw.Register[TOTAL_ELEMS, i32](THREAD_0) + wid = tkw.Register[TOTAL_ELEMS, i32](WORKGROUP_2) + tid_offset = tkw.Register[TOTAL_ELEMS, i32](BLOCK_SHAPE) * wid + tid + reordered_idx = tkw.read( + reorder_a, + mapping=dyn_reorder_a_read_map, + mapping_dynamic_vals=(tid_offset,), + ) + + tkw.set_symbol(SCATTER_IDX, reordered_idx) + is_not_padding = SCATTER_IDX < PAD_VALUE + + @tkw.conditional(is_not_padding) + def then(): + c_row_data = tkw.read( + c_back, + mapping=c_back_read_map, + mapping_dynamic_vals=(tid,), + elements_per_thread=16, + ) + tkw.write( + c_row_data, + c, + mapping=c_write_map, + mapping_dynamic_vals=(reordered_idx,), + elements_per_thread=16, + ) + + # Create test matrices + m, n, k = 64, 64, 128 # Small dimensions for testing + block_shape = 6 + total_elems = 30 + num_blocks = total_elems // block_shape + num_experts = 4 + + # Initialize input matrices with random values + torch.manual_seed(0) + a = torch.randn(m, k, dtype=torch.float16, device="cuda") + a_back = torch.zeros(num_blocks, m, k, dtype=torch.float16, device="cuda") + b = torch.randn(num_experts, n, k, dtype=torch.float16, device="cuda") + c = torch.zeros(m, n, dtype=torch.float32, device="cuda") + c_back = torch.zeros(num_blocks, m, n, dtype=torch.float32, device="cuda") + + # create an expert_id list which is num_blocks long, each element is a random integer between 0 and num_experts - 1 + expert_ids = torch.randint( + 0, num_experts, (num_blocks,), dtype=torch.int32, device="cuda" + ) + + # Set hyperparameters for compilation + hyperparams = { + ADDRESS_SPACE_A: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_B: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: m, + N: n, + K: k, + E: num_experts, + BLOCK_SHAPE: block_shape, + TOTAL_ELEMS: total_elems, + NUM_BLOCKS: num_blocks, + PAD_VALUE: m, + } + + # Compile the kernel + if is_debug: + options = WaveCompileOptions( + subs=hyperparams, + print_ir_after="all", + print_ir_before="all", + ) + else: + options = WaveCompileOptions( + subs=hyperparams, + ) + + options = set_default_run_config(options) + compiled_gemm = wave_compile(options, gemm) + + if is_debug: + with open("scatter_gemm_fused.mlir", "w") as f: + f.write(compiled_gemm.asm) + + # create reorder_a such that it is a permutation of the rows of a + reorder_a = torch.randperm(total_elems).to(torch.int32).to(device="cuda") + + # make reorder 2d where each row is total_elemns/block_shape, block_shape times + reorder_a = reorder_a.view(total_elems // block_shape, block_shape) + + # for each row, make last 0, 1, or 2 elements m randomly + for i in range(total_elems // block_shape): + reorder_a[i, -2] = m + reorder_a[i, -1] = m + + reorder_a = reorder_a.view(-1) + + compiled_gemm(a, b, reorder_a, expert_ids, a_back, c, c_back) + reordered_a = torch.zeros((num_blocks, m, k), dtype=torch.float16).to(device="cuda") + + # Verify the result using PyTorch's matmul + expected = torch.zeros((m, n), dtype=torch.float32).to(device="cuda") + expected_int = torch.zeros((num_blocks, m, n), dtype=torch.float32).to( + device="cuda" + ) + + for block_idx in range(num_blocks): + expert_id = expert_ids[block_idx].item() + for i in range(block_shape): + idx = block_idx * block_shape + i + if reorder_a[idx] < m: + reordered_a[block_idx][i] = a[reorder_a[idx]] + + expected_int[block_idx] = torch.matmul(reordered_a[block_idx], b[expert_id].t()) + + for i in range(block_shape): + idx = block_idx * block_shape + i + if reorder_a[idx] < m: + expected[reorder_a[idx]] = expected_int[block_idx][i] + + # print exactly which indices are not matching + # for i in range(num_blocks): + # for j in range(m): + # for k in range(k): + # if not torch.allclose(a_back[i][j][k], reordered_a[i][j][k], rtol=1e-2, atol=1e-2): + # print(f"A back doesn't match expected output at index {i}, {j}, {k}") + + # for i in range(m): + # for j in range(n): + # if not torch.allclose(expected[i][j], c[i][j], rtol=1e-2, atol=1e-2): + # print(f"C doesn't match expected output at index {i}, {j}") + + # for i in range(num_blocks): + # print("EXPERT ID: ", i) + # try: + # assert torch.allclose( + # a_back[i], reordered_a[i], rtol=1e-2, atol=1e-2 + # ), f"A back doesn't match expected output\nMax difference: {(a_back[i] - reordered_a[i]).abs().max()}" + # except Exception as e: + # breakpoint() + # assert torch.allclose( + # a_back, reordered_a, rtol=1e-2, atol=1e-2 + # ), f"A back doesn't match expected output\nMax difference: {(a_back - reordered_a).abs().max()}" + + # assert torch.allclose( + # c_back.to(torch.float16), expected_int, rtol=1e-2, atol=1e-2 + # ), f"C back doesn't match expected output\nMax difference: {(c_back.to(torch.float16) - expected_int).abs().max()}" + + assert torch.allclose( + c, expected, rtol=1e-2, atol=1e-2 + ), f"GEMM result doesn't match expected output\nMax difference: {(c - expected).abs().max()}" + + print("GEMM test passed!") + + +def conditional_weight_gemm_test(is_debug=False): + """ + Test GEMM with conditional topk_weight multiplication. + + This demonstrates how to conditionally multiply GEMM output by weights based on an i32 flag. + Use case: In MoE, GEMM1 doesn't need weight multiplication, but GEMM2 does. + + - When apply_weights=0: output = A @ B.T + - When apply_weights=1: output = (A @ B.T) * weights + """ + constraints = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + tkw.TilingConstraint(K, BLOCK_K), + tkw.WaveConstraint(M, BLOCK_M / 2), + tkw.WaveConstraint(N, BLOCK_N / 2), + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=tkw.MMAType.F32_16x16x16_F16, + vector_shapes={M: 16, N: 16, K: 16}, + ), + ] + + @tkw.wave(constraints) + def gemm( + a: Memory[M, K, ADDRESS_SPACE_A, f16], # Input matrix A + b: Memory[N, K, ADDRESS_SPACE_B, f16], # Input matrix B + weights: Memory[M, ADDRESS_SPACE_A, f32], # TopK weights per row + c: Memory[M, N, ADDRESS_SPACE_C, f32], # Output matrix C + apply_weights: i32, # Flag: 0 = no multiply, 1 = multiply by weights + ): + # Initialize the accumulator register with zeros + c_reg = Register[M, N, f32](0.0) + + # Iterate over the K dimension to compute the dot product + @tkw.iterate(K, init_args=[c_reg]) + def repeat(acc: Register[M, N, f32]) -> Register[M, N, f32]: + # Load elements from A and B + a_reg = tkw.read(a) + b_reg = tkw.read(b) + + # Compute matrix multiplication and accumulate + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + # Store GEMM result in register (don't write to memory yet) + tkw.write(repeat, c) + + # Conditionally multiply by weights if flag is set + condition = apply_weights == tkw.scalar(1, i32) + + @tkw.conditional(condition) + def apply_topk_weights(): + weights_reg = tkw.read(weights) + weights_broadcast = tkw.broadcast(weights_reg, target_shape=[M, N]) + c_reg = tkw.read(c) + result = c_reg * weights_broadcast + tkw.write(result, c) + + # Create test matrices + m, n, k = 64, 64, 64 + + # Initialize input matrices with random values + torch.manual_seed(0) + a = torch.randn(m, k, dtype=torch.float16, device="cuda") + b = torch.randn(n, k, dtype=torch.float16, device="cuda") + + # TopK weights (simulating router weights in MoE) + weights = torch.full((m,), 2.0, dtype=torch.float32, device="cuda") + + # Test 1: Without weight multiplication (apply_weights=0) + c_no_weights = torch.zeros(m, n, dtype=torch.float32, device="cuda") + + hyperparams = { + ADDRESS_SPACE_A: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_B: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: m, + N: n, + K: k, + } + + options = WaveCompileOptions( + subs=hyperparams, + print_ir_after="all" if is_debug else [], + print_ir_before="all" if is_debug else [], + ) + options = set_default_run_config(options) + compiled_gemm = wave_compile(options, gemm) + + # Run with apply_weights=0 + compiled_gemm(a, b, weights, c_no_weights, 0) + + if is_debug: + with open("conditional_weight_gemm.mlir", "w") as f: + f.write(compiled_gemm.asm) + + # Test 2: With weight multiplication (apply_weights=1) + c_with_weights = torch.zeros(m, n, dtype=torch.float32, device="cuda") + c_with_weights = torch.zeros(m, n, dtype=torch.float32, device="cuda") + compiled_gemm(a, b, weights, c_with_weights, 1) + + # Verify results + expected_no_weights = torch.matmul(a, b.t()).to(torch.float32) + expected_with_weights = expected_no_weights * weights.view(-1, 1) + + print("\n=== Conditional Weight GEMM Test ===") + print(f"Matrix dimensions: M={m}, N={n}, K={k}") + print(f"Weights shape: {weights.shape}") + print(f"Weights mean: {weights.mean().item():.4f}") + + # Test without weights + torch.testing.assert_close( + c_no_weights.to(torch.float16), + expected_no_weights.to(torch.float16), + rtol=1e-2, + atol=1e-2, + ) + print("Test 1 passed: apply_weights=0 (no multiplication)") + + breakpoint() + # Test with weights + torch.testing.assert_close( + c_with_weights.to(torch.float16), + expected_with_weights.to(torch.float16), + rtol=1e-2, + atol=1e-2, + ) + print("Test 2 passed: apply_weights=1 (with multiplication)") + + # Verify that results are different + assert not torch.allclose( + c_no_weights, c_with_weights, rtol=1e-3, atol=1e-3 + ), "Outputs should differ when weights are applied" + + print( + f"\nOutput difference when weights applied: {((c_with_weights - c_no_weights).abs().mean().item()):.4f}" + ) + print("Conditional weight GEMM test passed!") + + +if __name__ == "__main__": + args = parse_args() + if args.list_tests: + list_tests() + else: + # run the test 10 times and collect how many times it passes + for i in range(args.repeat): + try: + globals()[args.test](args.debug) + print(f"Test {i} passed") + except Exception as e: + print(f"Error: {e}") + print(f"Test {i} failed") + exit(1) diff --git a/examples/python/3_atomics.py b/examples/python/3_atomics.py index 667ecf6576..16d0e31d86 100644 --- a/examples/python/3_atomics.py +++ b/examples/python/3_atomics.py @@ -172,6 +172,195 @@ def wave_kernel( print(c) +def test_histogram(is_debug=False): + NUM_EXPERTS = tkl.sym.NUM_EXPERTS + + """Atomic add operation to a histogram using dynamic mapping.""" + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={M: M, NUM_EXPERTS: NUM_EXPERTS}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, M, 0)] + constraints += [tkw.WorkgroupConstraint(NUM_EXPERTS, NUM_EXPERTS, 1)] + constraints += [tkw.WaveConstraint(M, M)] + constraints += [tkw.WaveConstraint(NUM_EXPERTS, NUM_EXPERTS)] + + i = tkw.IndexMapping.iterator(0) + d0 = tkw.IndexMapping.dynamic_val(0) + + topk_read_map = tkw.IndexMapping( + num_iterators=1, + inputs={M: d0}, + outputs={M: i}, + dynamic_val_mappings={M: i}, + ) + + expert_read_map = tkw.IndexMapping( + num_iterators=1, + inputs={NUM_EXPERTS: d0}, + outputs={NUM_EXPERTS: i}, + dynamic_val_mappings={NUM_EXPERTS: i}, + ) + + @tkw.wave(constraints) + def histogram_atomic_add( + topk_ids: tkl.Memory[M, ADDRESS_SPACE, tkl.i32], + experts: tkl.Memory[NUM_EXPERTS, ADDRESS_SPACE, tkl.i32], + ): + one_reg = tkw.Register[NUM_EXPERTS, tkl.i32](1) + tid = tkw.scalar(THREAD_0, tkl.i32) + + zero_vec = tkl.Register[NUM_EXPERTS, tkl.i32](0) + shmem = tkw.allocate( + shape=(NUM_EXPERTS,), + distributed_shape=(NUM_EXPERTS,), + dtype=tkl.i32, + ) + tkw.write(zero_vec, shmem) + + expert_id = tkw.read( + topk_ids, + mapping=topk_read_map, + mapping_dynamic_vals=(tid,), + elements_per_thread=1, + ) + + tkw.atomic_add( + one_reg, + shmem, + mapping=expert_read_map, + mapping_dynamic_vals=(expert_id,), + elements_per_thread=1, + ) + + tmp = tkw.read(shmem) + tkw.write(tmp, experts) + + num_experts = 10 + num_tokens = 64 + hyperparams = { + M: num_tokens, + NUM_EXPERTS: num_experts, + } + options = WaveCompileOptions( + subs=hyperparams, + canonicalize=True, + minimize_shared_allocs=False, + print_ir_after="all" if is_debug else [], + ) + histogram_atomic_add = wave_compile(options, histogram_atomic_add) + if is_debug: + print(histogram_atomic_add.asm) + + topk_ids = torch.randint(0, num_experts, (num_tokens,), dtype=torch.int32).cuda() + experts = torch.zeros((num_experts,), dtype=torch.int32).cuda() + histogram_atomic_add(topk_ids, experts) + print("topk_ids: ", topk_ids) + print("experts: ", experts) + print("expected experts: ", torch.bincount(topk_ids, minlength=num_experts)) + + +def test_large_histogram(is_debug=False): + NUM_EXPERTS = tkl.sym.NUM_EXPERTS + TOKEN_OFFSET = tkl.sym.TOKEN_OFFSET + """Atomic add operation to a histogram using dynamic mapping.""" + constraints: list[tkw.Constraint] = [] + constraints += [tkw.WorkgroupConstraint(M, M, 0)] + constraints += [tkw.WorkgroupConstraint(NUM_EXPERTS, NUM_EXPERTS, 1)] + constraints += [tkw.WaveConstraint(M, M)] + constraints += [tkw.WaveConstraint(NUM_EXPERTS, NUM_EXPERTS)] + + constraints += [tkw.TilingConstraint(TOKEN_OFFSET)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={M: M, NUM_EXPERTS: NUM_EXPERTS, TOKEN_OFFSET: 0}, + ) + ] + + i = tkw.IndexMapping.iterator(0) + d0 = tkw.IndexMapping.dynamic_val(0) + + topk_read_map = tkw.IndexMapping( + num_iterators=1, + inputs={M: d0}, + outputs={M: i}, + dynamic_val_mappings={M: i}, + ) + + expert_read_map = tkw.IndexMapping( + num_iterators=1, + inputs={NUM_EXPERTS: d0}, + outputs={NUM_EXPERTS: i}, + dynamic_val_mappings={NUM_EXPERTS: i}, + ) + + @tkw.wave(constraints) + def histogram_atomic_add( + topk_ids: tkl.Memory[M, ADDRESS_SPACE, tkl.i32], + experts: tkl.Memory[NUM_EXPERTS, ADDRESS_SPACE, tkl.i32], + ): + one_reg = tkw.Register[NUM_EXPERTS, tkl.i32](1) + zero_reg = tkw.Register[TOKEN_OFFSET, tkl.i32](0) + + loop_condition = TOKEN_OFFSET < M + + @tkw.iterate( + TOKEN_OFFSET, start=zero_reg, condition=loop_condition, init_args=[] + ) + def count_tokens(): + token_idx = tkw.self_index(TOKEN_OFFSET, tkl.i32) + tid_reg = tkw.Register[TOKEN_OFFSET, tkl.i32](THREAD_0) + token_idx = token_idx * tkl.Register[TOKEN_OFFSET, tkl.i32](64) + tid_reg + + expert_id = tkw.read( + topk_ids, + mapping=topk_read_map, + mapping_dynamic_vals=(token_idx,), + elements_per_thread=1, + ) + + tkw.atomic_add( + one_reg, + experts, + mapping=expert_read_map, + mapping_dynamic_vals=(expert_id,), + elements_per_thread=1, + ) + + next_token_idx = token_idx + tkl.Register[TOKEN_OFFSET, tkl.i32](64) + tkw.set_symbol(TOKEN_OFFSET, next_token_idx) + + num_experts = 10 + num_tokens = 64 + hyperparams = { + M: num_tokens, + NUM_EXPERTS: num_experts, + } + options = WaveCompileOptions( + subs=hyperparams, + canonicalize=True, + minimize_shared_allocs=False, + print_ir_after="all" if is_debug else [], + ) + histogram_atomic_add = wave_compile(options, histogram_atomic_add) + if is_debug: + print(histogram_atomic_add.asm) + + topk_ids = torch.randint(0, num_experts, (num_tokens,), dtype=torch.int32).cuda() + experts = torch.zeros((num_experts,), dtype=torch.int32).cuda() + + histogram_atomic_add(topk_ids, experts) + print("topk_ids: ", topk_ids) + print("experts: ", experts) + print("expected experts: ", torch.bincount(topk_ids, minlength=num_experts)) + + if __name__ == "__main__": args = parse_args() if args.list_tests: diff --git a/examples/python/5_gemm.py b/examples/python/5_gemm.py index d8a9d06634..f1620802f9 100644 --- a/examples/python/5_gemm.py +++ b/examples/python/5_gemm.py @@ -7,8 +7,8 @@ import torch import argparse - import wave_lang.kernel.wave as tkw +import wave_lang.kernel.lang as tkl from wave_lang.kernel._support.dtype import f16, f32, i32 from wave_lang.kernel._support.indexing import sym from wave_lang.kernel.lang.global_symbols import * @@ -1621,6 +1621,147 @@ def then(): print("GEMM test passed!") +def fused_gemms(is_debug=False): + """Fused GEMM kernel where we run two GEMMs back to back.""" + N1 = sym.N1 + N2 = sym.N2 + BLOCK_N1 = sym.BLOCK_N1 + BLOCK_N2 = sym.BLOCK_N2 + + # Define constraints for the kernel + constraints = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N2, BLOCK_N2, 1), + tkw.WaveConstraint(M, BLOCK_M / 2), + tkw.WaveConstraint(N2, BLOCK_N2 / 2), + tkw.TilingConstraint(K, BLOCK_K), + tkw.TilingConstraint(N1, BLOCK_N1), + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=tkw.MMAType.F32_16x16x16_F16, + vector_shapes={M: 16, N1: 16, N2: 16, K: 16}, + ), + ] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + d0 = tkw.IndexMapping.dynamic_val(0) + + a_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: d0, K: j}, + outputs={M: i, K: j}, + dynamic_val_mappings={M: i}, + ) + + w1_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={N1: i, K: j}, + outputs={N1: i, K: j}, + ) + + w2_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={N2: i, N1: j}, + outputs={N2: i, N1: j}, + ) + + @tkw.wave(constraints) + def gemm( + a: Memory[M, K, ADDRESS_SPACE_A, f16], # Input matrix A + w1: Memory[N1, K, ADDRESS_SPACE_B, f16], # Input matrix B + w2: Memory[N2, N1, ADDRESS_SPACE_B, f16], # Input matrix D + c: Memory[M, N2, ADDRESS_SPACE_C, f32], # Output matrix C + ): + # Initialize the accumulator register with zeros + c_reg1 = Register[M, N1, f32](0.0) + c_reg2 = Register[M, N2, f32](0.0) + + c_back1 = tkw.allocate( + shape=(M, N1), + distributed_shape=(M, N1), + dtype=tkl.f32, + ) + + # Iterate over the K dimension to compute the dot product + @tkw.iterate(K, init_args=[c_reg1]) + def repeat1(acc: Register[M, N1, f32]) -> Register[M, N1, f32]: + # Load elements from A and B + a_reg = tkw.read(a) + w1_reg = tkw.read(w1) + acc = tkw.mma(a_reg, w1_reg, acc) + return acc + + # Store the final result to C + tkw.write(repeat1, c_back1) + + @tkw.iterate(N1, init_args=[c_reg2]) + def repeat2(acc: Register[M, N2, f32]) -> Register[M, N2, f32]: + # Load elements from A and B + a_reg = tkw.read(c_back1) + a_reg = tkw.cast(a_reg, f16) + w2_reg = tkw.read(w2) + acc = tkw.mma(a_reg, w2_reg, acc) + return acc + + # Store the final result to C + tkw.write(repeat2, c) + + # Create test matrices + m, k = 64, 64 # Small dimensions for testing + n1, n2 = 64, 64 + # Initialize input matrices with random values + torch.manual_seed(0) + a = torch.randn(m, k, dtype=torch.float16, device="cuda") + w1 = torch.randn(n1, k, dtype=torch.float16, device="cuda") + w2 = torch.randn(n2, n1, dtype=torch.float16, device="cuda") + c = torch.zeros(m, n2, dtype=torch.float32, device="cuda") + c_back1 = torch.zeros(m, n1, dtype=torch.float32, device="cuda") + + # Set hyperparameters for compilation + hyperparams = { + ADDRESS_SPACE_A: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_B: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE, + BLOCK_M: 64, + BLOCK_N1: 64, + BLOCK_N2: 64, + BLOCK_K: 32, + M: m, + N1: n1, + N2: n2, + K: k, + } + + # Compile the kernel + options = WaveCompileOptions( + subs=hyperparams, + print_ir_after="all" if is_debug else [], + ) + options = set_default_run_config(options) + compiled_gemm = wave_compile(options, gemm) + + if is_debug: + print(compiled_gemm.asm) + with open("gemm.mlir", "w") as f: + f.write(compiled_gemm.asm) + + # Run the GEMM kernel + compiled_gemm(a, w1, w2, c) + + # Verify the result using PyTorch's matmul + expected = torch.matmul(a, w1.t()) + expected = torch.matmul(expected, w2.t()) + + # Check if results are close (accounting for floating-point precision) + assert torch.allclose( + c.to(torch.float16), expected, rtol=1e-2, atol=1e-2 + ), f"GEMM result doesn't match expected output\nMax difference: {(c - expected).abs().max()}" + + print("GEMM test passed!") + + if __name__ == "__main__": args = parse_args() if args.list_tests: diff --git a/examples/python/6_tensor_ops.py b/examples/python/6_tensor_ops.py new file mode 100644 index 0000000000..47976f1354 --- /dev/null +++ b/examples/python/6_tensor_ops.py @@ -0,0 +1,97 @@ +""" +Transformation of tensors, such as transpose, broadcast, split, concatenate, etc. +""" + +""" +GEMM Examples + +Demonstrates matrix multiplication patterns including basic GEMM, dynamic expert selection, +input reordering, scatter operations, and conditional weight application. +""" + +import torch +import wave_lang.kernel.wave as tkw +import wave_lang.kernel.lang as tkl +from wave_lang.kernel._support.indexing import sym +from wave_lang.kernel.lang.global_symbols import * +from wave_lang.kernel.lang.wave_types import * +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile +from wave_lang.kernel.wave.utils.run_utils import set_default_run_config + + +def split_tensor_test(): + + M = sym.M + N = sym.N + TWO_N = sym.TWO_N + BLOCK_M = sym.BLOCK_M + BLOCK_N = sym.BLOCK_N + + wave_size = 64 + + datatype = tkl.i32 + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.dynamic_val(0) + + x1_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, TWO_N: k}, + outputs={M: i, N: j}, + dynamic_val_mappings={TWO_N: j}, + ) + x2_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, TWO_N: k + N}, + outputs={M: i, N: j}, + dynamic_val_mappings={TWO_N: j}, + ) + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + vector_shapes={M: 1, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def split_tensor( + tensor: tkl.Memory[M, TWO_N, GLOBAL_ADDRESS_SPACE, datatype], + out1: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, datatype], + out2: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, datatype], + ): + # Compute global thread ID accounting for workgroup offset + tid = tkw.scalar(THREAD_0 + WORKGROUP_0 * wave_size, tkl.i32) + x1_reg = tkw.read(tensor, mapping=x1_read_map, mapping_dynamic_vals=(tid,)) + x2_reg = tkw.read(tensor, mapping=x2_read_map, mapping_dynamic_vals=(tid,)) + tkw.write(x1_reg, out1) + tkw.write(x2_reg, out2) + + hyperparams = { + M: 64, + N: 64, + TWO_N: 128, + BLOCK_M: 64, + BLOCK_N: 64, + } + + options = WaveCompileOptions(subs=hyperparams) + options = set_default_run_config(options) + split_tensor = wave_compile(options, split_tensor) + + tensor = torch.arange(64 * 128, dtype=torch.int32, device="cuda") + tensor = tensor.view(64, 128) + out1 = torch.zeros(64, 64, dtype=torch.int32, device="cuda") + out2 = torch.zeros(64, 64, dtype=torch.int32, device="cuda") + split_tensor(tensor, out1, out2) + print("Out1: ", out1) + print("Out2: ", out2) + + +if __name__ == "__main__": + split_tensor_test() diff --git a/examples/test.py b/examples/test.py new file mode 100644 index 0000000000..1f7b9e4317 --- /dev/null +++ b/examples/test.py @@ -0,0 +1,657 @@ +import wave_lang.kernel.lang as tkl +import wave_lang.kernel.wave as tkw +from wave_lang.kernel.lang.global_symbols import * +import torch +import sympy + +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile + + +M = tkl.sym.M +N = tkl.sym.N +K = tkl.sym.K +B = tkl.sym.B +BLOCK_M = tkl.sym.BLOCK_M +BLOCK_N = tkl.sym.BLOCK_N +BLOCK_K = tkl.sym.BLOCK_K +BLOCK_B = tkl.sym.BLOCK_B +LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD +STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD +ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE +ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 + + +def get_wave_compile_options( + canonicalize: bool = False, dynamic_symbols=[], additional_symbols={} +): + bindings = { + M: 16, + N: 16, + K: 16, + BLOCK_M: 16, + BLOCK_N: 16, + BLOCK_K: 16, + ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value, + } + bindings.update(additional_symbols) + + # Remove dynamic symbols from the bindings. + for sym in dynamic_symbols: + if sym in bindings: + del bindings[sym] + + return WaveCompileOptions( + subs=bindings, + canonicalize=canonicalize, + dynamic_symbols=dynamic_symbols, + ) + + +def test_read_write_dynamic_mapping_broadcast(): + ONE = tkl.sym.ONE + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + vector_shapes={M: 16, N: 16, ONE: 1}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.dynamic_val(0) + mapping = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, N: k + j % 16}, + outputs={M: i, N: j}, + dynamic_val_mappings={M: i, ONE: j // 16}, + ) + + @tkw.wave(constraints) + def read_write_dynamic_mapping_broadcast( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32], + off: tkl.Memory[M, ONE, ADDRESS_SPACE, tkl.i32], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32], + ): + offset = tkw.read(off) + res = tkw.read( + a, + mapping=mapping, + mapping_dynamic_vals=(offset,), + ) + tkw.write(res, b) + + read_write_dynamic_mapping_broadcast = wave_compile( + get_wave_compile_options(canonicalize=True, additional_symbols={ONE: 1}), + read_write_dynamic_mapping_broadcast, + ) + print(read_write_dynamic_mapping_broadcast.asm) + + # create input tensors + a = torch.arange(0, 256, dtype=torch.int32).reshape(16, 16).cuda() + off = torch.arange(0, 16, dtype=torch.int32).reshape(16, 1).cuda() + b = torch.zeros((16, 16), dtype=torch.int32).cuda() + + read_write_dynamic_mapping_broadcast(a, off, b) + print(a) + print(off) + print(b) + + +def test_one_read_write_dynamic_mapping_broadcast(): + ONE = tkl.sym.ONE + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + vector_shapes={N: 16, ONE: 1}, + ) + ] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + j = tkw.IndexMapping.iterator(0) + k = tkw.IndexMapping.dynamic_val(0) + mapping = tkw.IndexMapping( + num_iterators=1, + inputs={N: k + j % 16}, + outputs={N: j}, + dynamic_val_mappings={ONE: j // 16}, + ) + + @tkw.wave(constraints) + def read_write_dynamic_mapping_broadcast( + a: tkl.Memory[N, ADDRESS_SPACE, tkl.i32], + off: tkl.Memory[ONE, ADDRESS_SPACE, tkl.i32], + b: tkl.Memory[N, ADDRESS_SPACE, tkl.i32], + ): + offset = tkw.read(off) + res = tkw.read( + a, + mapping=mapping, + mapping_dynamic_vals=(offset,), + ) + tkw.write(res, b) + + read_write_dynamic_mapping_broadcast = wave_compile( + get_wave_compile_options(canonicalize=True, additional_symbols={ONE: 1}), + read_write_dynamic_mapping_broadcast, + ) + print(read_write_dynamic_mapping_broadcast.asm) + + # create input tensors + a = ( + torch.arange(0, 16, dtype=torch.int32) + .reshape( + 16, + ) + .cuda() + ) + off = torch.ones((1,), dtype=torch.int32).cuda() + b = torch.zeros((16,), dtype=torch.int32).cuda() + + read_write_dynamic_mapping_broadcast(a, off, b) + print(a) + print(off) + print(b) + + +def test_one_nooffset_dynamic_mapping_broadcast(): + ONE = tkl.sym.ONE + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + vector_shapes={N: 16, ONE: 1}, + ) + ] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + j = tkw.IndexMapping.iterator(0) + k = tkw.IndexMapping.dynamic_val(0) + mapping = tkw.IndexMapping( + num_iterators=1, + inputs={N: k + j % 16}, + outputs={N: j}, + dynamic_val_mappings={ONE: j // 16}, + ) + + seq_len_mapping = tkw.IndexMapping( + num_iterators=1, + inputs={N: j}, + outputs={N: j}, + ) + + seq_len_mapping_w = tkw.IndexMapping( + num_iterators=1, + inputs={N: j}, + outputs={N: j + 1}, + ) + + @tkw.wave(constraints) + def read_write_dynamic_mapping_broadcast( + a: tkl.Memory[N, ADDRESS_SPACE, tkl.i32], + b: tkl.Memory[N, ADDRESS_SPACE, tkl.i32], + ): + offset = tkw.Register[ONE, tkl.i32](1) + # offset = tkw.scalar(1, tkl.i32) + # res = tkw.read( + # a, + # mapping=mapping, + # mapping_dynamic_vals=(offset,), + # ) + temp = tkw.Register[N, tkl.i32](0) + temp = tkw.read( + a, + mapping=seq_len_mapping, + ) + tkw.write(temp, b, mapping=seq_len_mapping_w) + + read_write_dynamic_mapping_broadcast = wave_compile( + get_wave_compile_options(canonicalize=True, additional_symbols={ONE: 1}), + read_write_dynamic_mapping_broadcast, + ) + print(read_write_dynamic_mapping_broadcast.asm) + + # create input tensors + a = ( + torch.arange(0, 16, dtype=torch.int32) + .reshape( + 16, + ) + .cuda() + ) + b = torch.zeros((16,), dtype=torch.int32).cuda() + + read_write_dynamic_mapping_broadcast(a, b) + print(a) + print(b) + + +def test_iteration_with_condition(): + LIMIT_VAL = tkl.sym.LIMIT_VAL + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={B: 0, M: 64, LIMIT_VAL: 0}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + # B is iterated over and so we define a tiling constraint on it. + # However, there is no notion of tile size for the iteration as + # it is an unstructured loop. + constraints += [tkw.TilingConstraint(B)] + + i = tkw.IndexMapping.iterator(0) + d0 = tkw.IndexMapping.dynamic_val(0) + + limit_val_map = tkw.IndexMapping( + num_iterators=1, + inputs={M: d0}, + outputs={M: i}, + dynamic_val_mappings={M: i}, + ) + + @tkw.wave(constraints) + def iterated_gemm( + a: tkl.Memory[M, ADDRESS_SPACE, tkl.i32], + b: tkl.Memory[M, ADDRESS_SPACE, tkl.i32], + c: tkl.Memory[M, ADDRESS_SPACE_0, tkl.i32], + init_value: tkl.i32, # type: ignore + ): + + tid = tkw.scalar(tkw.THREAD_0, tkl.i32) + limit_val = tkw.read( + a, mapping=limit_val_map, mapping_dynamic_vals=(tid,), elements_per_thread=1 + ) + tkw.set_symbol(LIMIT_VAL, limit_val) + condition = B < LIMIT_VAL + + init_val = tkw.read( + b, mapping=limit_val_map, mapping_dynamic_vals=(tid,), elements_per_thread=1 + ) + ones_b = tkw.Register[B, tkl.i32](1) + + # init_val = tkw.scalar(0, tkl.i32) + @tkw.iterate(B, start=init_val, condition=condition, init_args=[]) + def body(): + c_reg = tkw.read(c) + b_reg = tkw.read(b) + + # c_reg = c_reg + b_reg + c_reg = tkw.Register[M, tkl.i32](tkw.THREAD_0) + tkw.write(c_reg, c) + + # Set the next value for the iteration. + # In this case, we are using a simple increment operation, + # but this can be replaced with any other operation. + index_b = tkw.self_index(B, tkl.i32) + next_value = tkw.apply_expr(index_b, lambda x: x + 1) + # next_value = index_b + ones_b + tkw.set_symbol(B, next_value) + + options = WaveCompileOptions( + subs={ + M: 64, + B: 10, + BLOCK_M: 64, + BLOCK_B: 1, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + print_ir_after="all", + print_ir_before="all", + ) + iterated_gemm = wave_compile(options, iterated_gemm) + print(iterated_gemm.asm) + + # generate random input tensors between -1 and 1 + a = torch.randint(0, 4, (64,), dtype=torch.int32).cuda() + b = torch.randint(1, 2, (64,), dtype=torch.int32).cuda() + c = torch.zeros((64,), dtype=torch.int32).cuda() + + iterated_gemm(a, b, c, 0) + print(a) + print(b) + print(c) + + +def test_atomic_add_return_value(): + LIMIT_VAL = tkl.sym.LIMIT_VAL + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={B: 0, M: 64, LIMIT_VAL: 0}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + # B is iterated over and so we define a tiling constraint on it. + # However, there is no notion of tile size for the iteration as + # it is an unstructured loop. + constraints += [tkw.TilingConstraint(B)] + + i = tkw.IndexMapping.iterator(0) + d0 = tkw.IndexMapping.dynamic_val(0) + + simple_read_mapping = tkw.IndexMapping( + num_iterators=1, + inputs={M: i}, + outputs={M: i}, + ) + + @tkw.wave(constraints) + def iterated_gemm( + a: tkl.Memory[M, ADDRESS_SPACE, tkl.i32], + c: tkl.Memory[M, ADDRESS_SPACE_0, tkl.i32], + ): + + one_reg = tkw.Register[M, tkl.i32](1) + res = tkw.atomic_add(one_reg, a, mapping=simple_read_mapping) + tkw.write(res, c) + + options = WaveCompileOptions( + subs={ + M: 64, + B: 10, + BLOCK_M: 64, + BLOCK_B: 1, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + print_ir_after="all", + print_ir_before="all", + ) + iterated_gemm = wave_compile(options, iterated_gemm) + print(iterated_gemm.asm) + + # generate random input tensors between -1 and 1 + a = torch.randint(1, 2, (64,), dtype=torch.int32).cuda() + c = torch.zeros((64,), dtype=torch.int32).cuda() + + iterated_gemm(a, c) + print(a) + print(c) + + +def test_read_back_scalar(): + ONE = tkl.sym.ONE + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={B: 0, M: 64, ONE: 1}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(ONE, ONE, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(ONE, ONE)] + + i = tkw.IndexMapping.iterator(0) + d0 = tkw.IndexMapping.dynamic_val(0) + + simple_read_mapping = tkw.IndexMapping( + num_iterators=1, + inputs={M: i}, + outputs={M: i}, + ) + + dynamic_read_mapping = tkw.IndexMapping( + num_iterators=1, + inputs={M: d0}, + outputs={M: i}, + dynamic_val_mappings={M: i}, + ) + + @tkw.wave(constraints) + def iterated_gemm( + a: tkl.Memory[M, ADDRESS_SPACE, tkl.i32], + b: tkl.Memory[M, ADDRESS_SPACE, tkl.i32], + c: tkl.Memory[ONE, ADDRESS_SPACE, tkl.i32], + ): + + tid = tkw.scalar(THREAD_0, tkl.i32) + one_reg = tkw.Register[M, tkl.i32](1) + res = tkw.atomic_add( + one_reg, a, mapping=dynamic_read_mapping, mapping_dynamic_vals=(tid,) + ) + val = tkw.read( + res, + mapping=dynamic_read_mapping, + mapping_dynamic_vals=(tid,), + elements_per_thread=1, + ) + tkw.write(val, c) + + options = WaveCompileOptions( + subs={ + M: 64, + ONE: 1, + BLOCK_M: 64, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + print_ir_after="all", + print_ir_before="all", + minimize_shared_allocs=False, + ) + iterated_gemm = wave_compile(options, iterated_gemm) + print(iterated_gemm.asm) + + # generate random input tensors between -1 and 1 + a = torch.randint(1, 2, (64,), dtype=torch.int32).cuda() + b = torch.zeros((64,), dtype=torch.int32).cuda() + c = torch.zeros((1,), dtype=torch.int32).cuda() + + iterated_gemm(a, b, c) + print(a) + print(b) + print(c) + + +def test_reduce_sum(): + shape = (64, 128) + M = tkl.sym.M + N = tkl.sym.N + wave_size = 64 + BLOCK_M = 1 + BLOCK_N = sympy.ceiling(N / wave_size) * wave_size + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + vector_shapes={M: 1, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + ): + res = tkw.read(a) + res = tkw.sum(res, dim=N) + tkw.write(res, c) + + torch.manual_seed(1) + a = torch.randn(shape, dtype=torch.float16, device="cuda") + c = torch.zeros((shape[0],), dtype=torch.float16, device="cuda") + ref = torch.sum(a, dim=-1) + options = WaveCompileOptions( + subs={ + M: shape[0], + N: shape[1], + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + ) + test = wave_compile(options, test) + + test(a, c) + torch.testing.assert_close(ref, c, atol=0.1, rtol=1e-05) + print("Test passed") + + +def test_broadcast_reduce_sum(): + shape = (64, 128) + M = tkl.sym.M + N = tkl.sym.N + wave_size = 64 + BLOCK_M = 1 + BLOCK_N = sympy.ceiling(N / wave_size) * wave_size + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + vector_shapes={M: 1, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + c_temp: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + ): + lhs = tkw.read(a) + rhs = tkw.read(b) + rhs = tkw.broadcast(rhs, (M, N)) + res = lhs * rhs + tkw.write(res, c_temp) + res = tkw.sum(res, dim=N) + tkw.write(res, c) + + a = torch.randn(shape, dtype=torch.float16, device="cuda") + b = torch.randn(shape[0], dtype=torch.float16, device="cuda") + c = torch.zeros((shape[0],), dtype=torch.float16, device="cuda") + c_temp = torch.zeros(shape, dtype=torch.float16, device="cuda") + + ref_temp = a * b.view(-1, 1) + ref = torch.sum(ref_temp, dim=-1) + options = WaveCompileOptions( + subs={ + M: shape[0], + N: shape[1], + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + ) + test = wave_compile(options, test) + + test(a, b, c, c_temp) + torch.testing.assert_close(ref_temp, c_temp, atol=0.1, rtol=1e-05) + torch.testing.assert_close(ref, c, atol=0.1, rtol=1e-05) + print("Test passed") + + +def test_moe_weighted_sum(): + """Test 3D weighted sum matching MOE pattern: + final_res = (out.view(B, -1, w2.shape[1]) * topk_weights.view(B, -1, 1)).sum(dim=1) + """ + shape = (64, 64, 128) # (B, K, D) + B = tkl.sym.B + K = tkl.sym.K + D = tkl.sym.D + wave_size = 64 + BLOCK_B = 1 + BLOCK_K = sympy.ceiling(shape[1] / wave_size) * wave_size + BLOCK_D = 1 + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + vector_shapes={B: BLOCK_B, K: BLOCK_K, D: BLOCK_D}, + ) + ] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.WorkgroupConstraint(K, BLOCK_K, 0)] + constraints += [tkw.WorkgroupConstraint(D, BLOCK_D, 1)] + constraints += [tkw.WaveConstraint(B, BLOCK_B)] + constraints += [tkw.WaveConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(D, BLOCK_D)] + + @tkw.wave(constraints) + def test( + out: tkl.Memory[B, K, D, ADDRESS_SPACE, tkl.f16], + topk_weights: tkl.Memory[B, K, ADDRESS_SPACE, tkl.f16], + result: tkl.Memory[B, D, ADDRESS_SPACE, tkl.f16], + temp_output: tkl.Memory[B, K, D, ADDRESS_SPACE, tkl.f16], + ): + # Read 3D tensor: (B, K, D) + out_vals = tkw.read(out) + # Read 2D weights: (B, K) + weights = tkw.read(topk_weights) + # Broadcast weights to (B, K, D) + weights_broadcast = tkw.broadcast(weights, [B, K, D]) + # Multiply element-wise: (B, K, D) * (B, K, D) + weighted = out_vals * weights_broadcast + # Write intermediate result for verification + tkw.write(weighted, temp_output) + # Sum along K dimension: (B, K, D) -> (B, D) + res = tkw.sum(weighted, dim=K) + # Write final result + tkw.write(res, result) + + # Create input tensors + out = torch.randn(shape, dtype=torch.float16, device="cuda") + topk_weights = torch.randn(shape[0], shape[1], dtype=torch.float16, device="cuda") + result = torch.zeros((shape[0], shape[2]), dtype=torch.float16, device="cuda") + temp_output = torch.zeros(shape, dtype=torch.float16, device="cuda") + + # Reference computation matching MOE pattern + ref_temp = out * topk_weights.view(shape[0], shape[1], 1) + ref = torch.sum(ref_temp, dim=1) + + options = WaveCompileOptions( + subs={ + B: shape[0], + K: shape[1], + D: shape[2], + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + ) + test = wave_compile(options, test) + + test(out, topk_weights, result, temp_output) + torch.testing.assert_close(ref_temp, temp_output, atol=0.1, rtol=1e-05) + torch.testing.assert_close(ref, result, atol=0.1, rtol=1e-05) + print("Test passed") + + +if __name__ == "__main__": + import sys + + globals()[sys.argv[1]]() diff --git a/tests/kernel/moe/moe_align_block_size_test.py b/tests/kernel/moe/moe_align_block_size_test.py index f0bc0c800c..8be206d81c 100644 --- a/tests/kernel/moe/moe_align_block_size_test.py +++ b/tests/kernel/moe/moe_align_block_size_test.py @@ -4,18 +4,25 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Third-party imports import pytest import torch +import math + +# Local imports +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile +from wave_lang.kernel.wave.templates.moe import get_moe_align_block_size_kernel + from .torch_kernels import moe_align_block_size_pytorch -import torch.nn.functional as F + torch.manual_seed(0) -num_tokens_values = [1, 33, 256] +num_tokens_values = [32] topk_values = [2] block_size_values = [16, 32, 64] -num_experts_values = [4, 8, 64] +num_experts_values = [4] def verify_moe_align_block_size_results( @@ -104,10 +111,11 @@ def test_moe_align_block_size( """ device = "cuda" - scores = torch.rand(num_tokens, num_experts) + scores = torch.rand(num_tokens, num_experts, device=device) # Get topk expert indices for each token _, topk_ids = torch.topk(scores, k=topk, dim=1) + topk_ids = topk_ids.to(device) # Conservative upper bound that accounts for both the number of tokens and # the maximum possible padding needed per expert @@ -136,6 +144,101 @@ def test_moe_align_block_size( topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad ) + moe_align_block_size, hyperparams, dynamic_symbols = ( + get_moe_align_block_size_kernel( + num_tokens, + num_experts, + block_size, + topk_ids.numel(), + max_num_m_blocks, + max_num_tokens_padded, + topk, + ) + ) + + options = WaveCompileOptions( + subs=hyperparams, + minimize_shared_allocs=False, + ) + + kernel = wave_compile( + options, + moe_align_block_size, + ) + + expert_counts_buffer = torch.randint( + size=(num_experts,), dtype=torch.int32, device="cuda", low=0, high=1 + ) + padded_counts_buffer = torch.randint( + size=(num_experts,), dtype=torch.int32, device="cuda", low=0, high=1 + ) + cumsum_buffer = torch.randint( + size=(num_experts,), dtype=torch.int32, device="cuda", low=0, high=1 + ) + cumsum_exclusive = torch.randint( + size=(num_experts,), dtype=torch.int32, device="cuda", low=0, high=1 + ) + num_blocks_buffer = torch.randint( + size=(num_experts,), dtype=torch.int32, device="cuda", low=0, high=1 + ) + + wave_expert_ids = torch.empty( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + + wave_sorted_ids = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) + + wave_num_tokens_post_pad = torch.empty( + (1), dtype=torch.int32, device=topk_ids.device + ) + flat_topk = topk_ids.view(-1).to(torch.int32) + kernel( + flat_topk, + wave_expert_ids, + expert_counts_buffer, + padded_counts_buffer, + cumsum_buffer, + cumsum_exclusive, + num_blocks_buffer, + wave_sorted_ids, + ) + + # print("Block size:", block_size) + # print("\n\n============Wave outputs================") + # print("Histogram:", expert_counts_buffer) + # print("Padded:", padded_counts_buffer) + # print("Cumsum (i):", cumsum_buffer) + # print("Cumsum (e):", cumsum_exclusive) + # print("Num blocks:", num_blocks_buffer) + # print("Expert IDs:", wave_expert_ids) + + # print("Sorted IDs:") + # for i in range(math.ceil(max_num_tokens_padded / block_size)): + # for j in range(block_size): + # if i * block_size + j >= max_num_tokens_padded: + # break + # print(wave_sorted_ids[i * block_size + j].item(), end=" ") + # print() + + # print("\n\n============Reference outputs================") + # print("Sorted IDs:") + # for i in range(math.ceil(max_num_tokens_padded / block_size)): + # for j in range(block_size): + # if i * block_size + j >= max_num_tokens_padded: + # break + # print(sorted_ids[i * block_size + j].item(), end=" ") + # print() + # print("Expert IDs:", expert_ids) + + # print("Num tokens post pad:", num_tokens_post_pad.item()) + verify_moe_align_block_size_results( - topk_ids, sorted_ids, expert_ids, num_tokens_post_pad, block_size, num_experts + topk_ids, + wave_sorted_ids, + wave_expert_ids, + cumsum_buffer[-1], + block_size, + num_experts, ) diff --git a/tests/kernel/moe/silu_and_mul_test.py b/tests/kernel/moe/silu_and_mul_test.py new file mode 100644 index 0000000000..d4ee07817e --- /dev/null +++ b/tests/kernel/moe/silu_and_mul_test.py @@ -0,0 +1,86 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest +import torch +import wave_lang.kernel as tk +import wave_lang.kernel.lang as tkl +from wave_lang.kernel.lang.global_symbols import * +from wave_lang.kernel.wave.utils.run_utils import ( + set_default_run_config, +) +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile +from wave_lang.kernel.wave.utils.general_utils import ( + get_default_scheduling_params, +) +from wave_lang.kernel.wave.templates.moe import ( + get_silu_and_mul_kernel, +) +from wave_lang.kernel.lang import DataType +import torch.nn.functional as F + +torch.manual_seed(0) + + +def silu_and_mul_ref(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + + +def silu_and_mul_kernel( + m: int = 32, + n: int = 64, + dtype: torch.dtype = torch.float32, +): + """Test the SiLU and Mul kernel against PyTorch reference""" + device = "cuda" + + # Create test inputs + x = torch.randn(m, 2 * n, dtype=dtype, device=device) + + # Reference implementation + ref_output = silu_and_mul_ref(x) + + # Kernel implementation + output = torch.zeros(m, n, dtype=dtype, device=device) + + # Get and compile the kernel + silu_and_mul, symbols = get_silu_and_mul_kernel(m, n, tkl.f32) + symbols.update(get_default_scheduling_params()) + options = WaveCompileOptions(subs=symbols) + options = set_default_run_config(options) + silu_and_mul = wave_compile(options, silu_and_mul) + + # Run the kernel + silu_and_mul(x, output) + + # Compare results + rtol, atol = 1e-4, 1e-4 + torch.testing.assert_close( + output, ref_output, rtol=rtol, atol=atol, msg="SiLU and Mul output mismatch" + ) + + print(f"SiLU and Mul test passed for shape [{m}, {n}] with dtype {dtype}") + + +# Test parameters +m_values = [64] +n_values = [128] +dtypes = [torch.float32] + + +@pytest.mark.parametrize("m", m_values) +@pytest.mark.parametrize("n", n_values) +@pytest.mark.parametrize("dtype", dtypes) +def test_silu_and_mul_parametrized(m: int, n: int, dtype: torch.dtype): + """Parametrized test for SiLU and Mul kernel""" + silu_and_mul_kernel(m, n, dtype) + + +if __name__ == "__main__": + # Run a simple test when script is executed directly + silu_and_mul_kernel() + print("All SiLU and Mul tests passed!") diff --git a/wave_lang/kernel/ops/wave_ops.py b/wave_lang/kernel/ops/wave_ops.py index d729b7fb27..cd74dc8bec 100644 --- a/wave_lang/kernel/ops/wave_ops.py +++ b/wave_lang/kernel/ops/wave_ops.py @@ -2500,6 +2500,8 @@ def indexing_dims(self) -> list[IndexSymbol] | list[list[IndexSymbol]]: expand_dims: list[IndexSymbol] = [] subgraph = self.get_root_graph().subgraphs[self.subgraph_name] return_node = get_custom(subgraph.output_node()) + if return_node.return_vals[0] is None: + return [] assert isinstance(return_node, Output) for return_val in return_node.yielded_values: return_dims = get_custom(return_val).indexing_dims diff --git a/wave_lang/kernel/wave/analysis/index_sequence_analysis.py b/wave_lang/kernel/wave/analysis/index_sequence_analysis.py index 4c72b29663..edffd40cb7 100644 --- a/wave_lang/kernel/wave/analysis/index_sequence_analysis.py +++ b/wave_lang/kernel/wave/analysis/index_sequence_analysis.py @@ -492,10 +492,17 @@ def set_thread_independent_index( continue # If the constraint is a tiling constraint, and the node - # is outside a reduction, we don't apply the constraint. + # is outside a reduction for this specific dimension, we don't apply the constraint. if isinstance(constraint, TilingConstraint): if not hasattr(custom.graph, "parent_op"): continue + # Check if we're inside an iterate for this specific tiled dimension + parent_iterate = get_custom(custom.graph.parent_op) + if ( + not isinstance(parent_iterate, Iterate) + or parent_iterate.axis != constraint.dim + ): + continue if isinstance(constraint, WorkgroupConstraint) and has_grid_constraint: continue diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index 3ffe0078aa..36d06600fc 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -210,6 +210,7 @@ def check_contiguous_index(): simplified_index[dim].start.subs({GPR_NUM: 0}) + offset[j], 1, 1 ) for j, dim in enumerate(symbolic_shape) + if dim in simplified_index } extract.index = write.index write.vector_shapes = vector_shapes diff --git a/wave_lang/kernel/wave/decompose_reduce_ops.py b/wave_lang/kernel/wave/decompose_reduce_ops.py index 17ed85147e..27858d9732 100644 --- a/wave_lang/kernel/wave/decompose_reduce_ops.py +++ b/wave_lang/kernel/wave/decompose_reduce_ops.py @@ -369,11 +369,11 @@ def decompose_reduce_ops( raise NotImplementedError( "NYI: Expect all reduce_src to have same fastest dim." ) - if reduction_dim is not src_fastest_dims[0]: - raise NotImplementedError( - f"Only implemented reduction on fastest dimension. Got {reduction_dim} and {src_fastest_dims}." - f"\n{custom}" - ) + # if reduction_dim is not src_fastest_dims[0]: + # raise NotImplementedError( + # f"Only implemented reduction on fastest dimension. Got {reduction_dim} and {src_fastest_dims}." + # f"\n{custom}" + # ) get_thread_shape = lambda index: max( subs_idxc(x.size) for x in index.values() diff --git a/wave_lang/kernel/wave/decompose_scan_ops.py b/wave_lang/kernel/wave/decompose_scan_ops.py index 7ca57ee86d..c229daef20 100644 --- a/wave_lang/kernel/wave/decompose_scan_ops.py +++ b/wave_lang/kernel/wave/decompose_scan_ops.py @@ -36,7 +36,7 @@ from .constraints import HardwareConstraint, WaveConstraint, WorkgroupConstraint from .utils.classes import ShuffleMode from .utils.general_utils import all_equal, delinearize_index -from .utils.graph_utils import DCE, get_outer_node +from .utils.graph_utils import get_outer_node def get_graph_node( @@ -563,4 +563,5 @@ def decompose_scan_ops( custom.fx_node, final_scan[user.expanded_dims[scan_dim]] ) - DCE(trace) + custom.graph.erase_node(custom.fx_node) + # DCE(trace) diff --git a/wave_lang/kernel/wave/templates/moe.py b/wave_lang/kernel/wave/templates/moe.py index 4dfd441dfd..a6b1b28e95 100644 --- a/wave_lang/kernel/wave/templates/moe.py +++ b/wave_lang/kernel/wave/templates/moe.py @@ -8,7 +8,10 @@ import wave_lang.kernel.lang as tkl import wave_lang.kernel.wave as tkw +from wave_lang.kernel._support.dtype import f16, f32, i32 +from wave_lang.kernel._support.indexing import sym from wave_lang.kernel.lang.global_symbols import * +from wave_lang.kernel.lang.wave_types import * from wave_lang.kernel.wave.constraints import MMAType from wave_lang.kernel._support.dtype import DataType import sympy @@ -694,6 +697,8 @@ def get_silu_and_mul_kernel( # Input sizes M = tkl.sym.M N = tkl.sym.N + TWO_N = tkl.sym.TWO_N + # Each workgroup works on single row of input data, and rows are further # split into blocks of size up to 256. We have single wave per WG, # and with default wave size of 32, each thread is operating on up to 8 @@ -709,7 +714,6 @@ def get_silu_and_mul_kernel( constraints: list[tkw.Constraint] = [ tkw.HardwareConstraint( threads_per_wave=wave_size, - waves_per_block=(1, 1, 1), vector_shapes={M: BLOCK_M, N: BLOCK_N}, ) ] @@ -718,28 +722,47 @@ def get_silu_and_mul_kernel( constraints += [tkw.WaveConstraint(M, BLOCK_M)] constraints += [tkw.WaveConstraint(N, BLOCK_N)] + # Create index mappings to read gate (first half) and up (second half) + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.dynamic_val(0) + x1_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, TWO_N: k}, + outputs={M: i, N: j}, + dynamic_val_mappings={TWO_N: j}, + ) + + x2_read_map = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, TWO_N: k + N}, + outputs={M: i, N: j}, + dynamic_val_mappings={TWO_N: j}, + ) + @tkw.wave(constraints) def silu_and_mul( - x1: tkl.Memory[M, N, ADDRESS_SPACE, datatype], - x2: tkl.Memory[M, N, ADDRESS_SPACE, datatype], + gemm1_out: tkl.Memory[M, TWO_N, ADDRESS_SPACE, datatype], out: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, datatype], ): - x1_reg = tkw.read(x1) - cst_m1 = tkl.Register[M, N, datatype](-1.0) - cst_1 = tkl.Register[M, N, datatype](-1.0) - exp_out = tkw.exp2(x1_reg * cst_m1) + # Read x1 (first half: columns 0 to N-1) + # Compute global thread ID accounting for workgroup offset + tid = tkw.scalar(THREAD_0 + WORKGROUP_0 * wave_size, tkl.i32) + x1_reg = tkw.read(gemm1_out, mapping=x1_read_map, mapping_dynamic_vals=(tid,)) + cst_m1 = tkw.Register[M, N, datatype](-1.0) + cst_1 = tkw.Register[M, N, datatype](1.0) + exp_out = tkw.exp(x1_reg * cst_m1) sigmoid = cst_1 / (cst_1 + exp_out) silu = sigmoid * x1_reg - - x2_reg = tkw.read(x2) + x2_reg = tkw.read(gemm1_out, mapping=x2_read_map, mapping_dynamic_vals=(tid,)) res = silu * x2_reg - tkw.write(res, out) hyperparams = { - ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE, M: m, N: n, + TWO_N: 2 * n, } return silu_and_mul, hyperparams diff --git a/wave_lang/kernel/wave/utils/mapping_utils.py b/wave_lang/kernel/wave/utils/mapping_utils.py index 24ed506463..4133eab37e 100644 --- a/wave_lang/kernel/wave/utils/mapping_utils.py +++ b/wave_lang/kernel/wave/utils/mapping_utils.py @@ -8,6 +8,7 @@ import sympy import torch.fx as fx +from ...ops.wave_ops import get_custom from ..._support.indexing import IndexingContext from ...lang.wave_types import IndexMapping from .general_utils import infer_dim, get_fastest_index @@ -231,8 +232,10 @@ def check_is_dynamic_vals_broadcasted(nodes: list[fx.Node]) -> bool: This function checks all nodes in the list and returns True only if all dynamic values are broadcasted (size 1 in all dims). """ + for node in nodes: - index = node.index + custom = get_custom(node) + index = custom.index assert index is not None, f"Node {node} has no index" if any(subs_idxc(i.size) > 1 for i in index.values()): return False