diff --git a/docs/Doxyfile b/docs/Doxyfile index f17ffc297b..2c593e4594 100644 --- a/docs/Doxyfile +++ b/docs/Doxyfile @@ -93,14 +93,6 @@ ALLOW_UNICODE_NAMES = NO OUTPUT_LANGUAGE = English -# The OUTPUT_TEXT_DIRECTION tag is used to specify the direction in which all -# documentation generated by doxygen is written. Doxygen will use this -# information to generate all generated output in the proper direction. -# Possible values are: None, LTR, RTL and Context. -# The default value is: None. - -OUTPUT_TEXT_DIRECTION = None - # If the BRIEF_MEMBER_DESC tag is set to YES, doxygen will include brief member # descriptions after the members that are listed in the file and class # documentation (similar to Javadoc). Set to NO to disable this. @@ -263,12 +255,6 @@ TAB_SIZE = 2 ALIASES = -# This tag can be used to specify a number of word-keyword mappings (TCL only). -# A mapping has the form "name=value". For example adding "class=itcl::class" -# will allow you to use the command class in the itcl::class meaning. - -TCL_SUBST = - # Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources # only. Doxygen will then generate output that is more tailored for C. For # instance, some of the names that are used will be different. The list of all @@ -1156,13 +1142,6 @@ CLANG_DATABASE_PATH = ALPHABETICAL_INDEX = YES -# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in -# which the alphabetical index list will be split. -# Minimum value: 1, maximum value: 20, default value: 5. -# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. - -COLS_IN_ALPHA_INDEX = 5 - # In case all classes in a project start with a common prefix, all classes will # be put under the same header in the alphabetical index. The IGNORE_PREFIX tag # can be used to specify a prefix (or a list of prefixes) that should be ignored @@ -1290,15 +1269,6 @@ HTML_COLORSTYLE_SAT = 100 HTML_COLORSTYLE_GAMMA = 80 -# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML -# page will contain the date and time when the page was generated. Setting this -# to YES can help to show when doxygen was last run and thus if the -# documentation is up to date. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_TIMESTAMP = NO - # If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML # documentation will contain a main index with vertical navigation menus that # are dynamically created via JavaScript. If disabled, the navigation index will @@ -1580,17 +1550,6 @@ EXT_LINKS_IN_WINDOW = NO FORMULA_FONTSIZE = 10 -# Use the FORMULA_TRANSPARENT tag to determine whether or not the images -# generated for formulas are transparent PNGs. Transparent PNGs are not -# supported properly for IE 6.0, but are supported on all modern browsers. -# -# Note that when changing this option you need to delete any form_*.png files in -# the HTML output directory before the changes have effect. -# The default value is: YES. -# This tag requires that the tag GENERATE_HTML is set to YES. - -FORMULA_TRANSPARENT = YES - # The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands # to create new LaTeX commands to be used in formulas as building blocks. See # the section "Including formulas" for details. @@ -1889,16 +1848,6 @@ LATEX_BATCHMODE = NO LATEX_HIDE_INDICES = NO -# If the LATEX_SOURCE_CODE tag is set to YES then doxygen will include source -# code with syntax highlighting in the LaTeX output. -# -# Note that which sources are shown also depends on other settings such as -# SOURCE_BROWSER. -# The default value is: NO. -# This tag requires that the tag GENERATE_LATEX is set to YES. - -LATEX_SOURCE_CODE = NO - # The LATEX_BIB_STYLE tag can be used to specify the style to use for the # bibliography, e.g. plainnat, or ieeetr. See # https://en.wikipedia.org/wiki/BibTeX and \cite for more info. @@ -1907,14 +1856,6 @@ LATEX_SOURCE_CODE = NO LATEX_BIB_STYLE = plain -# If the LATEX_TIMESTAMP tag is set to YES then the footer of each generated -# page will contain the date and time when the page was generated. Setting this -# to NO can help when comparing the output of multiple runs. -# The default value is: NO. -# This tag requires that the tag GENERATE_LATEX is set to YES. - -LATEX_TIMESTAMP = NO - # The LATEX_EMOJI_DIRECTORY tag is used to specify the (relative or absolute) # path from which the emoji images will be read. If a relative path is entered, # it will be relative to the LATEX_OUTPUT directory. If left blank the @@ -1979,16 +1920,6 @@ RTF_STYLESHEET_FILE = RTF_EXTENSIONS_FILE = -# If the RTF_SOURCE_CODE tag is set to YES then doxygen will include source code -# with syntax highlighting in the RTF output. -# -# Note that which sources are shown also depends on other settings such as -# SOURCE_BROWSER. -# The default value is: NO. -# This tag requires that the tag GENERATE_RTF is set to YES. - -RTF_SOURCE_CODE = NO - #--------------------------------------------------------------------------- # Configuration options related to the man page output #--------------------------------------------------------------------------- @@ -2085,15 +2016,6 @@ GENERATE_DOCBOOK = NO DOCBOOK_OUTPUT = docbook -# If the DOCBOOK_PROGRAMLISTING tag is set to YES, doxygen will include the -# program listings (including syntax highlighting and cross-referencing -# information) to the DOCBOOK output. Note that enabling this will significantly -# increase the size of the DOCBOOK output. -# The default value is: NO. -# This tag requires that the tag GENERATE_DOCBOOK is set to YES. - -DOCBOOK_PROGRAMLISTING = NO - #--------------------------------------------------------------------------- # Configuration options for the AutoGen Definitions output #--------------------------------------------------------------------------- diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index db86498005..99d850d04d 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -38,7 +38,7 @@ PyTorch :members: reset, get_states, set_states, add, fork -.. autoapifunction:: transformer_engine.pytorch.autocast +.. autoapiclass:: transformer_engine.pytorch.autocast(enabled=True, calibrating=False, recipe=None, amax_reduction_group=None) .. autoapifunction:: transformer_engine.pytorch.quantized_model_init diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index 8340d2010f..c0a095d6b5 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -143,5 +143,6 @@ wait # Final cleanup (trap will also call cleanup on exit) cleanup +wait exit $HAS_FAILURE diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh index 3c1f2ba1fb..4242c77c11 100644 --- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -98,5 +98,6 @@ wait # Final cleanup (trap will also call cleanup on exit) cleanup +wait exit $HAS_FAILURE diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index a0ae7dde82..1549a292d8 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -643,15 +643,35 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, TRANSFORMER_ENGINE_SWITCH_CONDITION( with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, + // The specialized rowwise cast-only kernel vectorizes full 128-element chunks. + // Shapes with a partial row tail (for example, N=48) must use the generic kernel, + // otherwise the last chunk reads/writes past the logical end of the row. + using rowwise_traits = specialized::CastTraits; + using bidimensional_traits = specialized::CastTraits; + constexpr size_t max_grid_dim_y = 65535; + const bool rowwise_specialized_grid_fits = + ((rows + rowwise_traits::blockDimM - 1) / rowwise_traits::blockDimM) <= + max_grid_dim_y; + const bool bidimensional_specialized_grid_fits = + ((rows + bidimensional_traits::blockDIM::M - 1) / + bidimensional_traits::blockDIM::M) <= max_grid_dim_y; + + const bool is_full_rowwise_chunk = (cols % 128 == 0); + const bool scaling_type_has_specialized_support = + (scaling_type == ScalingType::ROWWISE && is_full_rowwise_chunk && + rowwise_specialized_grid_fits) || + (scaling_type == ScalingType::BIDIMENSIONAL && + bidimensional_specialized_grid_fits); + if (specialized::hasSpec() && - !WITH_GEMM_SWIZZLED_SCALES) { + !WITH_GEMM_SWIZZLED_SCALES && scaling_type_has_specialized_support) { switch (scaling_type) { case ScalingType::ROWWISE: { using traits = specialized::CastTraits; auto kernel = specialized::quantize_mxfp8_kernel_cast_only; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - traits::smem); + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, traits::smem)); dim3 block(traits::threadLayout::num, traits::warpLayout::N, traits::warpLayout::M); @@ -664,16 +684,12 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } - case ScalingType::COLWISE: { - NVTE_WARN("Colwise scaling will fallback to original kernel."); - break; - } case ScalingType::BIDIMENSIONAL: { using traits = specialized::CastTraits; auto kernel = specialized::quantize_mxfp8_kernel_cast_only; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - traits::smem); + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, traits::smem)); // TMA for loading, so that we don't need STS for transposing alignas(64) CUtensorMap tensor_map_input{}; constexpr size_t input_type_bit_size = TypeInfo::size; @@ -710,6 +726,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, NVTE_ERROR("Invalid scaling type."); } } + NVTE_CHECK_CUDA(cudaGetLastError()); return; } @@ -789,7 +806,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); break; } case ScalingType::COLWISE: { @@ -804,7 +820,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); break; } case ScalingType::BIDIMENSIONAL: { @@ -819,10 +834,9 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); break; } - } + } NVTE_CHECK_CUDA(cudaGetLastError()); if constexpr (IS_DBIAS) { common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); diff --git a/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh index 41e62ac319..9459f0273a 100644 --- a/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh @@ -91,18 +91,6 @@ __device__ __forceinline__ e8m0_t to_e8m0(IType amax) { #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } // anonymous namespace -inline bool is_cast_only_enabled() { - static bool enabled = []() { - const char *env = std::getenv("ENABLE_CAST_ONLY"); - return env != nullptr && (env[0] == '1'); - }(); - return enabled; - - // // FIXME: when finish debugging, remove this - // const char* env = std::getenv("ENABLE_CAST_ONLY"); - // return env != nullptr && (env[0] == '1'); -} - template inline bool hasSpec() { return false; @@ -112,19 +100,19 @@ inline bool hasSpec() { // OType could be [fp8e5m2, fp8e4m3] template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 9fe692dd2d..a99e0946ef 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -108,7 +108,7 @@ void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA /*! \brief Set an option in matrix multiplication configuration. * - * \param[in/out] config Matrix multiplication configuration. + * \param[in,out] config Matrix multiplication configuration. * \param[in] attr Option type. * \param[in] buf Memory address to read option value from. * \param[in] size_in_bytes Size of buf. @@ -298,39 +298,6 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); -/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ -/*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C - * - * \note Requires cuBLAS 13.2+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture. - * Will error at runtime if compiled with an older cuBLAS version or run on - * a pre-Blackwell GPU. - * - * Performs batched GEMM on a collection of matrices with potentially different shapes. - * All tensors in the group must have compatible dimensions for matrix multiplication. - * Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous - * memory layout and shape metadata. - * - * \param[in] A Input grouped tensor A. - * \param[in] transa Whether to transpose A matrices. - * \param[in] B Input grouped tensor B. - * \param[in] transb Whether to transpose B matrices. - * \param[in] C Input grouped tensor C (can be NULL for beta=0). - * \param[out] D Output grouped tensor D. - * \param[in] alpha Scale multipliers for A @ B (NVTETensor with num_tensors elements). - * \param[in] beta Scale multipliers for C (NVTETensor with num_tensors elements). - * \param[in] workspace_setup Workspace tensor for pointer array setup. - * \param[in] workspace_cublas Workspace tensor for cuBLAS operations. - * \param[in] config Additional configuration (can be NULL for defaults). - * \param[in] stream CUDA stream for the operation. - * - * Requirements: - * - cuBLAS 13.2+ (CUDA 13.1+) - * - Blackwell (SM100) or newer GPU architecture - * - A, B, C (if provided), D must have the same num_tensors - * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] - * - Shape compatibility: if transa=false, transb=false: - * - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i]) - */ /*! \brief Return the required size in bytes for the setup workspace of grouped GEMM. * * The setup workspace stores pointer arrays and per-matrix dimension arrays used @@ -385,6 +352,39 @@ void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *ds void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets, size_t n_groups, int64_t last_dim, cudaStream_t stream); +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C + * + * \note Requires cuBLAS 13.2+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture. + * Will error at runtime if compiled with an older cuBLAS version or run on + * a pre-Blackwell GPU. + * + * Performs batched GEMM on a collection of matrices with potentially different shapes. + * All tensors in the group must have compatible dimensions for matrix multiplication. + * Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous + * memory layout and shape metadata. + * + * \param[in] A Input grouped tensor A. + * \param[in] transa Whether to transpose A matrices. + * \param[in] B Input grouped tensor B. + * \param[in] transb Whether to transpose B matrices. + * \param[in] C Input grouped tensor C (can be NULL for beta=0). + * \param[out] D Output grouped tensor D. + * \param[in] alpha Scale multipliers for A @ B (NVTETensor with num_tensors elements). + * \param[in] beta Scale multipliers for C (NVTETensor with num_tensors elements). + * \param[in] workspace_setup Workspace tensor for pointer array setup. + * \param[in] workspace_cublas Workspace tensor for cuBLAS operations. + * \param[in] config Additional configuration (can be NULL for defaults). + * \param[in] stream CUDA stream for the operation. + * + * Requirements: + * - cuBLAS 13.2+ (CUDA 13.1+) + * - Blackwell (SM100) or newer GPU architecture + * - A, B, C (if provided), D must have the same num_tensors + * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] + * - Shape compatibility: if transa=false, transb=false: + * - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i]) + */ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, const NVTETensor beta, NVTETensor workspace_setup, @@ -398,8 +398,19 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT * instead of NVTEGroupedTensor. This enables discrete per-expert weights as inputA * for Grouped GEMM. * - * \param[in] A_list List of A tensors (length = num_tensors). + * \param[in] A_list List of A tensors (length = num_a_tensors). * \param[in] num_a_tensors Number of tensors in A_list. + * \param[in] transa Whether to transpose A matrices. + * \param[in] B Input grouped tensor B. + * \param[in] transb Whether to transpose B matrices. + * \param[in] C Input grouped tensor C (can be NULL for beta=0). + * \param[out] D Output grouped tensor D. + * \param[in] alpha Scale multipliers for A @ B (NVTETensor with num_tensors elements). + * \param[in] beta Scale multipliers for C (NVTETensor with num_tensors elements). + * \param[in] workspace_setup Workspace tensor for pointer array setup. + * \param[in] workspace_cublas Workspace tensor for cuBLAS operations. + * \param[in] config Additional configuration (can be NULL for defaults). + * \param[in] stream CUDA stream for the operation. */ void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num_a_tensors, int transa, const NVTEGroupedTensor B, int transb, @@ -415,10 +426,20 @@ void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num * instead of NVTEGroupedTensor. This enables accumulation into non-contiguous * per-expert buffers (for wgrads). * -* \param[in] C_list Optional list of C tensors (length = num_tensors). +* \param[in] A Input grouped tensor A. +* \param[in] transa Whether to transpose A matrices. +* \param[in] B Input grouped tensor B. +* \param[in] transb Whether to transpose B matrices. +* \param[in] C_list Optional list of C tensors (length = num_c_tensors). * \param[in] num_c_tensors Number of tensors in C_list (Can be 0 if C is not provided). -* \param[out] D_list List of D tensors (length = num_tensors). +* \param[out] D_list List of D tensors (length = num_d_tensors). * \param[in] num_d_tensors Number of tensors in D_list. +* \param[in] alpha Scale multipliers for A @ B (NVTETensor with num_tensors elements). +* \param[in] beta Scale multipliers for C (NVTETensor with num_tensors elements). +* \param[in] workspace_setup Workspace tensor for pointer array setup. +* \param[in] workspace_cublas Workspace tensor for cuBLAS operations. +* \param[in] config Additional configuration (can be NULL for defaults). +* \param[in] stream CUDA stream for the operation. * \note All tensors in C_list and D_list must share the same dtype. */ void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa, diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 488f259150..045ae88893 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -282,7 +282,7 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream); * * \warning Deprecated in favor of nvte_set_tensor_param_v2. * - * \param[in/out] tensor Tensor. + * \param[in,out] tensor Tensor. * \param[in] param_name The parameter to be set. * \param[in] param The value to be set. */ @@ -300,7 +300,7 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p /*! \brief Set a tensor parameter. * - * \param[in/out] tensor Tensor. + * \param[in,out] tensor Tensor. * \param[in] param Tensor parameter type. * \param[in] buf Memory address to read parameter value. * \param[in] size_in_bytes Size of buf. @@ -406,7 +406,7 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, /*! \brief Set an option in quantization config. * - * \param[in/out] config Quantization config. + * \param[in,out] config Quantization config. * \param[in] attr Option type. * \param[in] buf Memory address to read option value. * \param[in] size_in_bytes Size of buf. @@ -510,7 +510,7 @@ void nvte_destroy_grouped_tensor(NVTEGroupedTensor tensor); /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Set a grouped tensor parameter. * - * \param[in/out] tensor Grouped tensor. + * \param[in,out] tensor Grouped tensor. * \param[in] param Grouped tensor parameter type. * \param[in] buf Memory address to read parameter value. * \param[in] size_in_bytes Size of buf. diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 995ecf31b4..3db0417bdb 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1064,11 +1064,10 @@ def cp_p2p_fwd_flash_attn( **fa_forward_kwargs, ) rng_states = None - if not fa_utils.v2_7_0_plus: + if not use_flash_attn_3 and not fa_utils.v2_7_0_plus: out_per_step = fa_outputs[4] softmax_lse_per_step = fa_outputs[5] - if not use_flash_attn_3: - rng_states = fa_outputs[7] + rng_states = fa_outputs[7] else: out_per_step = fa_outputs[0] softmax_lse_per_step = fa_outputs[1] @@ -3255,11 +3254,10 @@ def forward( causal=causal, **fa_forward_kwargs, ) - if not fa_utils.v2_7_0_plus: + if not use_flash_attn_3 and not fa_utils.v2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] + rng_states[i] = fa_outputs[7] else: out_per_step[i] = fa_outputs[0] softmax_lse_per_step[i] = fa_outputs[1] @@ -4086,9 +4084,9 @@ def forward( causal=causal, **fa_forward_kwargs, ) - if not fa_utils.v2_7_0_plus: + if not use_flash_attn_3 and not fa_utils.v2_7_0_plus: out_, softmax_lse = fa_outputs[4], fa_outputs[5] - rng_state = fa_outputs[7] if not use_flash_attn_3 else None + rng_state = fa_outputs[7] else: out_, softmax_lse = fa_outputs[0], fa_outputs[1] rng_state = fa_outputs[3] if not use_flash_attn_3 else None diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index ed10909b8a..c81c18e64f 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -307,8 +307,9 @@ def start_offload(self): # needed to restore pre-offload state after reload. self.aux = aux - self.finish_offload_event = torch.cuda.Event() - self.finish_offload_event.record(self.offload_stream) + if len(self.fwd_gpu_tensor_group.tensor_list) > 0: + self.finish_offload_event = torch.cuda.Event() + self.finish_offload_event.record(self.offload_stream) def release_activation_forward_gpu_memory(self): """ @@ -319,13 +320,13 @@ def release_activation_forward_gpu_memory(self): func_name="release_activation_forward_gpu_memory", allowed_states=["offload_started"] ) self.state = "offload_finished" + if len(self.fwd_gpu_tensor_group.tensor_list) > 0: + torch.cuda.current_stream().wait_event(self.finish_offload_event) # type: ignore[arg-type] - torch.cuda.current_stream().wait_event(self.finish_offload_event) # type: ignore[arg-type] - - # GPU memory can be released safely after the offload. - # Notice that the memory needs to be kept alive when GPU->CPU copy is performed. - self.fwd_gpu_tensor_group = TensorGroup() - del self.finish_offload_event + # GPU memory can be released safely after the offload. + # Notice that the memory needs to be kept alive when GPU->CPU copy is performed. + self.fwd_gpu_tensor_group = TensorGroup() + del self.finish_offload_event def start_reload(self): """ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4a2ea7412b..9b10a9c5a4 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -309,6 +309,16 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w py::object ln_out, py::handle quantizer, DType otype, const int sm_margin, const bool zero_centered_gamma); +/*************************************************************************************************** + * Memory allocation + **************************************************************************************************/ + +// Allocates tensors all backed by a single contiguous buffer. +std::vector bulk_allocate(const std::vector> &shapes, + const std::vector &dtypes, + std::optional device = std::nullopt, + std::optional> alignments = std::nullopt); + /*************************************************************************************************** * Cast **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/allocate.cpp b/transformer_engine/pytorch/csrc/extensions/allocate.cpp new file mode 100644 index 0000000000..f972f8a2d2 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/allocate.cpp @@ -0,0 +1,86 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include "../extensions.h" + +namespace transformer_engine { +namespace pytorch { + +std::vector bulk_allocate(const std::vector> &shapes, + const std::vector &dtypes, + std::optional device, + std::optional> alignments) { + // Check shapes and dtypes + const size_t n = shapes.size(); + NVTE_CHECK(dtypes.size() == n, "Got ", shapes.size(), " shapes and ", dtypes.size(), " dtypes."); + NVTE_CHECK(!alignments || alignments->size() == n, "Got ", shapes.size(), " shapes and ", + alignments->size(), " alignments."); + + // Return immediately if no tensors are needed + if (n == 0) return {}; + + // Set defaults for optional arguments + if (!device) { + device = c10::Device(c10::kCUDA); + } + if (!alignments) { + alignments = std::vector{}; + alignments->reserve(n); + for (const auto &dtype : dtypes) { + alignments->push_back(c10::elementSize(dtype)); + } + } + + // Compute offsets in base buffer + std::vector byte_sizes(n); + std::vector offsets(n); + size_t base_byte_size = 0; + size_t base_alignment = 1; + for (size_t i = 0; i < n; ++i) { + byte_sizes[i] = product(shapes[i]) * at::elementSize(dtypes[i]); + offsets[i] = roundup(base_byte_size, (*alignments)[i]); + base_byte_size = offsets[i] + byte_sizes[i]; + base_alignment = std::max(base_alignment, (*alignments)[i]); + } + if (base_alignment > 1) { + // Pad in case data pointer is not aligned + base_byte_size += base_alignment; + } + + // Allocate base buffer + auto base_buffer = std::make_shared( + at::empty({static_cast(base_byte_size)}, at::device(*device).dtype(torch::kUInt8))); + uint8_t *base_ptr = base_buffer->data_ptr(); + base_ptr = + reinterpret_cast(roundup(reinterpret_cast(base_ptr), base_alignment)); + + // Create views into base buffer + std::vector out; + out.reserve(n); + std::vector shape_int64; + for (size_t i = 0; i < n; ++i) { + shape_int64.assign(shapes[i].begin(), shapes[i].end()); + if (byte_sizes[i] == 0) { + // Work around problems with from_blob when constructing an + // empty tensor. Passing a null pointer fails because it checks + // that the pointer is on GPU. Passing a non-null pointer can + // cause bugs in TE kernels. + out.emplace_back(at::empty(shape_int64, at::device(*device).dtype(dtypes[i]))); + } else { + // Construct tensor with custom deleter to keep base buffer alive + out.emplace_back(at::from_blob( + base_ptr + offsets[i], shape_int64, [base_buffer](void *) {}, + at::device(*device).dtype(dtypes[i]))); + } + } + return out; +} + +} // namespace pytorch +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 00f4383ab6..3ada2459c8 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -495,60 +495,30 @@ std::tuple, std::vector> bulk_allocate_fp const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto is_2D_scaled = scaling_mode == NVTE_BLOCK_SCALING_2D; const auto fp8_dtype = quantizer_cpp_list[0]->dtype; - constexpr size_t fp8_elem_size = 1; - constexpr size_t scale_elem_size = 4; - - // Helper function to construct tensor view - // Note: Deleter holds a shared_ptr for the buffer, so the buffer - // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, - size_t offset, at::ScalarType dtype) -> at::Tensor { - std::vector shape_int64(shape.begin(), shape.end()); - bool is_empty_shape = product(shape) == 0; - if (buffer->data_ptr() == nullptr || is_empty_shape) { - return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); - } - return at::from_blob( - buffer->data_ptr() + offset, shape_int64, - [buffer](void *) {}, // deleter holds shared_ptr - at::device(at::kCUDA).dtype(dtype)); - }; // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list; std::vector> rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { - // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); } - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets; - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 256); // align to 256B - data_offsets.push_back(buffer_size); - buffer_size += product(rowwise_data_shapes[i]) * fp8_elem_size; - } - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 16); // align to 16B - scale_offsets.push_back(buffer_size); - buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size; - } + // Bulk-allocate data and scale tensors + std::vector> shapes = rowwise_data_shapes; + std::vector dtypes(num_tensors, torch::kUInt8); + std::vector alignments(num_tensors, 256); + shapes.insert(shapes.end(), rowwise_scale_shapes.begin(), rowwise_scale_shapes.end()); + dtypes.insert(dtypes.end(), num_tensors, torch::kFloat32); + alignments.insert(alignments.end(), num_tensors, 16); + auto tensors = bulk_allocate(shapes, dtypes, std::nullopt, alignments); - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - - // Construct tensor views + // Split data and scale tensors for (size_t i = 0; i < num_tensors; ++i) { - rowwise_data_list.emplace_back( - make_torch_view(buffer, rowwise_data_shapes[i], data_offsets[i], torch::kUInt8)); - rowwise_scale_list.emplace_back( - make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kFloat32)); + rowwise_data_list.emplace_back(std::move(tensors[i])); + rowwise_scale_list.emplace_back(std::move(tensors[num_tensors + i])); } } @@ -556,7 +526,6 @@ std::tuple, std::vector> bulk_allocate_fp std::vector columnwise_data_list, columnwise_scale_list; std::vector> columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { - // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { columnwise_data_shapes.emplace_back(); auto &shape = columnwise_data_shapes.back(); @@ -568,30 +537,19 @@ std::tuple, std::vector> bulk_allocate_fp quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); } - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets; - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 256); // align to 256B - data_offsets.push_back(buffer_size); - buffer_size += product(columnwise_data_shapes[i]) * fp8_elem_size; - } - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 16); // align to 16B - scale_offsets.push_back(buffer_size); - buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size; - } - - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + // Bulk-allocate data and scale tensors + std::vector> shapes = columnwise_data_shapes; + std::vector dtypes(num_tensors, torch::kUInt8); + std::vector alignments(num_tensors, 256); + shapes.insert(shapes.end(), columnwise_scale_shapes.begin(), columnwise_scale_shapes.end()); + dtypes.insert(dtypes.end(), num_tensors, torch::kFloat32); + alignments.insert(alignments.end(), num_tensors, 16); + auto tensors = bulk_allocate(shapes, dtypes, std::nullopt, alignments); - // Construct tensor views + // Split data and scale tensors for (size_t i = 0; i < num_tensors; ++i) { - columnwise_data_list.emplace_back( - make_torch_view(buffer, columnwise_data_shapes[i], data_offsets[i], torch::kUInt8)); - columnwise_scale_list.emplace_back( - make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kFloat32)); + columnwise_data_list.push_back(tensors[i]); + columnwise_scale_list.push_back(tensors[num_tensors + i]); } } @@ -648,60 +606,29 @@ std::tuple, std::vector> bulk_allocate_mx const auto fp8_dtype = quantizer_cpp_list[0]->dtype; const bool with_gemm_swizzled_scales = quantizer_cpp_list[0]->optimize_for_gemm; - constexpr size_t fp8_elem_size = 1; - constexpr size_t scale_elem_size = 1; - - // Helper function to construct tensor view - // Note: Deleter holds a shared_ptr for the buffer, so the buffer - // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, - size_t offset, at::ScalarType dtype) -> at::Tensor { - std::vector shape_int64(shape.begin(), shape.end()); - bool is_empty_shape = product(shape) == 0; - if (buffer->data_ptr() == nullptr || is_empty_shape) { - return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); - } - return at::from_blob( - buffer->data_ptr() + offset, shape_int64, - [buffer](void *) {}, // deleter holds shared_ptr - at::device(at::kCUDA).dtype(dtype)); - }; - // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list; std::vector> rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { - // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); } - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets; - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 256); // align to 256B - data_offsets.push_back(buffer_size); - buffer_size += product(rowwise_data_shapes[i]) * fp8_elem_size; - } - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 16); // align to 16B - scale_offsets.push_back(buffer_size); - buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size; - } - - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + // Bulk-allocate data and scale tensors + std::vector> shapes = rowwise_data_shapes; + std::vector dtypes(num_tensors, torch::kUInt8); + std::vector alignments(num_tensors, 256); + shapes.insert(shapes.end(), rowwise_scale_shapes.begin(), rowwise_scale_shapes.end()); + dtypes.insert(dtypes.end(), num_tensors, torch::kUInt8); + alignments.insert(alignments.end(), num_tensors, 16); + auto tensors = bulk_allocate(shapes, dtypes, std::nullopt, alignments); - // Construct tensor views + // Split data and scale tensors for (size_t i = 0; i < num_tensors; ++i) { - rowwise_data_list.emplace_back( - make_torch_view(buffer, rowwise_data_shapes[i], data_offsets[i], torch::kUInt8)); - rowwise_scale_list.emplace_back( - make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + rowwise_data_list.emplace_back(std::move(tensors[i])); + rowwise_scale_list.emplace_back(std::move(tensors[num_tensors + i])); } } @@ -709,7 +636,6 @@ std::tuple, std::vector> bulk_allocate_mx std::vector columnwise_data_list, columnwise_scale_list; std::vector> columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { - // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { // For MXFP8, the columnwise data doesn't need transpose // because of TN, NT, NN layout support in SM100 @@ -718,30 +644,19 @@ std::tuple, std::vector> bulk_allocate_mx quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); } - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets; - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 256); // align to 256B - data_offsets.push_back(buffer_size); - buffer_size += product(columnwise_data_shapes[i]) * fp8_elem_size; - } - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 16); // align to 16B - scale_offsets.push_back(buffer_size); - buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size; - } - - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + // Bulk-allocate data and scale tensors + std::vector> shapes = columnwise_data_shapes; + std::vector dtypes(num_tensors, torch::kUInt8); + std::vector alignments(num_tensors, 256); + shapes.insert(shapes.end(), columnwise_scale_shapes.begin(), columnwise_scale_shapes.end()); + dtypes.insert(dtypes.end(), num_tensors, torch::kUInt8); + alignments.insert(alignments.end(), num_tensors, 16); + auto tensors = bulk_allocate(shapes, dtypes, std::nullopt, alignments); - // Construct tensor views + // Split data and scale tensors for (size_t i = 0; i < num_tensors; ++i) { - columnwise_data_list.emplace_back( - make_torch_view(buffer, columnwise_data_shapes[i], data_offsets[i], torch::kUInt8)); - columnwise_scale_list.emplace_back( - make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + columnwise_data_list.push_back(tensors[i]); + columnwise_scale_list.push_back(tensors[num_tensors + i]); } } @@ -808,103 +723,70 @@ std::tuple, std::vector, bool> bulk_alloc const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto fp4_dtype = quantizer_cpp_list[0]->dtype; const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm; - constexpr size_t scale_elem_size = 1; - - // Helper function to construct tensor view - // Note: Deleter holds a shared_ptr for the buffer, so the buffer - // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, - size_t offset, at::ScalarType dtype) -> at::Tensor { - std::vector shape_int64(shape.begin(), shape.end()); - bool is_empty_shape = product(shape) == 0; - if (buffer->data_ptr() == nullptr || is_empty_shape) { - return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); - } - return at::from_blob( - buffer->data_ptr() + offset, shape_int64, - [buffer](void *) {}, // deleter holds shared_ptr - at::device(at::kCUDA).dtype(dtype)); - }; - // Lambda function for converting std::vector shape to NVFP4 shape (last dim divided by 2) - auto to_fp4_shape = [](const std::vector &shape) { - std::vector fp4_shape(shape.begin(), shape.end()); - if (!fp4_shape.empty()) { - fp4_shape.back() /= 2; - } - return fp4_shape; + // Helper function to get size of byte buffer holding FP4 data (last dim divided by 2) + auto fp4_byte_shape = [](const std::vector &shape) -> std::vector { + NVTE_CHECK(!shape.empty()); + NVTE_CHECK(shape.back() % 2 == 0); + std::vector out(shape.begin(), shape.end()); + out.back() /= 2; + return out; }; - auto flat_first_dim = [](const std::vector &shape) -> size_t { - if (shape.empty()) { - return 1; - } - size_t rows = 1; - for (size_t i = 0; i + 1 < shape.size(); ++i) { - rows *= shape[i]; + + // Helper function to get size of amax buffer + auto amax_shape = [](const std::vector &shape, + bool row_scaled = false) -> std::vector { + if (row_scaled) { + const auto [rows, _] = get_2d_dims(shape); + return {rows}; } - return rows; + return {1}; }; // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list, amax_rowwise_list; std::vector> rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { - // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); } - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets, amax_offsets; + // Check whether data and scales can be packed in contiguous + // buffer. Amaxes are not contiguous since they are aligned to + // 16B. for (size_t i = 0; i < num_tensors; ++i) { - // FP4 data is aligned to 256B - const auto offset = roundup(buffer_size, 256); - if (offset != buffer_size) { + if (product(rowwise_data_shapes[i]) / 2 % 256 != 0) { contiguous_data_and_scale = false; } - data_offsets.push_back(offset); - buffer_size = offset + (product(rowwise_data_shapes[i]) + 1) / 2; - } - for (size_t i = 0; i < num_tensors; ++i) { - // Scales are aligned to 16B - const auto offset = roundup(buffer_size, 16); - if (offset != buffer_size) { + if (product(rowwise_scale_shapes[i]) % 16 != 0) { contiguous_data_and_scale = false; } - scale_offsets.push_back(offset); - buffer_size = offset + product(rowwise_scale_shapes[i]) * scale_elem_size; } + + // Bulk-allocate tensors data, scale, and amax tensors + std::vector> shapes; + for (size_t i = 0; i < num_tensors; ++i) { + shapes.emplace_back(fp4_byte_shape(rowwise_data_shapes[i])); + } + std::vector dtypes(num_tensors, torch::kUInt8); + std::vector alignments(num_tensors, 256); + shapes.insert(shapes.end(), rowwise_scale_shapes.begin(), rowwise_scale_shapes.end()); + dtypes.insert(dtypes.end(), num_tensors, torch::kUInt8); + alignments.insert(alignments.end(), num_tensors, 16); for (size_t i = 0; i < num_tensors; ++i) { - // Amaxes (FP32) are aligned to 16B - // Note: Multi-quantize kernel does not require contiguous amaxes. - const auto offset = roundup(buffer_size, 16); - amax_offsets.push_back(offset); - size_t amax_size = 4; - if (row_scaled_nvfp4) { - amax_size *= flat_first_dim(rowwise_data_shapes[i]); - } - buffer_size = offset + amax_size; + shapes.emplace_back(amax_shape(rowwise_data_shapes[i], row_scaled_nvfp4)); } + dtypes.insert(dtypes.end(), num_tensors, torch::kFloat32); + alignments.insert(alignments.end(), num_tensors, 16); + auto tensors = bulk_allocate(shapes, dtypes, std::nullopt, alignments); - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - - // Construct tensor views + // Split data, scale, and amax tensors for (size_t i = 0; i < num_tensors; ++i) { - rowwise_data_list.emplace_back(make_torch_view(buffer, to_fp4_shape(rowwise_data_shapes[i]), - data_offsets[i], torch::kUInt8)); - rowwise_scale_list.emplace_back( - make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); - std::vector amax_shape{1}; - if (row_scaled_nvfp4) { - amax_shape = {flat_first_dim(rowwise_data_shapes[i])}; - } - amax_rowwise_list.emplace_back( - make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); + rowwise_data_list.push_back(tensors[i]); + rowwise_scale_list.push_back(tensors[num_tensors + i]); + amax_rowwise_list.push_back(tensors[2 * num_tensors + i]); } } @@ -912,7 +794,6 @@ std::tuple, std::vector, bool> bulk_alloc std::vector columnwise_data_list, columnwise_scale_list, amax_columnwise_list; std::vector> columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { - // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { // push the transposed shape into NVFP4 columnwise shape // NVFP4 on SM100 is TN only @@ -926,47 +807,40 @@ std::tuple, std::vector, bool> bulk_alloc quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); } - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets, amax_offsets; + // Check whether data and scales can be packed in contiguous + // buffer. Amaxes are not contiguous since they are aligned to + // 16B. for (size_t i = 0; i < num_tensors; ++i) { - // FP4 data is aligned to 256B - const auto offset = roundup(buffer_size, 256); - if (offset != buffer_size) { + if (product(columnwise_data_shapes[i]) / 2 % 256 != 0) { contiguous_data_and_scale = false; } - data_offsets.push_back(offset); - buffer_size = offset + (product(columnwise_data_shapes[i]) + 1) / 2; - } - for (size_t i = 0; i < num_tensors; ++i) { - // Scales are aligned to 16B - const auto offset = roundup(buffer_size, 16); - if (offset != buffer_size) { + if (product(columnwise_scale_shapes[i]) % 16 != 0) { contiguous_data_and_scale = false; } - scale_offsets.push_back(offset); - buffer_size = offset + product(columnwise_scale_shapes[i]) * scale_elem_size; } + + // Bulk-allocate tensors data, scale, and amax tensors + std::vector> shapes; for (size_t i = 0; i < num_tensors; ++i) { - // Amaxes (FP32) are aligned to 16B - // Note: Multi-quantize kernel does not require contiguous amaxes. - const auto offset = roundup(buffer_size, 16); - amax_offsets.push_back(offset); - buffer_size = offset + 4; + shapes.emplace_back(fp4_byte_shape(columnwise_data_shapes[i])); } + std::vector dtypes(num_tensors, torch::kUInt8); + std::vector alignments(num_tensors, 256); + shapes.insert(shapes.end(), columnwise_scale_shapes.begin(), columnwise_scale_shapes.end()); + dtypes.insert(dtypes.end(), num_tensors, torch::kUInt8); + alignments.insert(alignments.end(), num_tensors, 16); + for (size_t i = 0; i < num_tensors; ++i) { + shapes.emplace_back(amax_shape(columnwise_data_shapes[i])); + } + dtypes.insert(dtypes.end(), num_tensors, torch::kFloat32); + alignments.insert(alignments.end(), num_tensors, 16); + auto tensors = bulk_allocate(shapes, dtypes, std::nullopt, alignments); - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - - // Construct tensor views + // Split data, scale, and amax tensors for (size_t i = 0; i < num_tensors; ++i) { - columnwise_data_list.emplace_back(make_torch_view( - buffer, to_fp4_shape(columnwise_data_shapes[i]), data_offsets[i], torch::kUInt8)); - columnwise_scale_list.emplace_back( - make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); - amax_columnwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + columnwise_data_list.push_back(tensors[i]); + columnwise_scale_list.push_back(tensors[num_tensors + i]); + amax_columnwise_list.push_back(tensors[2 * num_tensors + i]); } } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index eb7576d905..a813f3119d 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -352,6 +352,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"), py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"), py::arg("out_dtype"), py::call_guard()); + // NVFP4 2D m.def("nvfp4_2d_compute_partial_amax", &transformer_engine::pytorch::nvfp4_2d_compute_partial_amax, @@ -404,6 +405,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "In-place swizzle of grouped tensor scales for GEMM", py::arg("tensor"), py::arg("rowwise"), py::arg("columnwise")); + // Tensor allocation + m.def("bulk_allocate", &transformer_engine::pytorch::bulk_allocate, + "Allocate tensors backed by a single contiguous buffer", py::arg("shapes"), + py::arg("dtypes"), py::arg("device") = py::none(), py::arg("alignments") = py::none(), + py::call_guard()); + // attention kernels m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd, "Prepare QKV for Flash Attention", py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 2b29f260e7..82dfe4d222 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -7,6 +7,7 @@ #include #include "common.h" +#include "common/util/cuda_runtime.h" #include "common/util/system.h" #include "pybind.h" #include "torch/torch.h" @@ -2264,7 +2265,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT bool eligible_for_rht_cast_fusion = - input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; + input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0 && + transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110; // Stochastic rounding // When both rowwise and columnwise quantization are used with RHT, diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e950f26571..627144345c 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -496,10 +496,13 @@ def backward( if ctx.fuse_wgrad_accumulation: wgrad_list = main_grads else: - wgrad_list = [ - torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) - for w in weights - ] + weight_shape = list(weights[0].size()) + wgrad_list = tex.bulk_allocate( + [weight_shape] * ctx.num_gemms, + [ctx.activation_dtype] * ctx.num_gemms, + ctx.device, + [256] * ctx.num_gemms, # alignment + ) if ctx.save_original_input: inp = inputmats[0] diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index e698c2697f..1f00d92284 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -1393,10 +1393,12 @@ def _fuser_backward_split_quantize( ] accumulate_into_main_grad = get_accumulate_flag_in_param(weights[0]) else: - grad_weights = [ - torch.empty(weight_shape, dtype=ctx.dtype, device=device) - for _ in range(num_groups) - ] + grad_weights = tex.bulk_allocate( + [weight_shape] * num_groups, + [ctx.dtype] * num_groups, + device, + [256] * num_groups, # alignment + ) final_weight_grads = list(grad_weights) # Perform dgrad GEMMs diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 320c7c39e5..a11d0505c1 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -194,9 +194,12 @@ def _compute_grad_params( w_list = [get_main_grad_from_param(w, op_label=op_label) for w in weights] accumulate_into_main_grad = get_accumulate_flag_in_param(weights[0]) else: - w_list = [ - torch.empty(weight_shape, dtype=dtype, device=device) for _ in range(num_groups) - ] + w_list = tex.bulk_allocate( + [weight_shape] * num_groups, + [dtype] * num_groups, + device, + [256] * num_groups, # alignment + ) wgrad_output = w_list if ctx.weight_requires_grad: