Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/xegpu/enumerate_matmul_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
"prefetch_b_n": 16,
"prefetch_a_nb": 1,
"prefetch_b_nb": 1,
"transpose_a": False,
"transpose_b": False,
}

# Check that at least one constraint was reified into the schedule.
Expand Down
100 changes: 68 additions & 32 deletions examples/xegpu/matmul.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# RUN: %PYTHON %s --sizes 512 1024 128 --dump-kernel=xegpu-wg | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --transpose-a | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --transpose-b | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --bias | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --relu | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --bias --relu | FileCheck %s
Expand Down Expand Up @@ -30,6 +32,7 @@
from lighthouse.schedule.xegpu import mlp_schedule, xegpu_to_binary
from lighthouse.utils.numpy import mlir_to_numpy_dtype
from lighthouse.ingress.mlir_gen import generate_gpu_matmul_payload, get_mlir_elem_type
from lighthouse.schedule.xegpu import XeGPUParameterSelector


def matmul_complexity(
Expand Down Expand Up @@ -77,6 +80,8 @@ class XeGPUMatMul:
K: int = 1024
ab_type: ir.Type | str | None = None
c_type: ir.Type | str | None = None
transpose_a: bool = False
transpose_b: bool = False
has_bias: bool = False
has_relu: bool = False
accumulate_c: bool = True
Expand All @@ -96,8 +101,8 @@ def __post_init__(self):
assert isinstance(self.c_type, ir.F32Type), "Only f32 type is supported for C"
self.ab_dtype = mlir_to_numpy_dtype(self.ab_type)
self.c_dtype = mlir_to_numpy_dtype(self.c_type)
self.a_shape = (self.M, self.K)
self.b_shape = (self.K, self.N)
self.a_shape = (self.M, self.K) if not self.transpose_a else (self.K, self.M)
self.b_shape = (self.K, self.N) if not self.transpose_b else (self.N, self.K)
self.c_shape = (self.M, self.N)
self.bias_shape = (self.N,)

Expand Down Expand Up @@ -142,6 +147,8 @@ def payload_module(self) -> ir.Module:
K=self.K,
ab_type=self.ab_type,
c_type=self.c_type,
transpose_a=self.transpose_a,
transpose_b=self.transpose_b,
has_bias=self.has_bias,
has_relu=self.has_relu,
accumulate_c=self.accumulate_c,
Expand Down Expand Up @@ -190,6 +197,11 @@ def check_results(
C, A, B = host_inputs[:3]
bias = host_inputs[3] if mmul.has_bias else None

if mmul.transpose_a:
A = A.T
if mmul.transpose_b:
B = B.T

# use float32 data type for efficiency
f32 = np.float32
D_ref = A.astype(f32) @ B.astype(f32)
Expand Down Expand Up @@ -233,6 +245,16 @@ def cli_parser(description):
default=[4096, 4096, 4096],
help="M,N,K matrix sizes (A=MxK, B=KxN, C=MxN).",
)
parser.add_argument(
"--transpose-a",
action="store_true",
help="Transpose matrix A (i.e., A is KxM) before multiplication.",
)
parser.add_argument(
"--transpose-b",
action="store_true",
help="Transpose matrix B (i.e., B is NxK) before multiplication.",
)
parser.add_argument(
"--bias",
action="store_true",
Expand Down Expand Up @@ -375,46 +397,54 @@ def parse_cli_args(description):

# Problem size
m, n, k = args.sizes if args.sizes else (4096, 4096, 4096)
transpose_a = args.transpose_a
transpose_b = args.transpose_b

# Set required parameters
params = {
"m": m,
"n": n,
"k": k,
"transpose_a": transpose_a,
"transpose_b": transpose_b,
}
if args.target:
params["device"] = args.target

if args.json:
# Override parameters with values from JSON file if provided
with open(args.json, "r") as f:
json_params = json.load(f)
params.update(json_params)

# Override parameters with CLI args if provided
# Collect parameters from CLI arguments
cli_params = {}
if args.wg_tile:
params["wg_m"], params["wg_n"] = args.wg_tile
cli_params["wg_m"], cli_params["wg_n"] = args.wg_tile
if args.sg_tile:
params["sg_m"], params["sg_n"] = args.sg_tile
cli_params["sg_m"], cli_params["sg_n"] = args.sg_tile
if args.k_tile:
params["k_tile"] = args.k_tile
cli_params["k_tile"] = args.k_tile
if args.load_tile_a:
params["load_a_m"], params["load_a_k"] = args.load_tile_a
cli_params["load_a_m"], cli_params["load_a_k"] = args.load_tile_a
if args.load_tile_b:
params["load_b_k"], params["load_b_n"] = args.load_tile_b
cli_params["load_b_k"], cli_params["load_b_n"] = args.load_tile_b
if args.prefetch_tile_a:
params["prefetch_a_m"], params["prefetch_a_k"] = args.prefetch_tile_a
cli_params["prefetch_a_m"], cli_params["prefetch_a_k"] = args.prefetch_tile_a
if args.prefetch_tile_b:
params["prefetch_b_k"], params["prefetch_b_n"] = args.prefetch_tile_b
cli_params["prefetch_b_k"], cli_params["prefetch_b_n"] = args.prefetch_tile_b
if args.prefetch_a_nb is not None:
params["prefetch_a_nb"] = args.prefetch_a_nb
cli_params["prefetch_a_nb"] = args.prefetch_a_nb
if args.prefetch_b_nb is not None:
params["prefetch_b_nb"] = args.prefetch_b_nb
cli_params["prefetch_b_nb"] = args.prefetch_b_nb

for k, v in params.items():
if v is None:
raise ValueError(
f"Parameter {k} is not set. Please provide it via CLI or JSON file."
)
# By default the tile size parameters are left undefined
if args.json:
# Override parameters with values from JSON file if provided
with open(args.json, "r") as f:
json_params = json.load(f)
params.update(json_params)
# Override with CLI params
params.update(cli_params)
elif cli_params:
# Get default parameters from selector
param_selector = XeGPUParameterSelector(device=args.target)
def_params = param_selector.get_parameters((m, n, k), transpose_a, transpose_b)
params.update(def_params)
# Override with CLI params
params.update(cli_params)

with ir.Context(), ir.Location.unknown():
lh_dialects.register_and_load()
Expand All @@ -423,6 +453,8 @@ def parse_cli_args(description):
M=params["m"],
N=params["n"],
K=params["k"],
transpose_a=params["transpose_a"],
transpose_b=params["transpose_b"],
has_bias=args.bias,
has_relu=args.relu,
accumulate_c=not args.no_accumulate_c,
Expand Down Expand Up @@ -481,14 +513,18 @@ def list2str(a):
c_type = str(wload.c_type)
print(
f"sizes={list2str([params['m'], params['n'], params['k']])} "
f"ta={int(params['transpose_a'])} "
f"tb={int(params['transpose_b'])} "
f"dt={ab_type},{c_type} "
f"wg-tile={list2str([params['wg_m'], params['wg_n']])} "
f"sg-tile={list2str([params['sg_m'], params['sg_n']])} "
f"k-tile={params['k_tile']} "
f"load-a-tile={list2str([params['load_a_m'], params['load_a_k']])} "
f"load-b-tile={list2str([params['load_b_k'], params['load_b_n']])} "
f"pf-a-tile={list2str([params['prefetch_a_m'], params['prefetch_a_k']])} "
f"pf-b-tile={list2str([params['prefetch_b_k'], params['prefetch_b_n']])} "
f"wg={list2str([params['wg_m'], params['wg_n']])} "
f"sg={list2str([params['sg_m'], params['sg_n']])} "
f"k={params['k_tile']} "
f"ld-a={list2str([params['load_a_m'], params['load_a_k']])} "
f"ld-b={list2str([params['load_b_k'], params['load_b_n']])} "
f"pf-a={list2str([params['prefetch_a_m'], params['prefetch_a_k']])} "
f"pf-b={list2str([params['prefetch_b_k'], params['prefetch_b_n']])} "
f"pf-a-nb={params['prefetch_a_nb']} "
f"pf-b-nb={params['prefetch_b_nb']} "
f"time(us): {elapsed:.2f} "
f"GFLOPS: {gflops:.2f}"
)
55 changes: 52 additions & 3 deletions examples/xegpu/mlp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --transpose-a | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --transpose-b | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --relu | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --bias | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --accumulate-c | FileCheck %s
Expand Down Expand Up @@ -46,6 +48,8 @@ def check_correctness(
ab_dtype: np.dtype,
has_bias: bool = False,
has_relu: bool = False,
transpose_a: bool = False,
transpose_b: bool = False,
verbose: int = 0,
) -> bool:
output_array, input_array, *rest = initial_host_arrays
Expand All @@ -63,7 +67,11 @@ def check_correctness(
biases = [b.astype(np.float32) for b in biases]

a_array = input_array
if transpose_a:
a_array = a_array.T
for i, W in enumerate(weights):
if transpose_b:
W = W.T
D_ref = a_array @ W
if has_bias:
D_ref += biases[i]
Expand Down Expand Up @@ -109,6 +117,8 @@ class XeGPUMLP:
hidden_layer_sizes: Optional[list[int]] = None
ab_type: ir.Type | str | None = None
acc_type: ir.Type | str | None = None
transpose_a: bool = False
transpose_b: bool = False
has_bias: bool = False
has_relu: bool = False
accumulate_c: bool = False
Expand All @@ -135,9 +145,13 @@ def __post_init__(self):
if self.hidden_layer_sizes is None:
self.hidden_layer_sizes = []
self.input_shape = (self.batch_size, self.input_size)
if self.transpose_a:
self.input_shape = self.input_shape[::-1]
self.output_shape = (self.batch_size, self.output_size)
layer_sizes = [self.input_size] + self.hidden_layer_sizes + [self.output_size]
self.weight_shapes = list(zip(layer_sizes[:-1], layer_sizes[1:]))
if self.transpose_b:
self.weight_shapes = [shape[::-1] for shape in self.weight_shapes]
self.matmul_layers = [(self.batch_size, o, i) for i, o in self.weight_shapes]
self.bias_shapes = [(o,) for o in layer_sizes[1:]] if self.has_bias else []

Expand All @@ -154,10 +168,16 @@ def gen_random(shape, dtype):
return np.random.rand(*shape).astype(dtype)

def gen_identity(shape, dtype):
# identity matrix, if cols > rows wrap to fill all columns
# identity matrix,
a = np.zeros(shape, dtype=dtype)
np.fill_diagonal(a, 1)
if shape[1] > shape[0]:
if self.transpose_b:
if shape[0] > shape[1]:
# if rows > cols wrap to fill all rows
second_block = a[shape[1] :, :]
np.fill_diagonal(second_block, 1)
elif shape[1] > shape[0]:
# if cols > rows wrap to fill all columns
second_block = a[:, shape[0] :]
np.fill_diagonal(second_block, 1)
return a
Expand Down Expand Up @@ -208,6 +228,8 @@ def payload_module(self) -> ir.Module:
acc_type=self.acc_type,
bias_type=self.ab_type,
result_type=self.ab_type,
transpose_a=self.transpose_a,
transpose_b=self.transpose_b,
has_bias=self.has_bias,
has_relu=self.has_relu,
accumulate_c=self.accumulate_c,
Expand Down Expand Up @@ -299,6 +321,16 @@ def parse_cli():
action="store_true",
help="Add ReLU activation function to each layer except the output layer.",
)
parser.add_argument(
"--transpose-a",
action="store_true",
help="Transpose the input matrix A in the first matmul layer.",
)
parser.add_argument(
"--transpose-b",
action="store_true",
help="Transpose the weight matrices B in all matmul layers.",
)
parser.add_argument(
"--accumulate-c",
action="store_true",
Expand Down Expand Up @@ -358,13 +390,17 @@ def parse_cli():
with ir.Context(), ir.Location.unknown():
lh_dialects.register_and_load()

tr_a = args.transpose_a
tr_b = args.transpose_b
wload = XeGPUMLP(
batch_size=args.batch_size,
input_size=args.input_size,
output_size=args.output_size,
hidden_layer_sizes=args.hidden_sizes,
has_bias=args.bias,
has_relu=args.relu,
transpose_a=tr_a,
transpose_b=tr_b,
accumulate_c=args.accumulate_c,
identity_weights=identity_weights,
)
Expand All @@ -376,7 +412,16 @@ def parse_cli():
acc_type = wload.acc_type

# Initialize layer parameters
params = [{"m": M, "n": N, "k": K} for M, N, K in matmuls]
params = []
for i, (M, N, K) in enumerate(matmuls):
layer_params = {
"m": M,
"n": N,
"k": K,
"transpose_a": tr_a if i == 0 else False,
"transpose_b": tr_b,
}
params.append(layer_params)
if args.target:
for layer_params in params:
layer_params["device"] = args.target
Expand Down Expand Up @@ -422,6 +467,8 @@ def parse_cli():
wload.ab_dtype,
has_bias=wload.has_bias,
has_relu=wload.has_relu,
transpose_a=tr_a,
transpose_b=tr_b,
verbose=args.verbose,
)
if not success:
Expand All @@ -447,6 +494,8 @@ def list2str(a):
f"i={args.input_size} "
f"o={args.output_size} "
f"hs={list2str(hidden_sizes)} "
f"ta={int(tr_a)} "
f"tb={int(tr_b)} "
f"dt={ab_type},{acc_type} "
f"time(us): {elapsed:.2f} "
f"GFLOPS: {gflops:.2f}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def is_tileable_op(op: ir.Operation) -> bool:
linalg.MaxOp,
linalg.MinOp,
linalg.FillOp,
linalg.MatmulOp,
linalg.GenericOp,
]
return isinstance(op.opview, tuple(linalg_ops))
Expand Down
4 changes: 4 additions & 0 deletions lighthouse/ingress/mlir_gen/gpu_matmul_payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ def generate_gpu_matmul_payload(
K: int,
ab_type: ir.Type,
c_type: ir.Type,
transpose_a: bool,
transpose_b: bool,
has_bias: bool,
has_relu: bool,
accumulate_c: bool,
Expand All @@ -24,6 +26,8 @@ def generate_gpu_matmul_payload(
acc_type=c_type,
bias_type=c_type,
result_type=c_type,
transpose_a=transpose_a,
transpose_b=transpose_b,
has_bias=has_bias,
has_relu=has_relu,
accumulate_c=accumulate_c,
Expand Down
Loading
Loading