[xegpu] Add transpose A/B support in mlp schedule#184
Conversation
87f27d8 to
adc7607
Compare
We need to split kernels between matmul ops. Could be added as an option.
adc7607 to
b8aadcb
Compare
b8aadcb to
1ebf952
Compare
| _, n = inputs[1].type.shape | ||
| matmuls.append((m, n, k)) | ||
| input_is_transpose = [ | ||
| has_producer(o, linalg.TransposeOp) for o in inputs |
There was a problem hiding this comment.
Wouldn't it be more generic to check matmul's indexing maps?
There was a problem hiding this comment.
Hmm, yeah it depends on how the IR is formulated. In KernelBench we get IR with explicit linag.transpose ops:
%transposed = linalg.transpose ins(%0 : ...) outs(%2 : tensor<2048x8192xf16>) permutation = [1, 0]
%5 = linalg.matmul ins(%transposed, %1 : tensor<2048x8192xf16>, ...) -> tensor<2048x4096xf32>which this approach detects. Matmul with transposed indexing map would be another variant, that pattern is currently not supported.
There was a problem hiding this comment.
I see, so the imported IR always has explicit transpose op before matmul.
Then maybe I'd rename the metadata field to be more specific to indicate that A/B has transpose producer and not implicit transpose through indexing maps.
Knowing these are separate operation vs performing C += A^T * B can be a meaningful difference.
There was a problem hiding this comment.
Yeah, on vector level however the explicit transpose op will be gone (using the current xegpu lowering): there's one vector.contract with transposed index map:
#map = affine_map<(d0, d1, d2) -> (d2, d0)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
%8 = vector.transfer_read %1[%arg6, %arg3], %0 {in_bounds = [true, true]} : tensor<8192x2048xf16>, vector<32x256xf16>
%9 = vector.transfer_read %2[%arg6, %arg4], %0 {in_bounds = [true, true]} : tensor<8192x4096xf16>, vector<32x256xf16>
%10 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %9, %arg7 : vector<32x256xf16>, vector<32x256xf16> into vector<256x256xf32>As such the two alternatives, either explicit linalg.transpose producer or linalg.matmul with transposed indexing map, would AFAIK be identical here.
For xegpu lowering we essentially need to know if there's a transpose op in the A/B tile producer chain. As such, I'm not sure if we need to differentiate the two variants in the metadata. If such differentiation becomes necessary at some point we can add it then.
There was a problem hiding this comment.
I'm approaching it from a perspective of payload inspector as a standalone tool.
I agree that in this case it makes no difference for the current use case.
So, I'm not pushing to support the indexing map check too. Just removing ambiguity from the returned metadata should be enough to avoid future confusion when somebody uses the inspector without Xe pipeline assumptions.
There was a problem hiding this comment.
I'm approaching it from a perspective of payload inspector as a standalone tool.
I agree that in this case it makes no difference for the current use case.
Yes, understand. I would put my "test driven development" hat on and keep things simple until we need to differentiate. We don't know if there will be use cases where differentiating the transpose variants matters. We also don't know if payload inspector can be generalized - xegpu case might need different kind of analysis than some other use case.
There was a problem hiding this comment.
Then I'd consider moving this whole logic into some Xe-specific utils.
Anyway, not a blocker right now.
There was a problem hiding this comment.
Yes, the layer matching could be moved to xegpu land. payload_inspector has one other use case examples/mpi/feed-forward-mpi.py but it only uses the function arg shapes.
inspect_payloadreturns metadata in generic"layers"nested dict. Currently returns metadata for matmul, batch_matmul and elemwise layers.generate_configshas a verbose flag which prints tile selection info.get_tileable_consumersreturns only matmul epilog, i.e., excludes next linalg.matmul op.--transpose-a/boptions.