CUTLASS Grouped GEMM#6
Conversation
|
NVIDIA/cutlass#1286 should be available on H100 now as well |
|
Hi! Thanks for the PR! We have users who currently rely on the cuBLAS path for Hopper, which this PR deletes, I think.
Since this is now available, it'd be great to support for SM90! It looks like it requires a very new version of CUDA so perhaps it would be best to keep around the simple cuBLAS implementation to fallback to if we can't support CUTLASS grouped GEMM? |
|
is there any critical reason for why grouped gemm is hardcoded to use BFloat16? or would a string replace of bfloat16 with float16 just work? |
|
There is no reason why we only support BFloat16. I implemented only bfloat because that was what our user who needed this feature uses. It would be relatively easy to template our helpers and dispatch based on input tensor type. |
|
@tgale96 For context: I'm working on a branch that removes the CPU<->GPU sync for I also stumbled upon a nasty CUTLASS bug when one of the elements in |
|
Hey! It would be great to have a full CUTLASS path but I do not personally have the cycles for it at the moment. Contributions would be very welcome, and I'd be happy to provide any guidance that is necessary! |
|
Cool! I opened #14 as a starting point. Any guidance would be much appreciated! :) |
Use CUTLASS for grouped GEMM (both no transposition, trans_a, trans_b).
~20% speedup on A100