diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000000..3087832fa4 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1,24 @@ +# IMPORTANT: +# This file is ONLY used to subscribe for notifications for PRs +# related to a specific file path. Approvals from people in this +# file are not required for merges. + +# C API +/transformer_engine/common/include/ @ptrendx + +# TE/JAX +/transformer_engine/jax/ @jberchtold-nvidia + +# TE/PyTorch +/transformer_engine/pytorch/ @ksivaman + +# te.ops API +/transformer_engine/pytorch/ops/ @timmoon10 + +# Quantization kernels +/transformer_engine/common/cast/ @Oleg-Goncharov + +# Attention +/transformer_engine/pytorch/attention/ @cyanguwa +/transformer_engine/common/fused_attn/ @cyanguwa +/transformer_engine/jax/cpp_extensions/attention.py @KshitijLakhani diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 4b43940a51..504d761bb1 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include "../common.h" @@ -202,6 +203,49 @@ const std::string &include_directory(bool required) { return path; } +int include_directory_version(bool required) { + // Header path + const auto &include_dir = cuda::include_directory(false); + if (include_dir.empty()) { + if (required) { + NVTE_ERROR( + "Could not detect version of CUDA Toolkit headers " + "(CUDA Toolkit headers not found)."); + } + return -1; + } + + // Parse CUDART_VERSION from cuda_runtime_api.h. + const auto header_path = std::filesystem::path(include_dir) / "cuda_runtime_api.h"; + std::ifstream header_file(header_path); + if (header_file.is_open()) { + const std::string define_prefix = "#define CUDART_VERSION "; + std::string line; + while (std::getline(header_file, line)) { + const auto pos = line.find(define_prefix); + if (pos == std::string::npos) { + continue; + } + try { + const int version = std::stoi(line.substr(pos + define_prefix.size())); + if (version > 0) { + return version; + } + } catch (...) { + continue; + } + } + } + + if (required) { + NVTE_ERROR( + "Could not detect version of CUDA Toolkit headers " + "(Could not parse CUDART_VERSION from ", + header_path.string(), ")."); + } + return -1; +} + int cudart_version() { auto get_version = []() -> int { int version; diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index f0aa239622..0f35594001 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -67,6 +67,21 @@ bool supports_multicast(int device_id = -1); */ const std::string &include_directory(bool required = false); +/* \brief Version number of CUDA Toolkit headers + * + * The headers are accessed at run-time and its CUDA version may + * differ from compile-time and from the CUDA Runtime. The header path + * can be configured by setting NVTE_CUDA_INCLUDE_DIR in the + * environment (default is to search in common install paths). + * + * \param[in] required Whether to throw exception if headers are not + * found or if version cannot be determined. + * + * \return CUDA version encoded as major * 1000 + minor * 10, or -1 if + * it could not be determined. + */ +int include_directory_version(bool required = false); + /* \brief CUDA Runtime version number at run-time * * Versions may differ between compile-time and run-time. diff --git a/transformer_engine/common/util/rtc.cpp b/transformer_engine/common/util/rtc.cpp index 7925fdceea..70024a202c 100644 --- a/transformer_engine/common/util/rtc.cpp +++ b/transformer_engine/common/util/rtc.cpp @@ -12,6 +12,7 @@ #include "../common.h" #include "../util/cuda_driver.h" +#include "../util/cuda_runtime.h" #include "../util/string.h" #include "../util/system.h" @@ -175,14 +176,46 @@ void KernelManager::compile(const std::string& kernel_label, const std::string& const nvrtcResult compile_result = nvrtcCompileProgram(program, opts_ptrs.size(), opts_ptrs.data()); if (compile_result != NVRTC_SUCCESS) { - // Display log if compilation failed - std::string log = concat_strings("NVRTC compilation log for ", filename, ":\n"); + std::string log; + + // Decode CUDA version number to "major.minor" string + auto version_string = [](int v) -> std::string { + if (v < 0) { + return ""; + } + return concat_strings(v / 1000, ".", (v % 1000) / 10); + }; + + // Check CUDA versions + const int build_version = CUDA_VERSION; + int nvrtc_version = -1; + int nvrtc_version_major = 0, nvrtc_version_minor = 0; + if (nvrtcVersion(&nvrtc_version_major, &nvrtc_version_minor) == NVRTC_SUCCESS) { + nvrtc_version = nvrtc_version_major * 1000 + nvrtc_version_minor * 10; + } + const int header_version = cuda::include_directory_version(); + log += concat_strings("Compile-time CUDA version: ", version_string(build_version), "\n", + "Run-time NVRTC version: ", version_string(nvrtc_version), "\n", + "Run-time CUDA headers version: ", version_string(header_version), "\n"); + if (nvrtc_version != header_version) { + log += concat_strings( + "\nWarning: CUDA versions do not match between NVRTC and CUDA headers (", + cuda::include_directory(), + "). " + "Consider changing the CUDA header search path (by setting NVTE_CUDA_INCLUDE_DIR) " + "or the linked CUDA Runtime (by setting CUDA_HOME or LD_LIBRARY_PATH).\n\n"); + } + + // Get build log + log += concat_strings("NVRTC compilation log for ", filename, ":\n"); const size_t log_offset = log.size(); size_t log_size; NVTE_CHECK_NVRTC(nvrtcGetProgramLogSize(program, &log_size)); log.resize(log_offset + log_size); NVTE_CHECK_NVRTC(nvrtcGetProgramLog(program, &log[log_offset])); log.back() = '\n'; + + // Display log and throw error std::cerr << log; NVTE_CHECK_NVRTC(compile_result); } diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 35a459351b..94350da1e6 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -102,8 +102,9 @@ class Quantizer { virtual void set_quantization_params(TensorWrapper* tensor) const = 0; /*! @brief Construct a tensor with uninitialized data */ - virtual std::pair create_tensor(const std::vector& shape, - DType dtype) const = 0; + virtual std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional device = std::nullopt, bool pin_memory = false) const = 0; /*! @brief Construct a grouped tensor with uninitialized data */ virtual std::pair create_grouped_tensor( @@ -144,8 +145,9 @@ class NoneQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override {} - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional device = std::nullopt, bool pin_memory = false) const override; std::pair create_grouped_tensor( size_t num_tensors, const std::vector& logical_shape, DType dtype, @@ -174,8 +176,9 @@ class Float8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional device = std::nullopt, bool pin_memory = false) const override; std::pair create_grouped_tensor( size_t num_tensors, const std::vector& logical_shape, DType dtype, @@ -183,10 +186,10 @@ class Float8Quantizer : public Quantizer { size_t logical_last_dim) const override; /*! @brief Construct a tensor with pre-initialized data */ - std::pair create_tensor(const std::vector& shape, DType dtype, - std::optional data, - std::optional transpose, - std::optional scale_inv) const; + std::pair create_tensor( + const std::vector& shape, DType dtype, std::optional data, + std::optional transpose, std::optional scale_inv, + std::optional device = std::nullopt, bool pin_memory = false) const; std::pair convert_and_update_tensor(py::object shape) const override; @@ -208,8 +211,9 @@ class Float8CurrentScalingQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional device = std::nullopt, bool pin_memory = false) const override; std::pair create_grouped_tensor( size_t num_tensors, const std::vector& logical_shape, DType dtype, @@ -270,8 +274,9 @@ class Float8BlockQuantizer : public Quantizer { // Create a python Float8BlockQuantized tensor and C++ wrapper // for the tensor. Should set quantized data, scales for rowwise // and optionally columnwise usage. - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional device = std::nullopt, bool pin_memory = false) const override; std::pair create_grouped_tensor( size_t num_tensors, const std::vector& logical_shape, DType dtype, @@ -294,8 +299,9 @@ class MXFP8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional device = std::nullopt, bool pin_memory = false) const override; std::pair create_grouped_tensor( size_t num_tensors, const std::vector& logical_shape, DType dtype, @@ -333,8 +339,9 @@ class NVFP4Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional device = std::nullopt, bool pin_memory = false) const override; std::pair create_grouped_tensor( size_t num_tensors, const std::vector& logical_shape, DType dtype, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 9b10a9c5a4..8082ff07ed 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -320,9 +320,12 @@ std::vector bulk_allocate(const std::vector> &sh std::optional> alignments = std::nullopt); /*************************************************************************************************** - * Cast + * Quantize **************************************************************************************************/ +py::object create_empty_quantized_tensor(py::handle quantizer, const std::vector &shape, + at::ScalarType dtype, at::Device device, bool pin_memory); + py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, std::optional noop_flag); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 3ada2459c8..2b38339d67 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -65,6 +65,14 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob return output_py; } +py::object create_empty_quantized_tensor(py::handle quantizer, const std::vector &shape, + at::ScalarType dtype, at::Device device, bool pin_memory) { + auto quantizer_cpp = convert_quantizer(quantizer); + auto te_dtype = GetTransformerEngineDType(dtype); + auto [_, output_py] = quantizer_cpp->create_tensor(shape, te_dtype, device, pin_memory); + return output_py; +} + namespace { // helper functions for NVFP4 grouped quantization (cuda graph safe with shapes stored in device without D2H copy) diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index a813f3119d..a4571c64e2 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -139,6 +139,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("output") = py::none(), py::arg("noop") = py::none()); m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), py::arg("otype")); + m.def("create_empty_quantized_tensor", + &transformer_engine::pytorch::create_empty_quantized_tensor, + "Create an empty quantized tensor", py::arg("quantizer"), py::arg("shape"), + py::arg("dtype"), py::arg("device"), py::arg("pin_memory")); m.def("group_quantize", transformer_engine::pytorch::group_quantize, py::arg("tensor"), py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims")); m.def("group_dequantize", transformer_engine::pytorch::group_dequantize, diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 82dfe4d222..7045995dd7 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -16,6 +16,29 @@ namespace transformer_engine::pytorch { namespace { +/*! @brief Resolve an optional device to a concrete CUDA device + * + * If no device is provided, uses the current CUDA device. + */ +at::Device resolve_device(std::optional device, + const std::optional& data = std::nullopt) { + if (device.has_value() && data.has_value()) { + // Ensure that they are the same + const auto provided_device = *device; + const auto data_device = data->device(); + NVTE_CHECK(provided_device == data_device, + "Provided device and the device of the provided data tensor are not the same."); + return provided_device; + } + if (device.has_value()) { + return *device; + } + if (data.has_value()) { + return data->device(); + } + return at::Device(torch::kCUDA, c10::cuda::current_device()); +} + /*! @brief Transposed tensor shape * * The tensor is interpreted as a 2D matrix by flattening all but the @@ -129,10 +152,13 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti this->dtype = type; } -std::pair NoneQuantizer::create_tensor(const std::vector& shape, - DType dtype) const { +std::pair NoneQuantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional device_opt, + bool pin_memory) const { + const auto device = resolve_device(device_opt); const std::vector shape_int64(shape.begin(), shape.end()); - const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); + const auto opts = + at::TensorOptions().dtype(GetATenDType(dtype)).device(device).pinned_memory(pin_memory); return create_tensor(shape, dtype, at::empty(shape_int64, opts)); } @@ -240,22 +266,29 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { } std::pair Float8Quantizer::create_tensor( - const std::vector& shape, DType dtype) const { - const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + const std::vector& shape, DType dtype, std::optional device_opt, + bool pin_memory) const { + const auto device = resolve_device(device_opt); + const auto opts = + at::TensorOptions().dtype(torch::kFloat32).device(device).pinned_memory(pin_memory); at::Tensor scale_inv = at::empty(std::vector{1}, opts); - return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); + return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv), device, + pin_memory); } std::pair Float8Quantizer::create_tensor( const std::vector& shape, DType dtype, std::optional data, - std::optional transpose, std::optional scale_inv) const { + std::optional transpose, std::optional scale_inv, + std::optional device_opt, bool pin_memory) const { + const auto device = resolve_device(device_opt, data); using namespace pybind11::literals; int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Initialize data tensor const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data && !data) { const std::vector shape_int64(shape.begin(), shape.end()); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto opts = + at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory); data = at::empty(shape_int64, opts); } else if (!with_data && data) { data.reset(); @@ -266,7 +299,8 @@ std::pair Float8Quantizer::create_tensor( const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose && !transpose) { const auto transpose_shape = make_transpose_shape(shape); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto opts = + at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory); transpose = at::empty(transpose_shape, opts); } else if (!with_transpose && transpose) { transpose.reset(); @@ -277,10 +311,6 @@ std::pair Float8Quantizer::create_tensor( scale_inv = at::reciprocal(scale); } py::object scale_inv_py = py::cast(*scale_inv); - at::Device device = - with_data ? data->device() - : (with_transpose ? transpose->device() - : at::Device(torch::kCUDA, c10::cuda::current_device())); // Construct Python FP8 tensor py::object out_py; if (internal) { @@ -555,7 +585,9 @@ Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& q void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tensor) const {} std::pair Float8CurrentScalingQuantizer::create_tensor( - const std::vector& shape, DType dtype) const { + const std::vector& shape, DType dtype, std::optional device_opt, + bool pin_memory) const { + const auto device = resolve_device(device_opt); using namespace pybind11::literals; // Initialize data tensor @@ -564,7 +596,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data) { const std::vector shape_int64(shape.begin(), shape.end()); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto opts = + at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory); data_tensor = at::empty(shape_int64, opts); } @@ -573,20 +606,18 @@ std::pair Float8CurrentScalingQuantizer::create_tenso const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose) { const auto transpose_shape = make_transpose_shape(shape); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto opts = + at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory); transpose_tensor = at::empty(transpose_shape, opts); } // Initialize scale-inverse tensor at::Tensor scale_inv_tensor; { const std::vector scale_inv_shape = {1}; - const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + const auto opts = + at::TensorOptions().dtype(torch::kFloat32).device(device).pinned_memory(pin_memory); scale_inv_tensor = at::empty(scale_inv_shape, opts); } - at::Device device = - with_data ? data_tensor.device() - : (with_transpose ? transpose_tensor.device() - : at::Device(torch::kCUDA, c10::cuda::current_device())); // Construct Python FP8 tensor py::object out_py; py::object scale_inv_py = py::cast(scale_inv_tensor); @@ -924,7 +955,9 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} std::pair Float8BlockQuantizer::create_tensor( - const std::vector& shape, DType dtype) const { + const std::vector& shape, DType dtype, std::optional device_opt, + bool pin_memory) const { + const auto device = resolve_device(device_opt); using namespace pybind11::literals; std::vector torch_shape; for (auto s : shape) { @@ -935,8 +968,8 @@ std::pair Float8BlockQuantizer::create_tensor( at::TensorOptions opts; at::TensorOptions scale_opts; at::Tensor data_rowwise, data_colwise, scale_inv_rowwise, scale_inv_colwise; - opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); + opts = opts.dtype(torch::kUInt8).device(device).pinned_memory(pin_memory); + scale_opts = scale_opts.dtype(torch::kFloat32).device(device).pinned_memory(pin_memory); if (rowwise_usage) { data_rowwise = at::empty(torch_shape, opts); @@ -1015,6 +1048,7 @@ std::pair Float8BlockQuantizer::create_tensor( kwargs["fp8_dtype"] = py::cast(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["is_2D_scaled"] = py::cast(block_scaling_dim == 2); + kwargs["device"] = py::cast(device); py::tuple args(0); PyObject* result = PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorPythonClass), @@ -1334,8 +1368,10 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {} -std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, - DType dtype) const { +std::pair MXFP8Quantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional device_opt, + bool pin_memory) const { + const auto device = resolve_device(device_opt); using namespace pybind11::literals; // Scaling factor format @@ -1353,7 +1389,8 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve // Allocate tensors at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor; at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor; - const auto uint8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto uint8_tensor_opts = + at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory); if (rowwise_usage) { const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), rowwise_scale_inv_shape.end()); @@ -1413,6 +1450,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve kwargs["fp8_dtype"] = py::cast(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + kwargs["device"] = py::cast(device); py::tuple args(0); PyObject* result = PyObject_Call(reinterpret_cast(MXFP8TensorPythonClass), @@ -1722,8 +1760,10 @@ void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { columnwise_data.shape); } -std::pair NVFP4Quantizer::create_tensor(const std::vector& shape, - DType dtype) const { +std::pair NVFP4Quantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional device_opt, + bool pin_memory) const { + const auto device = resolve_device(device_opt); using namespace pybind11::literals; // Scaling factor format @@ -1749,8 +1789,10 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve // Allocate tensors at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor, amax_rowwise; at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor, amax_columnwise; - const auto bit8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - const auto bit32_tensor_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + const auto bit8_tensor_opts = + at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory); + const auto bit32_tensor_opts = + at::TensorOptions().dtype(torch::kFloat32).device(device).pinned_memory(pin_memory); if (rowwise_usage) { const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), rowwise_scale_inv_shape.end()); @@ -1831,6 +1873,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve kwargs["fp4_dtype"] = py::cast(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + kwargs["device"] = py::cast(device); kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); py::tuple args(0); PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index a7722f777e..7163e2b172 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -13,6 +13,8 @@ import torch from torch.utils._pytree import tree_map +import transformer_engine_torch as tex + from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch.tensor._quantization_helpers import ( _QuantizeFunc, @@ -311,13 +313,34 @@ def make_empty( shape: Iterable[int], *, dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, + device: Optional[Union[torch.device, str]] = None, + requires_grad: bool = False, + pin_memory: bool = False, ) -> QuantizedTensor: """Construct quantized tensor with uninitialized data""" - raise NotImplementedError( - f"{self.__class__.__name__} class does not implement make_empty function, " - "required for construction of unintialized quantized tensor" + + # Guard for custom quantizers that don't have a registered C++ converter. + # Without this, they would hit an opaque C++ NVTE_ERROR. + if getattr(self, "custom", False): + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement make_empty function, " + "required for construction of uninitialized quantized tensor" + ) + + if device is None: + device = torch.device("cuda") + # Handle the device passed as string + device = torch.device(device) + result = tex.create_empty_quantized_tensor( + self, + list(shape), + dtype, + device, + pin_memory, ) + if requires_grad: + result.requires_grad_(True) + return result def calibrate(self, tensor: torch.Tensor) -> None: """Calibrate quantizer state diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 914397b9b6..d0296902a9 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -202,62 +202,6 @@ def is_quantizable(self, inp: torch.Tensor) -> bool: return False return True - def make_empty( - self, - shape: Iterable[int], - *, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, - requires_grad: bool = False, - pin_memory: bool = False, - ) -> Float8BlockwiseQTensor: - """Construct quantized tensor with uninitialized data""" - - tensor_kwargs = { - "device": torch.device("cuda") if device is None else device, - "pin_memory": pin_memory, - } - - # Allocate buffers for row-scaled data - rowwise_data = None - rowwise_scale_inv = None - if self.rowwise_usage: - rowwise_data = torch.empty(shape, dtype=torch.uint8, **tensor_kwargs) - rowwise_scale_inv = torch.empty( - self.get_scale_shape(shape, columnwise=False), - dtype=torch.float32, - **tensor_kwargs, - ) - - # Allocate buffers for column-scaled data - columnwise_data = None - columnwise_scale_inv = None - if self.columnwise_usage: - columnwise_data = torch.empty( - self.get_columnwise_shape(shape), - dtype=torch.uint8, - **tensor_kwargs, - ) - columnwise_scale_inv = torch.empty( - self.get_scale_shape(shape, columnwise=True), - dtype=torch.float32, - **tensor_kwargs, - ) - - # Construct FP8 tensor - return Float8BlockwiseQTensor( - shape=shape, - dtype=dtype, - fp8_dtype=self.dtype, - rowwise_data=rowwise_data, - rowwise_scale_inv=rowwise_scale_inv, - columnwise_data=columnwise_data, - columnwise_scale_inv=columnwise_scale_inv, - quantizer=self, - is_2D_scaled=self.block_scaling_dim == 2, - requires_grad=requires_grad, - ) - def calibrate(self, tensor: torch.Tensor) -> None: # NOTE: This interface is specific to requirements like delayed scaling # where state from an estimator influences distribution parameters. diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index ed6091c85b..c4c5934f97 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -112,49 +112,6 @@ def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" return tex.quantize(tensor, self) - def make_empty( - self, - shape: Iterable[int], - *, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, - requires_grad: bool = False, - pin_memory: bool = False, - ) -> Float8Tensor: - - # Canonicalize tensor attributes - if device is None: - device = torch.device("cuda") - - # Allocate FP8 data - data = None - if self.rowwise_usage: - data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) - - # Allocate FP8 data transpose if needed - data_transpose = None - if self.columnwise_usage: - transpose_shape = [shape[-1]] + list(shape[:-1]) - data_transpose = torch.empty( - transpose_shape, - dtype=torch.uint8, - device=device, - pin_memory=pin_memory, - ) - - # Construct FP8 tensor - return Float8Tensor( - shape=shape, - dtype=dtype, - data=data, - fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory), - fp8_dtype=self.dtype, - requires_grad=requires_grad, - data_transpose=data_transpose, - quantizer=self, - device=device, - ) - def calibrate(self, tensor: torch.Tensor) -> None: amin, amax = tensor.aminmax() self.amax.copy_(torch.max(-amin, amax)) @@ -337,48 +294,6 @@ def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" return tex.quantize(tensor, self) - def make_empty( - self, - shape: Iterable[int], - *, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, - requires_grad: bool = False, - pin_memory: bool = False, - ) -> Float8Tensor: - - # Canonicalize tensor attributes - if device is None: - device = torch.device("cuda") - - # Allocate FP8 data - data = None - if self.rowwise_usage: - data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) - - # Allocate FP8 data transpose if needed - data_transpose = None - if self.columnwise_usage: - transpose_shape = [shape[-1]] + list(shape[:-1]) - data_transpose = torch.empty( - transpose_shape, - dtype=torch.uint8, - device=device, - pin_memory=pin_memory, - ) - # Construct FP8 tensor - return Float8Tensor( - shape=shape, - dtype=dtype, - data=data, - fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory), - fp8_dtype=self.dtype, - requires_grad=requires_grad, - data_transpose=data_transpose, - quantizer=self, - device=device, - ) - def calibrate(self, tensor: torch.Tensor) -> None: # current scaling don't need to calibrate return diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 5cab519c79..134f8b5a61 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -96,70 +96,6 @@ def is_quantizable(self, inp: torch.Tensor) -> bool: return False return True - def make_empty( - self, - shape: Iterable[int], - *, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, - requires_grad: bool = False, - pin_memory: bool = False, - ) -> MXFP8Tensor: - - # Canonicalize tensor attributes - if device is None: - device = torch.device("cuda") - - assert ( - shape[-1] % MXFP8_BLOCK_SCALING_SIZE == 0 - and math.prod(shape[:-1]) % MXFP8_BLOCK_SCALING_SIZE == 0 - ), ( - f"Incorrect shape {shape} for MXFP8. Tensor dims must be divisible by" - f" {MXFP8_BLOCK_SCALING_SIZE}" - ) - - # Allocate FP8 data - data = None - scale_inv = None - if self.rowwise_usage: - data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) - scale_inv = torch.empty( - round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), - round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), - dtype=torch.uint8, - device=device, - pin_memory=pin_memory, - ) - - # Allocate FP8 data transpose if needed - columnwise_data = None - columnwise_scale_inv = None - if self.columnwise_usage: - columnwise_data = torch.empty( - shape, dtype=torch.uint8, device=device, pin_memory=pin_memory - ) - columnwise_scale_inv = torch.empty( - round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), - round_up_to_nearest_multiple(shape[-1], 128), - dtype=torch.uint8, - device=device, - pin_memory=pin_memory, - ) - - # Construct FP8 tensor - return MXFP8Tensor( - shape=shape, - dtype=dtype, - fp8_dtype=self.dtype, - rowwise_data=data, - rowwise_scale_inv=scale_inv, - columnwise_data=columnwise_data, - columnwise_scale_inv=columnwise_scale_inv, - quantizer=self, - requires_grad=requires_grad, - with_gemm_swizzled_scales=self.optimize_for_gemm, - ) - def calibrate(self, tensor: torch.Tensor) -> None: # TODO(ksivamani): No calibration needed for mxfp8? pass diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 285a7f030a..df7a2b4bd3 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -297,99 +297,6 @@ def convert_shape_for_fp4(shape: Iterable[int]) -> Tuple[int, ...]: shape[-1] = shape[-1] // 2 return tuple(shape) - def make_empty( - self, - shape: Iterable[int], - *, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, - pin_memory: bool = False, - requires_grad: bool = False, - ) -> NVFP4Tensor: - - # Canonicalize tensor attributes - if device is None: - device = torch.device("cuda") - - assert shape[-1] % NVFP4_BLOCK_SCALING_SIZE == 0, ( - f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" - f" {NVFP4_BLOCK_SCALING_SIZE}" - ) - - flat_first_dim = math.prod(shape[:-1]) - assert flat_first_dim % NVFP4_BLOCK_SCALING_SIZE == 0, ( - f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" - f" {NVFP4_BLOCK_SCALING_SIZE}" - ) - if self.row_scaled_nvfp4: - if not self.rowwise_usage: - raise ValueError("Row-scaled NVFP4 quantization requires rowwise usage.") - if self.columnwise_usage: - raise ValueError("Row-scaled NVFP4 quantization does not support columnwise usage.") - - # Allocate FP4 data - data = None - scale_inv = None - amax_rowwise = None - if self.rowwise_usage: - data = torch.empty( - self.convert_shape_for_fp4(shape), - dtype=torch.uint8, - device=device, - pin_memory=pin_memory, - ) - scale_shape = self.get_scale_shape(shape, columnwise=False) - scale_inv = torch.empty( - scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory - ) - # Allocate global amax metadata. Row-scaled NVFP4 stores one value per row. - amax_rows = flat_first_dim if self.row_scaled_nvfp4 else 1 - amax_rowwise = torch.zeros( - amax_rows, dtype=torch.float32, device=device, pin_memory=pin_memory - ) - - # Allocate FP8 data transpose if needed - columnwise_data = None - columnwise_scale_inv = None - amax_columnwise = None - if self.columnwise_usage: - # enforce 2D shape to avoid [S, B, H] shape and B and be 1 - # and the transposed shape is [H, S, B], so divide last dim by 2 gives zero - shape_2d = tuple([flat_first_dim, shape[-1]]) - columnwise_data = torch.empty( - self.convert_shape_for_fp4(self.get_columnwise_shape(shape_2d)), - dtype=torch.uint8, - device=device, - pin_memory=pin_memory, - ) - columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) - columnwise_scale_inv = torch.empty( - columnwise_scale_shape, - dtype=torch.uint8, - device=device, - pin_memory=pin_memory, - ) - amax_columnwise = torch.zeros( - 1, dtype=torch.float32, device=device, pin_memory=pin_memory - ) - - # Construct FP8 tensor - return NVFP4Tensor( - shape=shape, - dtype=dtype, - rowwise_data=data, - rowwise_scale_inv=scale_inv, - columnwise_data=columnwise_data, - columnwise_scale_inv=columnwise_scale_inv, - amax_rowwise=amax_rowwise, - amax_columnwise=amax_columnwise, - fp4_dtype=self.dtype, - quantizer=self, - requires_grad=requires_grad, - with_gemm_swizzled_scales=False, - row_scaled_nvfp4=self.row_scaled_nvfp4, - ) - def calibrate(self, tensor: torch.Tensor) -> None: pass # Calibration is no-op