diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index acbe4753b..f3bf2a804 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -132,6 +132,14 @@ void multi_tensor_quantize_impl(const std::vector &input_list, } } + // Check if all quantizers are MXFP8 (for multi-stream quantize) + bool with_mxfp8_multi_stream = false; + if (!with_fused_kernel) { + with_mxfp8_multi_stream = std::all_of( + quantizer_py_list.begin(), quantizer_py_list.end(), + [](const py::handle &q) { return detail::IsMXFP8Quantizers(q.ptr()); }); + } + // Launch TE kernel if (with_fused_kernel) { // Fused kernel for multi-tensor quantize @@ -145,6 +153,20 @@ void multi_tensor_quantize_impl(const std::vector &input_list, nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(), nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); }); + } else if (with_mxfp8_multi_stream) { + // Multi-stream quantize for MXFP8 + std::vector nvte_tensor_input_list; + std::vector nvte_tensor_output_list; + for (size_t i = 0; i < num_tensors; ++i) { + nvte_tensor_input_list.push_back(input_list[i].data()); + nvte_tensor_output_list.push_back(output_list[i].data()); + } + QuantizationConfigWrapper quant_config; + NVTE_SCOPED_GIL_RELEASE({ + nvte_multi_tensor_quantize(nvte_tensor_input_list.data(), nvte_tensor_output_list.data(), + quant_config, num_tensors, + at::cuda::getCurrentCUDAStream()); + }); } else { // Quantize kernels individually for (size_t i = 0; i < num_tensors; ++i) {