diff --git a/examples/xegpu/enumerate_matmul_schedules.py b/examples/xegpu/enumerate_matmul_schedules.py index 32bdbdd0..24ea2524 100644 --- a/examples/xegpu/enumerate_matmul_schedules.py +++ b/examples/xegpu/enumerate_matmul_schedules.py @@ -44,6 +44,8 @@ "prefetch_b_n": 16, "prefetch_a_nb": 1, "prefetch_b_nb": 1, + "transpose_a": False, + "transpose_b": False, } # Check that at least one constraint was reified into the schedule. diff --git a/examples/xegpu/matmul.py b/examples/xegpu/matmul.py index c73dfa4a..05db6612 100644 --- a/examples/xegpu/matmul.py +++ b/examples/xegpu/matmul.py @@ -1,5 +1,7 @@ # RUN: %PYTHON %s --sizes 512 1024 128 --dump-kernel=xegpu-wg | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s +# RUN: %PYTHON %s --dump-kernel=xegpu-wg --transpose-a | FileCheck %s +# RUN: %PYTHON %s --dump-kernel=xegpu-wg --transpose-b | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg --bias | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg --relu | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg --bias --relu | FileCheck %s @@ -30,6 +32,7 @@ from lighthouse.schedule.xegpu import mlp_schedule, xegpu_to_binary from lighthouse.utils.numpy import mlir_to_numpy_dtype from lighthouse.ingress.mlir_gen import generate_gpu_matmul_payload, get_mlir_elem_type +from lighthouse.schedule.xegpu import XeGPUParameterSelector def matmul_complexity( @@ -77,6 +80,8 @@ class XeGPUMatMul: K: int = 1024 ab_type: ir.Type | str | None = None c_type: ir.Type | str | None = None + transpose_a: bool = False + transpose_b: bool = False has_bias: bool = False has_relu: bool = False accumulate_c: bool = True @@ -96,8 +101,8 @@ def __post_init__(self): assert isinstance(self.c_type, ir.F32Type), "Only f32 type is supported for C" self.ab_dtype = mlir_to_numpy_dtype(self.ab_type) self.c_dtype = mlir_to_numpy_dtype(self.c_type) - self.a_shape = (self.M, self.K) - self.b_shape = (self.K, self.N) + self.a_shape = (self.M, self.K) if not self.transpose_a else (self.K, self.M) + self.b_shape = (self.K, self.N) if not self.transpose_b else (self.N, self.K) self.c_shape = (self.M, self.N) self.bias_shape = (self.N,) @@ -142,6 +147,8 @@ def payload_module(self) -> ir.Module: K=self.K, ab_type=self.ab_type, c_type=self.c_type, + transpose_a=self.transpose_a, + transpose_b=self.transpose_b, has_bias=self.has_bias, has_relu=self.has_relu, accumulate_c=self.accumulate_c, @@ -190,6 +197,11 @@ def check_results( C, A, B = host_inputs[:3] bias = host_inputs[3] if mmul.has_bias else None + if mmul.transpose_a: + A = A.T + if mmul.transpose_b: + B = B.T + # use float32 data type for efficiency f32 = np.float32 D_ref = A.astype(f32) @ B.astype(f32) @@ -233,6 +245,16 @@ def cli_parser(description): default=[4096, 4096, 4096], help="M,N,K matrix sizes (A=MxK, B=KxN, C=MxN).", ) + parser.add_argument( + "--transpose-a", + action="store_true", + help="Transpose matrix A (i.e., A is KxM) before multiplication.", + ) + parser.add_argument( + "--transpose-b", + action="store_true", + help="Transpose matrix B (i.e., B is NxK) before multiplication.", + ) parser.add_argument( "--bias", action="store_true", @@ -375,46 +397,54 @@ def parse_cli_args(description): # Problem size m, n, k = args.sizes if args.sizes else (4096, 4096, 4096) + transpose_a = args.transpose_a + transpose_b = args.transpose_b + # Set required parameters params = { "m": m, "n": n, "k": k, + "transpose_a": transpose_a, + "transpose_b": transpose_b, } - if args.target: - params["device"] = args.target - - if args.json: - # Override parameters with values from JSON file if provided - with open(args.json, "r") as f: - json_params = json.load(f) - params.update(json_params) - # Override parameters with CLI args if provided + # Collect parameters from CLI arguments + cli_params = {} if args.wg_tile: - params["wg_m"], params["wg_n"] = args.wg_tile + cli_params["wg_m"], cli_params["wg_n"] = args.wg_tile if args.sg_tile: - params["sg_m"], params["sg_n"] = args.sg_tile + cli_params["sg_m"], cli_params["sg_n"] = args.sg_tile if args.k_tile: - params["k_tile"] = args.k_tile + cli_params["k_tile"] = args.k_tile if args.load_tile_a: - params["load_a_m"], params["load_a_k"] = args.load_tile_a + cli_params["load_a_m"], cli_params["load_a_k"] = args.load_tile_a if args.load_tile_b: - params["load_b_k"], params["load_b_n"] = args.load_tile_b + cli_params["load_b_k"], cli_params["load_b_n"] = args.load_tile_b if args.prefetch_tile_a: - params["prefetch_a_m"], params["prefetch_a_k"] = args.prefetch_tile_a + cli_params["prefetch_a_m"], cli_params["prefetch_a_k"] = args.prefetch_tile_a if args.prefetch_tile_b: - params["prefetch_b_k"], params["prefetch_b_n"] = args.prefetch_tile_b + cli_params["prefetch_b_k"], cli_params["prefetch_b_n"] = args.prefetch_tile_b if args.prefetch_a_nb is not None: - params["prefetch_a_nb"] = args.prefetch_a_nb + cli_params["prefetch_a_nb"] = args.prefetch_a_nb if args.prefetch_b_nb is not None: - params["prefetch_b_nb"] = args.prefetch_b_nb + cli_params["prefetch_b_nb"] = args.prefetch_b_nb - for k, v in params.items(): - if v is None: - raise ValueError( - f"Parameter {k} is not set. Please provide it via CLI or JSON file." - ) + # By default the tile size parameters are left undefined + if args.json: + # Override parameters with values from JSON file if provided + with open(args.json, "r") as f: + json_params = json.load(f) + params.update(json_params) + # Override with CLI params + params.update(cli_params) + elif cli_params: + # Get default parameters from selector + param_selector = XeGPUParameterSelector(device=args.target) + def_params = param_selector.get_parameters((m, n, k), transpose_a, transpose_b) + params.update(def_params) + # Override with CLI params + params.update(cli_params) with ir.Context(), ir.Location.unknown(): lh_dialects.register_and_load() @@ -423,6 +453,8 @@ def parse_cli_args(description): M=params["m"], N=params["n"], K=params["k"], + transpose_a=params["transpose_a"], + transpose_b=params["transpose_b"], has_bias=args.bias, has_relu=args.relu, accumulate_c=not args.no_accumulate_c, @@ -481,14 +513,18 @@ def list2str(a): c_type = str(wload.c_type) print( f"sizes={list2str([params['m'], params['n'], params['k']])} " + f"ta={int(params['transpose_a'])} " + f"tb={int(params['transpose_b'])} " f"dt={ab_type},{c_type} " - f"wg-tile={list2str([params['wg_m'], params['wg_n']])} " - f"sg-tile={list2str([params['sg_m'], params['sg_n']])} " - f"k-tile={params['k_tile']} " - f"load-a-tile={list2str([params['load_a_m'], params['load_a_k']])} " - f"load-b-tile={list2str([params['load_b_k'], params['load_b_n']])} " - f"pf-a-tile={list2str([params['prefetch_a_m'], params['prefetch_a_k']])} " - f"pf-b-tile={list2str([params['prefetch_b_k'], params['prefetch_b_n']])} " + f"wg={list2str([params['wg_m'], params['wg_n']])} " + f"sg={list2str([params['sg_m'], params['sg_n']])} " + f"k={params['k_tile']} " + f"ld-a={list2str([params['load_a_m'], params['load_a_k']])} " + f"ld-b={list2str([params['load_b_k'], params['load_b_n']])} " + f"pf-a={list2str([params['prefetch_a_m'], params['prefetch_a_k']])} " + f"pf-b={list2str([params['prefetch_b_k'], params['prefetch_b_n']])} " + f"pf-a-nb={params['prefetch_a_nb']} " + f"pf-b-nb={params['prefetch_b_nb']} " f"time(us): {elapsed:.2f} " f"GFLOPS: {gflops:.2f}" ) diff --git a/examples/xegpu/mlp.py b/examples/xegpu/mlp.py index d19441b5..1e60c787 100644 --- a/examples/xegpu/mlp.py +++ b/examples/xegpu/mlp.py @@ -1,5 +1,7 @@ # RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 | FileCheck %s +# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --transpose-a | FileCheck %s +# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --transpose-b | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --relu | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --bias | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --accumulate-c | FileCheck %s @@ -46,6 +48,8 @@ def check_correctness( ab_dtype: np.dtype, has_bias: bool = False, has_relu: bool = False, + transpose_a: bool = False, + transpose_b: bool = False, verbose: int = 0, ) -> bool: output_array, input_array, *rest = initial_host_arrays @@ -63,7 +67,11 @@ def check_correctness( biases = [b.astype(np.float32) for b in biases] a_array = input_array + if transpose_a: + a_array = a_array.T for i, W in enumerate(weights): + if transpose_b: + W = W.T D_ref = a_array @ W if has_bias: D_ref += biases[i] @@ -109,6 +117,8 @@ class XeGPUMLP: hidden_layer_sizes: Optional[list[int]] = None ab_type: ir.Type | str | None = None acc_type: ir.Type | str | None = None + transpose_a: bool = False + transpose_b: bool = False has_bias: bool = False has_relu: bool = False accumulate_c: bool = False @@ -135,9 +145,13 @@ def __post_init__(self): if self.hidden_layer_sizes is None: self.hidden_layer_sizes = [] self.input_shape = (self.batch_size, self.input_size) + if self.transpose_a: + self.input_shape = self.input_shape[::-1] self.output_shape = (self.batch_size, self.output_size) layer_sizes = [self.input_size] + self.hidden_layer_sizes + [self.output_size] self.weight_shapes = list(zip(layer_sizes[:-1], layer_sizes[1:])) + if self.transpose_b: + self.weight_shapes = [shape[::-1] for shape in self.weight_shapes] self.matmul_layers = [(self.batch_size, o, i) for i, o in self.weight_shapes] self.bias_shapes = [(o,) for o in layer_sizes[1:]] if self.has_bias else [] @@ -154,10 +168,16 @@ def gen_random(shape, dtype): return np.random.rand(*shape).astype(dtype) def gen_identity(shape, dtype): - # identity matrix, if cols > rows wrap to fill all columns + # identity matrix, a = np.zeros(shape, dtype=dtype) np.fill_diagonal(a, 1) - if shape[1] > shape[0]: + if self.transpose_b: + if shape[0] > shape[1]: + # if rows > cols wrap to fill all rows + second_block = a[shape[1] :, :] + np.fill_diagonal(second_block, 1) + elif shape[1] > shape[0]: + # if cols > rows wrap to fill all columns second_block = a[:, shape[0] :] np.fill_diagonal(second_block, 1) return a @@ -208,6 +228,8 @@ def payload_module(self) -> ir.Module: acc_type=self.acc_type, bias_type=self.ab_type, result_type=self.ab_type, + transpose_a=self.transpose_a, + transpose_b=self.transpose_b, has_bias=self.has_bias, has_relu=self.has_relu, accumulate_c=self.accumulate_c, @@ -299,6 +321,16 @@ def parse_cli(): action="store_true", help="Add ReLU activation function to each layer except the output layer.", ) + parser.add_argument( + "--transpose-a", + action="store_true", + help="Transpose the input matrix A in the first matmul layer.", + ) + parser.add_argument( + "--transpose-b", + action="store_true", + help="Transpose the weight matrices B in all matmul layers.", + ) parser.add_argument( "--accumulate-c", action="store_true", @@ -358,6 +390,8 @@ def parse_cli(): with ir.Context(), ir.Location.unknown(): lh_dialects.register_and_load() + tr_a = args.transpose_a + tr_b = args.transpose_b wload = XeGPUMLP( batch_size=args.batch_size, input_size=args.input_size, @@ -365,6 +399,8 @@ def parse_cli(): hidden_layer_sizes=args.hidden_sizes, has_bias=args.bias, has_relu=args.relu, + transpose_a=tr_a, + transpose_b=tr_b, accumulate_c=args.accumulate_c, identity_weights=identity_weights, ) @@ -376,7 +412,16 @@ def parse_cli(): acc_type = wload.acc_type # Initialize layer parameters - params = [{"m": M, "n": N, "k": K} for M, N, K in matmuls] + params = [] + for i, (M, N, K) in enumerate(matmuls): + layer_params = { + "m": M, + "n": N, + "k": K, + "transpose_a": tr_a if i == 0 else False, + "transpose_b": tr_b, + } + params.append(layer_params) if args.target: for layer_params in params: layer_params["device"] = args.target @@ -422,6 +467,8 @@ def parse_cli(): wload.ab_dtype, has_bias=wload.has_bias, has_relu=wload.has_relu, + transpose_a=tr_a, + transpose_b=tr_b, verbose=args.verbose, ) if not success: @@ -447,6 +494,8 @@ def list2str(a): f"i={args.input_size} " f"o={args.output_size} " f"hs={list2str(hidden_sizes)} " + f"ta={int(tr_a)} " + f"tb={int(tr_b)} " f"dt={ab_type},{acc_type} " f"time(us): {elapsed:.2f} " f"GFLOPS: {gflops:.2f}" diff --git a/lighthouse/dialects/transform/transform_ext/ops/get_tileable_consumers.py b/lighthouse/dialects/transform/transform_ext/ops/get_tileable_consumers.py index 947e0ef9..2e40e4a2 100644 --- a/lighthouse/dialects/transform/transform_ext/ops/get_tileable_consumers.py +++ b/lighthouse/dialects/transform/transform_ext/ops/get_tileable_consumers.py @@ -50,7 +50,6 @@ def is_tileable_op(op: ir.Operation) -> bool: linalg.MaxOp, linalg.MinOp, linalg.FillOp, - linalg.MatmulOp, linalg.GenericOp, ] return isinstance(op.opview, tuple(linalg_ops)) diff --git a/lighthouse/ingress/mlir_gen/gpu_matmul_payload.py b/lighthouse/ingress/mlir_gen/gpu_matmul_payload.py index 356638b5..f90d909c 100644 --- a/lighthouse/ingress/mlir_gen/gpu_matmul_payload.py +++ b/lighthouse/ingress/mlir_gen/gpu_matmul_payload.py @@ -9,6 +9,8 @@ def generate_gpu_matmul_payload( K: int, ab_type: ir.Type, c_type: ir.Type, + transpose_a: bool, + transpose_b: bool, has_bias: bool, has_relu: bool, accumulate_c: bool, @@ -24,6 +26,8 @@ def generate_gpu_matmul_payload( acc_type=c_type, bias_type=c_type, result_type=c_type, + transpose_a=transpose_a, + transpose_b=transpose_b, has_bias=has_bias, has_relu=has_relu, accumulate_c=accumulate_c, diff --git a/lighthouse/ingress/mlir_gen/gpu_mlp_payload.py b/lighthouse/ingress/mlir_gen/gpu_mlp_payload.py index 6fe91849..547ec666 100644 --- a/lighthouse/ingress/mlir_gen/gpu_mlp_payload.py +++ b/lighthouse/ingress/mlir_gen/gpu_mlp_payload.py @@ -1,5 +1,5 @@ from mlir import ir -from mlir.dialects import linalg, gpu, bufferization, arith, tensor +from mlir.dialects import linalg, bufferization, arith, tensor from .utils import emit_buf_to_tensor from .named import add_bias, relu, times_weights @@ -17,6 +17,8 @@ def generate_gpu_mlp_payload( acc_type: ir.Type, bias_type: ir.Type, result_type: ir.Type, + transpose_a: bool, + transpose_b: bool, has_bias: bool, has_relu: bool, accumulate_c: bool, @@ -24,14 +26,16 @@ def generate_gpu_mlp_payload( ) -> ir.Module: """Generate payload function module for an MLP kernel.""" mod = ir.Module.create() - memref_in_t = ir.MemRefType.get((batch_size, input_size), ab_type) + a_shape = (batch_size, input_size) if not transpose_a else (input_size, batch_size) + memref_in_t = ir.MemRefType.get(a_shape, ab_type) memref_out_t = ir.MemRefType.get((batch_size, output_size), result_type) layer_sizes = [input_size] + hidden_layer_sizes + [output_size] feature_sizes = list(zip(layer_sizes[:-1], layer_sizes[1:])) weight_memref_types = [] bias_memref_types = [] for in_size, out_size in feature_sizes: - memref_t = ir.MemRefType.get((in_size, out_size), ab_type) + shape = (in_size, out_size) if not transpose_b else (out_size, in_size) + memref_t = ir.MemRefType.get(shape, ab_type) weight_memref_types.append(memref_t) if has_bias: memref_t = ir.MemRefType.get((out_size,), bias_type) @@ -59,24 +63,28 @@ def payload(*args): ] layer_input_tensor = input_tensor - to_dealloc = None for i, (weight_tensor, bias_tensor) in enumerate( zip(weight_tensors, bias_tensors) ): - M, K = layer_input_tensor.type.shape - K, N = weight_tensor.type.shape - if i == nlayers - 1: - c_tensor = output_tensor - c_memref = output - else: - # allocate intermediate buffer - memref_type = ir.MemRefType.get((M, N), ab_type) - c_memref = gpu.alloc(memref_type, None, [], [], []) - gpu.memset(None, [], c_memref, arith.constant(ab_type, 0.0)) - if accumulate_c: - c_tensor = emit_buf_to_tensor( - c_memref, restrict=True, writable=True - ) + layer_transpose_a = ( + transpose_a and i == 0 + ) # transpose A only for the first layer + M, K = ( + layer_input_tensor.type.shape[::-1] + if layer_transpose_a + else layer_input_tensor.type.shape + ) + K, N = ( + weight_tensor.type.shape[::-1] + if transpose_b + else weight_tensor.type.shape + ) + c_tensor = None + if accumulate_c: + if i == nlayers - 1: + c_tensor = output_tensor + else: + c_tensor = tensor.empty((M, N), ab_type) # skip relu for final layer hidden_layer = i < nlayers - 1 layer_output = emit_mlp_layer( @@ -84,35 +92,34 @@ def payload(*args): weight_tensor, acc_type=acc_type, result_type=ab_type if hidden_layer else result_type, - acc_tensor=c_tensor if accumulate_c else None, + acc_tensor=c_tensor, bias_tensor=bias_tensor, + transpose_a=layer_transpose_a, + transpose_b=transpose_b, has_relu=(hidden_layer or relu_on_final_layer) and has_relu, ) - bufferization.materialize_in_destination( - None, layer_output, c_memref, restrict=True, writable=True - ) - if to_dealloc is not None: - gpu.dealloc(None, [], to_dealloc) - to_dealloc = None - if i != nlayers - 1: - # deallocate after next layer - to_dealloc = c_memref + if i == nlayers - 1: + bufferization.materialize_in_destination( + None, layer_output, output, restrict=True, writable=True + ) layer_input_tensor = layer_output return mod def emit_mlp_layer( - a_tensor, - b_tensor, - acc_type, - result_type, - acc_tensor=None, - bias_tensor=None, - has_relu=False, + a_tensor: ir.Value, + b_tensor: ir.Value, + acc_type: ir.Type, + result_type: ir.Type, + acc_tensor: ir.Value | None = None, + bias_tensor: ir.Value | None = None, + transpose_a: bool = False, + transpose_b: bool = False, + has_relu: bool = False, ) -> ir.Value: - M, K = a_tensor.type.shape - K, N = b_tensor.type.shape + M, K = a_tensor.type.shape[::-1] if transpose_a else a_tensor.type.shape + K, N = b_tensor.type.shape[::-1] if transpose_b else b_tensor.type.shape convert_result = acc_type != result_type if acc_tensor is not None: if acc_tensor.type.element_type != acc_type: @@ -124,6 +131,16 @@ def emit_mlp_layer( empty = tensor.empty((M, N), acc_type) zero_tensor = linalg.fill(zero, outs=[empty]) acc_tensor = zero_tensor + if transpose_a: + empty = tensor.empty((M, K), a_tensor.type.element_type) + a_tensor = linalg.transpose( + a_tensor, outs=(empty,), permutation=[1, 0] + ).results[0] + if transpose_b: + empty = tensor.empty((K, N), b_tensor.type.element_type) + b_tensor = linalg.transpose( + b_tensor, outs=(empty,), permutation=[1, 0] + ).results[0] terminal = times_weights(a_tensor, b_tensor, acc_tensor) if bias_tensor is not None: if bias_tensor.type.element_type != acc_type: diff --git a/lighthouse/schedule/xegpu/matmul_constraints.py b/lighthouse/schedule/xegpu/matmul_constraints.py index 3e2bf8c1..9a03bd1f 100644 --- a/lighthouse/schedule/xegpu/matmul_constraints.py +++ b/lighthouse/schedule/xegpu/matmul_constraints.py @@ -14,10 +14,16 @@ PFETCH_MAX_ROWS = 32 PFETCH_MIN_COLS = 16 PFETCH_MAX_COLS = 32 +TRANSPOSE_LOAD = [16, 16] # heuristics: skip likely suboptimal configurations MIN_NB_THREADS = 16 +def print_header(title: str, char: str = "=", width: int = 80): + header = f" {title} " + print(f"{header:{char}^{width}}") + + def check_wg_tile(M: int, N: int, wg_tile: tuple[int, int]) -> tuple[int, int]: if M % wg_tile[0] != 0: raise ValueError("wg_tile_m does not divide M") @@ -63,6 +69,7 @@ def check_load_tile( parent_shape: tuple[int, int], child_shape: tuple[int, int], name: str = "A", + transpose: bool = False, ): if parent_shape[0] % tile[0] != 0 or parent_shape[1] % tile[1] != 0: raise ValueError( @@ -80,26 +87,32 @@ def check_load_tile( raise ValueError(f"Load tile {name} {tile} has too many rows.") if tile[1] > LOAD_MAX_COLS: raise ValueError(f"Load tile {name} {tile} has too many cols.") + if transpose and (tile[0] != TRANSPOSE_LOAD[1] or tile[1] != TRANSPOSE_LOAD[0]): + raise ValueError( + f"If {name} is transposed, load tile must be {TRANSPOSE_LOAD}." + ) def check_load_tile_a( tile: tuple[int, int], sg_tile: tuple[int, int], k_tile: int, + transpose: bool = False, ): data_shape = (sg_tile[0], k_tile) child_shape = DPAS.A_TILE - check_load_tile(tile, data_shape, child_shape, name="A") + check_load_tile(tile, data_shape, child_shape, name="A", transpose=transpose) def check_load_tile_b( tile: tuple[int, int], sg_tile: tuple[int, int], k_tile: int, + transpose: bool = False, ): data_shape = (k_tile, sg_tile[1]) child_shape = DPAS.B_TILE - check_load_tile(tile, data_shape, child_shape, name="B") + check_load_tile(tile, data_shape, child_shape, name="B", transpose=transpose) def check_prefetch_tile( @@ -107,9 +120,12 @@ def check_prefetch_tile( data_shape: tuple[int, int], gpu_specs: XeGPUSpecs, name: str = "A", + transpose: bool = False, min_nb_threads: int | None = None, verbose: bool = False, ) -> tuple[int, int]: + if transpose: + data_shape = data_shape[::-1] if tile[0] < PFETCH_MIN_ROWS: raise ValueError( f"Prefetch tile {name} {tile} has too few rows (min {PFETCH_MIN_ROWS})." @@ -134,7 +150,8 @@ def check_prefetch_tile( cols = int(data_shape[1] / tile[1]) nb_threads = int(rows * cols) if verbose: - print(f"=== Prefetch {name} ===") + print_header(f"Prefetch {name}", char="-", width=50) + print(f"data shape: {data_shape}, transpose: {transpose}") print(f"tile size {tile}, grid size ({rows}, {cols}), {nb_threads} threads") if nb_threads > gpu_specs.max_nb_threads: raise ValueError( @@ -152,6 +169,7 @@ def check_prefetch_tile_a( wg_tile: tuple[int, int], k_tile: int, gpu_specs: XeGPUSpecs, + transpose: bool = False, min_nb_threads: int | None = None, verbose: bool = False, ) -> tuple[int, int]: @@ -161,6 +179,7 @@ def check_prefetch_tile_a( data_shape, gpu_specs, name="A", + transpose=transpose, min_nb_threads=min_nb_threads, verbose=verbose, ) @@ -171,6 +190,7 @@ def check_prefetch_tile_b( wg_tile: tuple[int, int], k_tile: int, gpu_specs: XeGPUSpecs, + transpose: bool = False, min_nb_threads: int | None = None, verbose: bool = False, ) -> tuple[int, int]: @@ -180,6 +200,7 @@ def check_prefetch_tile_b( data_shape, gpu_specs, name="B", + transpose=transpose, min_nb_threads=min_nb_threads, verbose=verbose, ) @@ -208,6 +229,8 @@ def check_constraints( prefetch_tile_b_k = params["prefetch_b_k"] prefetch_tile_b_n = params["prefetch_b_n"] k_tile = params["k_tile"] + transpose_a = params.get("transpose_a", False) + transpose_b = params.get("transpose_b", False) wg_tile = (wg_tile_m, wg_tile_n) sg_tile = (sg_tile_m, sg_tile_n) @@ -220,13 +243,14 @@ def check_constraints( check_wg_tile(M, N, wg_tile) check_sg_tile(wg_tile, sg_tile, gpu_specs, min_nb_threads=MIN_NB_THREADS) check_k_tile(K, k_tile) - check_load_tile_a(load_tile_a, sg_tile, k_tile) - check_load_tile_b(load_tile_b, sg_tile, k_tile) + check_load_tile_a(load_tile_a, sg_tile, k_tile, transpose=transpose_a) + check_load_tile_b(load_tile_b, sg_tile, k_tile, transpose=transpose_b) check_prefetch_tile_a( prefetch_tile_a, wg_tile, k_tile, gpu_specs, + transpose=transpose_a, min_nb_threads=MIN_NB_THREADS, verbose=verbose, ) @@ -235,6 +259,7 @@ def check_constraints( wg_tile, k_tile, gpu_specs, + transpose=transpose_b, min_nb_threads=MIN_NB_THREADS, verbose=verbose, ) diff --git a/lighthouse/schedule/xegpu/matmul_costmodel.py b/lighthouse/schedule/xegpu/matmul_costmodel.py index e5a60a80..7d260b01 100644 --- a/lighthouse/schedule/xegpu/matmul_costmodel.py +++ b/lighthouse/schedule/xegpu/matmul_costmodel.py @@ -21,17 +21,51 @@ PFETCH_MAX_ROWS, PFETCH_MIN_COLS, PFETCH_MAX_COLS, + MIN_NB_THREADS, + TRANSPOSE_LOAD, + print_header, ) +def summarize_config(params: dict, gpu_specs: XeGPUSpecs): + """Prints a summary of the given configuration.""" + M = params["m"] + N = params["n"] + K = params["k"] + wg_tile = (params["wg_m"], params["wg_n"]) + sg_tile = (params["sg_m"], params["sg_n"]) + k_tile = params["k_tile"] + ld_a = (params["load_a_m"], params["load_a_k"]) + ld_b = (params["load_b_k"], params["load_b_n"]) + pf_a = (params["prefetch_a_m"], params["prefetch_a_k"]) + pf_b = (params["prefetch_b_k"], params["prefetch_b_n"]) + transpose_a = params.get("transpose_a", False) + transpose_b = params.get("transpose_b", False) + estimate_performance(M, N, K, wg_tile, sg_tile, k_tile, gpu_specs, verbose=True) + print_header("Instruction level", char="-", width=50) + print(f"load size A: {ld_a}") + print(f"inst size A: {DPAS.A_TILE}") + print(f"load size B: {ld_b}") + print(f"inst size B: {DPAS.B_TILE}") + check_prefetch_tile_a( + pf_a, wg_tile, k_tile, gpu_specs, transpose=transpose_a, verbose=True + ) + check_prefetch_tile_b( + pf_b, wg_tile, k_tile, gpu_specs, transpose=transpose_b, verbose=True + ) + + def generate_configs( M: int, N: int, K: int, gpu_specs: XeGPUSpecs, + transpose_a: bool = False, + transpose_b: bool = False, perf_threshold: float | None = None, pf_strategy: str = "first", max_nb_configs: int | None = None, + verbose: bool = False, ) -> list[tuple[float, dict[str, int]]]: """Generate valid tile size configurations for (M, N, K) matrix multiplication. @@ -46,6 +80,12 @@ def generate_configs( Load tile sizes are currently fixed to DPAS tile sizes for A and B. + The `transpose_a` and `transpose_b` arguments indicate whether the A and B + matrices are transposed. The returned tile sizes are always in + non-transposed form, i.e. applicable to the payload op (e.g. DPAS), + _except_ for the prefetch tiles which are in the orientation applicable to + the prefetch op. + Returns: A list of (perf_estimate, params_dict) tuples sorted by perf_estimate (descending). """ @@ -63,9 +103,11 @@ def tuple_to_param_dict( tuple[int, int], tuple[int, int], tuple[int, int], + bool, + bool, ], ) -> dict[str, int]: - wg_tile, sg_tile, k_tile, ld_a, ld_b, pf_a, pf_b = config + wg_tile, sg_tile, k_tile, ld_a, ld_b, pf_a, pf_b, tr_a, tr_b = config return { "m": M, "n": N, @@ -85,6 +127,8 @@ def tuple_to_param_dict( "prefetch_b_n": pf_b[1], "prefetch_a_nb": 1, "prefetch_b_nb": 1, + "transpose_a": tr_a, + "transpose_b": tr_b, } # define search space @@ -105,17 +149,26 @@ def tuple_to_param_dict( ) n_prefetch = 1 if pf_strategy == "first" else None pf_a_list, pf_b_list = generate_prefetch_tiles( - wg_tile, k_tile, gpu_specs, n=n_prefetch + wg_tile, + k_tile, + gpu_specs, + n=n_prefetch, + transpose_a=transpose_a, + transpose_b=transpose_b, + verbose=False, ) - load_a_list = [DPAS.A_TILE] - load_b_list = [DPAS.B_TILE] + load_a_list = [DPAS.A_TILE if not transpose_a else TRANSPOSE_LOAD] + load_b_list = [DPAS.B_TILE if not transpose_b else TRANSPOSE_LOAD] for la, lb, pa, pb in product( load_a_list, load_b_list, pf_a_list, pf_b_list ): - c = (wg_tile, sg_tile, k_tile, la, lb, pa, pb) + c = (wg_tile, sg_tile, k_tile, la, lb, pa, pb, transpose_a, transpose_b) params = tuple_to_param_dict(M, N, K, c) - if check_constraints(params, gpu_specs, verbose=False): - valid_configs.append((perf, params)) + check_constraints(params, gpu_specs, verbose=False) + valid_configs.append((perf, params)) + if verbose: + print_header("Valid configuration found") + summarize_config(params, gpu_specs) except ValueError: pass @@ -130,6 +183,10 @@ def tuple_to_param_dict( if max_nb_configs is not None: valid_configs = valid_configs[:max_nb_configs] + if verbose and max_nb_configs == 1: + print_header("Selected configuration", char="=", width=50) + summarize_config(valid_configs[0][1], gpu_specs) + return valid_configs @@ -138,6 +195,9 @@ def generate_prefetch_tiles( k_tile: int, gpu_specs: XeGPUSpecs, n: int | None = None, + transpose_a: bool = False, + transpose_b: bool = False, + verbose: bool = False, ) -> tuple[ list[tuple[int, int]], list[tuple[int, int]], @@ -150,16 +210,26 @@ def generate_prefetch_tiles( def gridsearch( check_fn: Callable[ - [tuple[int, int], tuple[int, int], int, XeGPUSpecs], + [tuple[int, int], tuple[int, int], int, XeGPUSpecs, bool, bool], tuple[int, int], ], + transpose: bool = False, + verbose: bool = False, ) -> list[tuple[int, int]]: tiles = [] for rows in range(PFETCH_MIN_ROWS, PFETCH_MAX_ROWS + 1): for cols in range(PFETCH_MIN_COLS, PFETCH_MAX_COLS + 1): tile = (rows, cols) try: - grid = check_fn(tile, wg_tile, k_tile, gpu_specs) + grid = check_fn( + tile, + wg_tile, + k_tile, + gpu_specs, + transpose=transpose, + min_nb_threads=MIN_NB_THREADS, + verbose=verbose, + ) nb_threads = int(grid[0] * grid[1]) tiles.append((tile, nb_threads, grid)) except ValueError: @@ -169,8 +239,12 @@ def gridsearch( tiles = [t[0] for t in tiles] return tiles - prefetch_tiles_a = gridsearch(check_prefetch_tile_a) - prefetch_tiles_b = gridsearch(check_prefetch_tile_b) + prefetch_tiles_a = gridsearch( + check_prefetch_tile_a, transpose=transpose_a, verbose=verbose + ) + prefetch_tiles_b = gridsearch( + check_prefetch_tile_b, transpose=transpose_b, verbose=verbose + ) if n is not None: prefetch_tiles_a = prefetch_tiles_a[:n] prefetch_tiles_b = prefetch_tiles_b[:n] @@ -204,7 +278,8 @@ def estimate_performance( Raises ValueError if the given configuration is invalid. """ if verbose: - print("=== Global Level ===") + print_header("Global Level", char="-", width=50) + print(f"Matrix sizes: M={M}, N={N}, K={K}") # TODO generalize @@ -213,7 +288,7 @@ def estimate_performance( # WG if verbose: - print("=== Workgroup Level ===") + print_header("Workgroup Level", char="-", width=50) roofline_threshold = gpu_specs.peak_flops / gpu_specs.bw_global_mem # in FLOPs/Byte wg_grid = check_wg_tile(M, N, wg_tile) @@ -268,7 +343,7 @@ def estimate_performance( # SG if verbose: - print("=== Subgroup Level ===") + print_header("Subgroup Level", char="-", width=50) sg_grid = check_sg_tile(wg_tile, sg_tile, gpu_specs) nb_sgs = sg_grid[0] * sg_grid[1] @@ -316,16 +391,4 @@ def estimate_performance( f"Number of registers ({nb_reg}) exceeds hardware register file size ({gpu_specs.nb_registers})." ) - if prefetch_tile_a: - # check that prefetch tile is suitable for WG-k tile - check_prefetch_tile_a( - prefetch_tile_a, wg_tile, k_tile, gpu_specs, verbose=verbose - ) - - if prefetch_tile_b: - # check that prefetch tile is suitable for WG-k tile - check_prefetch_tile_b( - prefetch_tile_b, wg_tile, k_tile, gpu_specs, verbose=verbose - ) - return predicted_throughput diff --git a/lighthouse/schedule/xegpu/mlp_schedule.py b/lighthouse/schedule/xegpu/mlp_schedule.py index 66a9f70d..759e504c 100644 --- a/lighthouse/schedule/xegpu/mlp_schedule.py +++ b/lighthouse/schedule/xegpu/mlp_schedule.py @@ -155,7 +155,12 @@ def mlp_schedule( if not all(p in layer_params for p in required_params): # Some parameters are missing, use the parameter selector to fill # NOTE None values are interpreted as knobs in the constraint function - generated_params = param_selector.get_parameters(m, n, k) + shape = (m, n, k) + transpose_a = layer_params.get("transpose_a", False) + transpose_b = layer_params.get("transpose_b", False) + generated_params = param_selector.get_parameters( + shape, transpose_a, transpose_b + ) # Overwrite original params to ensure consistent configuration layer_params.update(generated_params) @@ -192,6 +197,9 @@ def bundle_xegpu_mlp_schedule( anytype = transform.AnyOpType.get() + # fuse all elementwise ops first + mod = apply_registered_pass(mod, "linalg-fuse-elementwise-ops") + matmul_ops = match_and_split(mod, ops={"linalg.matmul"}, nhandles=nlayers) # tile each layer separately @@ -216,6 +224,12 @@ def bundle_xegpu_mlp_schedule( # k loop tiling wg_matmul = match(wg_loop, ops={"linalg.matmul"}) _, [k_loop], _ = lh_transform.tile(wg_matmul, tile_sizes=[0, 0, k_tile]) + lh_transform.cleanup(wg_loop) + # if there's a transpose op fuse it into the k loop + transpose_op = match(wg_loop, ops={"linalg.transpose"}) + structured.structured_fuse_into_containing_op( + anytype, anytype, transpose_op, k_loop + ) func = transform.get_parent_op( anytype, @@ -256,6 +270,18 @@ def bundle_xegpu_mlp_schedule( ).result # fold memref.subviews into vector.transfer_read/write ops mod = apply_registered_pass(mod, "fold-memref-alias-ops") + # match payload function + wg_loops = match(mod, ops={"scf.forall"}) + func = transform.get_parent_op( + anytype, wg_loops, op_name="func.func", deduplicate=True + ) + # insert dealloc ops + func = apply_registered_pass(func, "buffer-deallocation-pipeline") + # convert to gpu.alloc and gpu.dealloc ops + alloc_ops = match(func, ops={"memref.alloc"}) + transform_ext.replace(alloc_ops, "gpu.alloc") + alloc_ops = match(func, ops={"memref.dealloc"}) + transform_ext.replace(alloc_ops, "gpu.dealloc") transform.apply_cse(mod) canonicalize(mod) @@ -368,6 +394,8 @@ def xegpu_wg_annotation_for_mlp_layer( prefetch_b_n: int | KnobValue, prefetch_a_nb: int | KnobValue, prefetch_b_nb: int | KnobValue, + transpose_a: bool, + transpose_b: bool, **_catch_all, ): """ @@ -407,6 +435,8 @@ def calc_sg_layout(WG_M, WG_N, SG_M, SG_N): prefetch_a_k, prefetch_b_k, prefetch_b_n, + transpose_a, + transpose_b, ) def constrain_and_calculate_load_and_prefetch_params( WG_M, @@ -422,6 +452,8 @@ def constrain_and_calculate_load_and_prefetch_params( PFA_K, PFB_K, PFB_N, + TR_A, + TR_B, ): # NB: normal asserts in case of concrete values, SMT assert ops for symbolic values smt_ext.assert_(SG_M % LDA_M == 0) @@ -434,17 +466,22 @@ def constrain_and_calculate_load_and_prefetch_params( smt_ext.assert_(LDB_K <= LOAD_MAX_ROWS) smt_ext.assert_(LDB_N <= LOAD_MAX_COLS) - smt_ext.assert_(SG_M % PFA_M == 0) - smt_ext.assert_(K_TILE % PFA_K == 0) - smt_ext.assert_(K_TILE % PFB_K == 0) - smt_ext.assert_(SG_N % PFB_N == 0) + # prefetch tile shape depends on transpose flag + pf_shape_a = (K_TILE, WG_M) if TR_A else (WG_M, K_TILE) + pf_shape_b = (WG_N, K_TILE) if TR_B else (K_TILE, WG_N) + + smt_ext.assert_(pf_shape_a[0] % PFA_M == 0) + smt_ext.assert_(pf_shape_a[1] % PFA_K == 0) + smt_ext.assert_(pf_shape_b[0] % PFB_K == 0) + smt_ext.assert_(pf_shape_b[1] % PFB_N == 0) smt_ext.assert_(PFA_M <= PFETCH_MAX_ROWS) smt_ext.assert_(PFA_K <= PFETCH_MAX_COLS) - smt_ext.assert_(PFB_K <= PFETCH_MAX_ROWS) - smt_ext.assert_(PFB_N <= PFETCH_MAX_COLS) smt_ext.assert_(PFA_M >= PFETCH_MIN_ROWS) smt_ext.assert_(PFA_K >= PFETCH_MIN_COLS) + + smt_ext.assert_(PFB_K <= PFETCH_MAX_ROWS) + smt_ext.assert_(PFB_N <= PFETCH_MAX_COLS) smt_ext.assert_(PFB_K >= PFETCH_MIN_ROWS) smt_ext.assert_(PFB_N >= PFETCH_MIN_COLS) @@ -454,15 +491,16 @@ def constrain_and_calculate_load_and_prefetch_params( smt_ext.assert_(LDB_N % DPAS.N == 0) # prefetch A thread layout - prefetch_th_a_m = WG_M // PFA_M - prefetch_th_a_k = K_TILE // PFA_K + prefetch_th_a_m = pf_shape_a[0] // PFA_M + prefetch_th_a_k = pf_shape_a[1] // PFA_K + prefetch_th_a = prefetch_th_a_m * prefetch_th_a_k smt_ext.assert_(prefetch_th_a <= gpu_specs.max_nb_threads) smt_ext.assert_(prefetch_th_a_m * prefetch_th_a_k >= MIN_NB_THREADS) # prefetch B thread layout - prefetch_th_b_k = K_TILE // PFB_K - prefetch_th_b_n = WG_N // PFB_N + prefetch_th_b_k = pf_shape_b[0] // PFB_K + prefetch_th_b_n = pf_shape_b[1] // PFB_N prefetch_th_b = prefetch_th_b_k * prefetch_th_b_n smt_ext.assert_(prefetch_th_b <= gpu_specs.max_nb_threads) if isinstance(prefetch_th_b, smt_ext.SMTIntValue): @@ -484,7 +522,6 @@ def constrain_and_calculate_load_and_prefetch_params( load_op_a = xegpu.get_load_op(transform.get_operand(anyvalue, dpas_op, [0])) load_op_b = xegpu.get_load_op(transform.get_operand(anyvalue, dpas_op, [1])) - # insert prefetch ops for DPAS A and B tiles def add_prefetch(load_op, prefetch_nb, **layout): desc_op = xegpu.insert_prefetch( load_op, @@ -493,33 +530,40 @@ def add_prefetch(load_op, prefetch_nb, **layout): pf_ops = transform.get_consumers_of_result(anytype, desc_op, 0) xegpu.set_anchor_layout(pf_ops, **layout) - add_prefetch( - load_op_a, - prefetch_a_nb, - sg_layout=prefetch_layout_a, - sg_data=prefetch_tile_a, - inst_data=PREFETCH_INST_DATA, - ) - add_prefetch( - load_op_b, - prefetch_b_nb, - sg_layout=prefetch_layout_b, - sg_data=prefetch_tile_b, - inst_data=PREFETCH_INST_DATA, - ) + def annotate_ab_load( + dpas_op, index, load_op, layout_load, layout_dpas, layout_prefetch, prefetch_nb + ): + """Annotate A/B tile load op and dpas operand and insert prefetch ops.""" + user = transform.get_consumers_of_result(anytype, load_op, 0) + # FIXME use transform.alternatives instead of select and foreach + # check_transpose = transform.AlternativesOp([], 2) + + # transposed case + transpose_consumer_op = transform.select(anytype, user, "vector.transpose") + with lh_transform.foreach(transpose_consumer_op): + # Load op loads the transposed tile and thus sg_layout and sg_data + # dimensions must be transposed. Keep inst_data which has been + # validated in its current orientation. + tr_load = layout_load.copy() + tr_load["sg_layout"] = layout_load["sg_layout"][::-1] + tr_load["sg_data"] = layout_load["sg_data"][::-1] + tr_load["order"] = [0, 1] + # annotate dpas op operand + layout_dpas_order = layout_dpas.copy() + layout_dpas_order["order"] = [1, 0] + xegpu.set_anchor_layout(dpas_op, index=index, **layout_dpas_order) + xegpu.set_anchor_layout(load_op, **tr_load) + add_prefetch(load_op, prefetch_nb, **layout_prefetch) + transform.yield_() - def annotate_ab_load(load_op, layout_load, layout_dpas): - xegpu.set_anchor_layout(load_op, **layout_load) - result_tile = transform.get_result(anyvalue, load_op, [0]) - xegpu.convert_layout( - result_tile, - input_sg_layout=layout_load["sg_layout"], - input_sg_data=layout_load["sg_data"], - input_inst_data=layout_load["inst_data"], - target_sg_layout=layout_dpas["sg_layout"], - target_sg_data=layout_dpas["sg_data"], - target_inst_data=layout_dpas["inst_data"], - ) + # no transpose case + dpas_consumer_op = transform.select(anytype, user, "xegpu.dpas") + with lh_transform.foreach(dpas_consumer_op): + # annotate dpas op operand + xegpu.set_anchor_layout(dpas_op, index=index, **layout_dpas) + xegpu.set_anchor_layout(load_op, **layout_load) + add_prefetch(load_op, prefetch_nb, **layout_prefetch) + transform.yield_() # A tile load layout layout_load_a = { @@ -530,7 +574,21 @@ def annotate_ab_load(load_op, layout_load, layout_dpas): # A tile dpas layout layout_dpas_a = layout_load_a.copy() layout_dpas_a["inst_data"] = DPAS.A_TILE - annotate_ab_load(load_op_a, layout_load_a, layout_dpas_a) + # A tile prefetch layout + layout_prefetch_a = { + "sg_layout": prefetch_layout_a, + "sg_data": prefetch_tile_a, + "inst_data": PREFETCH_INST_DATA, + } + annotate_ab_load( + dpas_op, + 0, + load_op_a, + layout_load_a, + layout_dpas_a, + layout_prefetch_a, + prefetch_a_nb, + ) # B tile load layout layout_load_b = { @@ -541,7 +599,21 @@ def annotate_ab_load(load_op, layout_load, layout_dpas): # B tile dpas layout layout_dpas_b = layout_load_b.copy() layout_dpas_b["inst_data"] = DPAS.B_TILE - annotate_ab_load(load_op_b, layout_load_b, layout_dpas_b) + # B tile prefetch layout + layout_prefetch_b = { + "sg_layout": prefetch_layout_b, + "sg_data": prefetch_tile_b, + "inst_data": PREFETCH_INST_DATA, + } + annotate_ab_load( + dpas_op, + 1, + load_op_b, + layout_load_b, + layout_dpas_b, + layout_prefetch_b, + prefetch_b_nb, + ) # C tile layout output_layout = { @@ -550,20 +622,18 @@ def annotate_ab_load(load_op, layout_load, layout_dpas): "inst_data": DPAS.C_TILE, } # C tile dpas anchor layout - xegpu.set_anchor_layout(dpas_op, index=0, **layout_dpas_a) - xegpu.set_anchor_layout(dpas_op, index=1, **layout_dpas_b) xegpu.set_anchor_layout(dpas_op, index=2, **output_layout) # annotate store op store_op_c = match(gpu_func, ops={"xegpu.store_nd"}) xegpu.set_anchor_layout(store_op_c, **output_layout) # annotate the 1d load of the broadcast op with a slice layout - # FIXME assert that we only match one add op - add_ops = match(gpu_func, ops={"arith.addf"}) - with lh_transform.foreach(add_ops) as bias_add_op: - bcast_load = xegpu.get_load_op( - transform.get_operand(anyvalue, bias_add_op, [0]) - ) + # NOTE assumes that xegpu.load is followed by vector.broadcast + maybe_bcast_load = match(gpu_func, ops={"xegpu.load"}) + load_user = transform.get_consumers_of_result(anytype, maybe_bcast_load, 0) + bcast_ops = transform.select(anytype, load_user, "vector.broadcast") + with lh_transform.foreach(bcast_ops) as bcast_op: + bcast_load = xegpu.get_load_op(transform.get_operand(anyvalue, bcast_op, [0])) xegpu.set_anchor_layout(bcast_load, index=0, **output_layout, slice_dims=[0]) transform.yield_() diff --git a/lighthouse/schedule/xegpu/xegpu_parameter_selector.py b/lighthouse/schedule/xegpu/xegpu_parameter_selector.py index 256fa6c9..3963c540 100644 --- a/lighthouse/schedule/xegpu/xegpu_parameter_selector.py +++ b/lighthouse/schedule/xegpu/xegpu_parameter_selector.py @@ -30,12 +30,27 @@ def __init__(self, device: str | None = None, json_file: str | None = None): self.gpu_specs = XeGPUSpecs.get(self.device) self.matmul_param_db = load_param_database(json_file) - def get_parameters(self, m: int, n: int, k: int) -> dict: - shape = (m, n, k) - if shape not in self.matmul_param_db: + def get_parameters( + self, + shape: tuple[int, int, int], + transpose_a: bool = False, + transpose_b: bool = False, + **kwargs, + ) -> dict: + m, n, k = shape + if shape not in self.matmul_param_db or transpose_a or transpose_b: try: # Use cost model to generate tile sizes and take first config - configs = generate_configs(m, n, k, self.gpu_specs, max_nb_configs=1) + configs = generate_configs( + m, + n, + k, + self.gpu_specs, + transpose_a=transpose_a, + transpose_b=transpose_b, + max_nb_configs=1, + verbose=False, + ) if not configs: raise ValueError( f"Cost model did not return any valid configurations for matmul {shape}." @@ -45,7 +60,11 @@ def get_parameters(self, m: int, n: int, k: int) -> dict: except Exception as e: msg = f"Error generating parameters for shape {shape} using cost model: {e}" raise ValueError(msg) from e - return self.matmul_param_db[shape] + params = self.matmul_param_db[shape] + # ensure transpose flags are set + params.setdefault("transpose_a", False) + params.setdefault("transpose_b", False) + return params - def get_parameters_for_layers(self, shapes: list[tuple[int, int, int]]) -> list: - return [self.get_parameters(*shape) for shape in shapes] + def get_parameters_for_layers(self, param_list: list[dict]) -> list: + return [self.get_parameters(**params) for params in param_list] diff --git a/lighthouse/utils/mlir.py b/lighthouse/utils/mlir.py index 220be2f1..bf95b5be 100644 --- a/lighthouse/utils/mlir.py +++ b/lighthouse/utils/mlir.py @@ -6,6 +6,7 @@ from mlir.dialects import func, linalg import os from pathlib import Path +from collections import defaultdict def get_mlir_library_path(): @@ -55,37 +56,112 @@ def inspect_payload(payload_module: ir.Module) -> dict: function_name: { "inputs": [input types], "results": [result types], - "matmuls": [(m, n, k), ...] # list of matmul shapes + "layers": { + "matmul": { + "m": m, + "n": n, + "k": k, + "transpose_a": bool, + "transpose_b": bool, + } + ... + } }, ... } """ + def has_producer(value: ir.Value, kind: type) -> bool: + if value is None or isinstance(value, ir.BlockArgument): + # stop trace + return False + if isinstance(value, ir.OpResult): + parent_op = value.owner + if isinstance(parent_op, kind): + return True + # recursively check producers + for operand in parent_op.operands: + if has_producer(operand, kind): + return True + return False + functions = {} def match_funcs(op: ir.Operation) -> ir.WalkResult: op = op.opview match op: case func.FuncOp(): - matmuls = [] + layers = defaultdict(list) def match_linalg(op: ir.Operation) -> ir.WalkResult: op = op.opview match op: + case linalg.GenericOp(): + # TODO support ElementwiseOp and MapOp + iter_parallel = "#linalg.iterator_type" + parallel = all( + str(it) == iter_parallel for it in op.iterator_types + ) + assert parallel, ( + "Only parallel iterators are supported in linalg.generic" + ) + outputs = op.outputs + assert len(outputs) == 1, "Expected only one output" + out_shape = outputs[0].type.shape + layers["elemwise"].append({"shape": out_shape}) case linalg.MatmulOp(): inputs = op.inputs outputs = op.outputs assert len(inputs) == 2 and len(outputs) == 1 - m, k = inputs[0].type.shape - _, n = inputs[1].type.shape - matmuls.append((m, n, k)) + input_is_transpose = [ + has_producer(o, linalg.TransposeOp) for o in inputs + ] + a_shape, b_shape = [d.type.shape for d in inputs] + c_shape = outputs[0].type.shape + assert len(c_shape) == 2 + assert len(a_shape) == 2 or len(b_shape) == 2 + m, n = c_shape + try: + _, k = a_shape + except Exception: + k, _ = b_shape + layers["matmul"].append( + { + "shape": (m, n, k), + "transpose_a": input_is_transpose[0], + "transpose_b": input_is_transpose[1], + } + ) + case linalg.BatchMatmulOp(): + inputs = op.inputs + outputs = op.outputs + assert len(inputs) == 2 and len(outputs) == 1 + input_is_transpose = [ + has_producer(o, linalg.TransposeOp) for o in inputs + ] + a_shape, b_shape = [d.type.shape for d in inputs] + c_shape = outputs[0].type.shape + assert len(c_shape) == 3 + assert len(a_shape) == 3 or len(b_shape) == 3 + b, m, n = c_shape + try: + _, _, k = a_shape + except Exception: + _, k, _ = b_shape + layers["batch_matmul"].append( + { + "shape": (b, m, n, k), + "transpose_a": input_is_transpose[0], + "transpose_b": input_is_transpose[1], + } + ) return ir.WalkResult.ADVANCE op.walk(match_linalg, ir.WalkOrder.PRE_ORDER) functions[op.sym_name.value] = { "inputs": op.type.inputs, "results": op.type.results, - "matmuls": matmuls, + "layers": layers, } return ir.WalkResult.ADVANCE