When I use float16 to train model, it says: `RuntimeError: Expected a.scalar_type() == torch::kBFloat16 to be true, but got false.`. Seems this is related to https://github.com/tgale96/grouped_gemm/blob/f4c08bc89e73b0b343e0815c680c4c9c9875302b/csrc/grouped_gemm.cu#L364 Look forward to your reply sincerely.
When I use float16 to train model, it says:
RuntimeError: Expected a.scalar_type() == torch::kBFloat16 to be true, but got false..Seems this is related to
grouped_gemm/csrc/grouped_gemm.cu
Line 364 in f4c08bc
Look forward to your reply sincerely.