Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,14 @@ void multi_tensor_quantize_impl(const std::vector<TensorWrapper> &input_list,
}
}

// Check if all quantizers are MXFP8 (for multi-stream quantize)
bool with_mxfp8_multi_stream = false;
if (!with_fused_kernel) {
with_mxfp8_multi_stream = std::all_of(
quantizer_py_list.begin(), quantizer_py_list.end(),
[](const py::handle &q) { return detail::IsMXFP8Quantizers(q.ptr()); });
}

// Launch TE kernel
if (with_fused_kernel) {
// Fused kernel for multi-tensor quantize
Expand All @@ -145,6 +153,20 @@ void multi_tensor_quantize_impl(const std::vector<TensorWrapper> &input_list,
nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(),
nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream());
});
} else if (with_mxfp8_multi_stream) {
// Multi-stream quantize for MXFP8
std::vector<NVTETensor> nvte_tensor_input_list;
std::vector<NVTETensor> nvte_tensor_output_list;
for (size_t i = 0; i < num_tensors; ++i) {
nvte_tensor_input_list.push_back(input_list[i].data());
nvte_tensor_output_list.push_back(output_list[i].data());
}
QuantizationConfigWrapper quant_config;
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_tensor_quantize(nvte_tensor_input_list.data(), nvte_tensor_output_list.data(),
quant_config, num_tensors,
at::cuda::getCurrentCUDAStream());
});
} else {
// Quantize kernels individually
for (size_t i = 0; i < num_tensors; ++i) {
Expand Down
Loading