diff --git a/transformer_engine/common/include/transformer_engine/multi_tensor.h b/transformer_engine/common/include/transformer_engine/multi_tensor.h index 303801a88..77ddf7bbe 100644 --- a/transformer_engine/common/include/transformer_engine/multi_tensor.h +++ b/transformer_engine/common/include/transformer_engine/multi_tensor.h @@ -68,6 +68,28 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, int per_tensor, int max_chunks_per_tensor, cudaStream_t stream); +/*! \brief Computes cumulative L2 norm for a list of tensors from precomputed chunk metadata. + * + * \warning This API is **experimental** and subject to change. + */ +void nvte_multi_tensor_l2norm_cuda_custom(int chunk_size, NVTETensor noop_flag, + const NVTEDType input_dtype, const int64_t *addresses, + const int *sizes, const int *block_to_tensor, + const int *chunk_offsets, int total_chunks, + NVTETensor output, + NVTETensor ret, cudaStream_t stream); + +/*! \brief Computes cumulative L2 norm for a list of tensors after unscaling from precomputed + * chunk metadata. + * + * \warning This API is **experimental** and subject to change. + */ +void nvte_multi_tensor_unscale_l2norm_cuda_custom( + int chunk_size, NVTETensor noop_flag, const NVTEDType input_dtype, + const int64_t *addresses, const int *sizes, const int *block_to_tensor, + const int *chunk_offsets, int total_chunks, + NVTETensor output, NVTETensor ret, NVTETensor inv_scale, cudaStream_t stream); + /*! \brief Compute and apply gradient update to parameters for Adam optimizer. * * \warning This API is **experimental** and subject to change. @@ -120,6 +142,34 @@ void nvte_multi_tensor_adam_param_remainder_cuda( const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, cudaStream_t stream); +/*! \brief Compute and apply gradient update to parameters for Adam optimizer + * where the master parameters only store the remainder bits. + * Uses precomputed chunk metadata instead of TensorListMetadata. + * + * \warning This API is **experimental** and subject to change. + */ +void nvte_multi_tensor_adam_param_remainder_cuda_custom( + int chunk_size, NVTETensor noop_flag, NVTEDType grad_dtype, + int64_t *addresses, int64_t *sizes, int *block_to_tensor, int *chunk_offsets, + int total_chunks, + const float lr, const float beta1, const float beta2, + const float epsilon, const int step, const int mode, const int bias_correction, + const float weight_decay, cudaStream_t stream); + +/*! \brief Compute and apply gradient update to parameters for Adam optimizer + * (4-list: g, p, m, v). + * Uses precomputed chunk metadata instead of TensorListMetadata. + * + * \warning This API is **experimental** and subject to change. + */ +void nvte_multi_tensor_adam_cuda_custom( + int chunk_size, NVTETensor noop_flag, NVTEDType grad_dtype, NVTEDType param_dtype, + int64_t *addresses, int64_t *sizes, int *block_to_tensor, int *chunk_offsets, + int total_chunks, int has_master, + const float lr, const float beta1, const float beta2, + const float epsilon, const int step, const int mode, const int bias_correction, + const float weight_decay, cudaStream_t stream); + /*! \brief Compute and apply gradient update to parameters for Adam optimizer * when model parameters are in Float8 precision. * diff --git a/transformer_engine/common/multi_tensor/adam.cu b/transformer_engine/common/multi_tensor/adam.cu index 2154102f0..bbafbbe78 100644 --- a/transformer_engine/common/multi_tensor/adam.cu +++ b/transformer_engine/common/multi_tensor/adam.cu @@ -927,6 +927,432 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag, NVTE_CHECK_CUDA(cudaGetLastError()); } +// --------------------------------------------------------------------------- +// Custom adam param-remainder kernel using device-side arrays instead of +// TensorListMetadata. Removes the 320-block limit and avoids packing the +// metadata struct on each launch. +// --------------------------------------------------------------------------- + +template +__device__ __forceinline__ bool is_aligned_n(const T *p) { + return ((uint64_t)p) % (N * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store_n(T *dst, const T *src, + int dst_offset, int src_offset) { + typedef typename std::aligned_storage::type LT; + ((LT *)dst)[dst_offset] = ((const LT *)src)[src_offset]; // NOLINT(*) +} + +static constexpr int CILP = 8; + +template +__global__ __launch_bounds__(BLOCK_SIZE) +void custom_adam_param_remainder_kernel( + const int chunk_size, + volatile int * __restrict__ noop_gmem, + const int64_t * __restrict__ addresses, + const int64_t * __restrict__ sizes, + const int * __restrict__ block_to_tensor, + const int * __restrict__ chunk_offsets, + const int total_chunks, + const float beta1, const float beta2, + const float step_size, const float beta2_corr_inv, + const float epsilon, + const float lr, const float decay) { + const int global_chunk = blockIdx.x; + if (global_chunk >= total_chunks) return; + + const int tensor_loc = block_to_tensor[global_chunk]; + const int chunk_idx = global_chunk - chunk_offsets[tensor_loc]; + + // addresses layout: [tensor_idx * 5 + list_idx] + // list 0 = grads, 1 = params(int16), 2 = exp_avg, 3 = exp_avg_sq, 4 = remainders + GRAD_T * __restrict__ g = + reinterpret_cast(addresses[tensor_loc * 5 + 0]); + int16_t * __restrict__ p = + reinterpret_cast(addresses[tensor_loc * 5 + 1]); + float * __restrict__ m = + reinterpret_cast(addresses[tensor_loc * 5 + 2]); + float * __restrict__ v = + reinterpret_cast(addresses[tensor_loc * 5 + 3]); + int16_t * __restrict__ p_remainder = + reinterpret_cast(addresses[tensor_loc * 5 + 4]); + + const int64_t elem_offset = (int64_t)chunk_idx * chunk_size; + g += elem_offset; + p += elem_offset; + m += elem_offset; + v += elem_offset; + p_remainder += elem_offset; + + const int n_this = static_cast( + min(sizes[tensor_loc] - elem_offset, (int64_t)chunk_size)); + + // Contiguous access: each thread processes CILP adjacent elements. + // This enables 128-bit vectorized loads for 16-bit types (CILP*2 = 16 bytes) + // and 2x 128-bit loads for float types (CILP*4 = 32 bytes). + for (int i_start = threadIdx.x * CILP; i_start < n_this; + i_start += blockDim.x * CILP) { + union fp32_or_int162 { + float fp32; + int16_t int16[2]; + }; + GRAD_T g_raw[CILP]; + int16_t local_p[CILP]; + int16_t local_p_rem[CILP]; + MATH_T r_m[CILP]; + MATH_T r_v[CILP]; + + if (i_start + CILP <= n_this && is_aligned_n(g + i_start)) { + // Vectorized loads: 128-bit for 16-bit types + load_store_n(g_raw, g, 0, i_start / CILP); + load_store_n(local_p, p, 0, i_start / CILP); + load_store_n(local_p_rem, p_remainder, 0, i_start / CILP); + // 2x 128-bit for float types + load_store_n<4>(r_m, m, 0, i_start / 4); + load_store_n<4>(r_m, m, 1, i_start / 4 + 1); + load_store_n<4>(r_v, v, 0, i_start / 4); + load_store_n<4>(r_v, v, 1, i_start / 4 + 1); + } else { +#pragma unroll + for (int ii = 0; ii < CILP; ii++) { + int i = i_start + ii; + if (i < n_this) { + g_raw[ii] = g[i]; + local_p[ii] = p[i]; + local_p_rem[ii] = p_remainder[i]; + r_m[ii] = m[i]; + r_v[ii] = v[i]; + } else { + g_raw[ii] = GRAD_T(0); + local_p[ii] = int16_t(0); + local_p_rem[ii] = int16_t(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } + } + + // Convert grads bf16 -> float + MATH_T r_g[CILP]; +#pragma unroll + for (int ii = 0; ii < CILP; ii++) { + r_g[ii] = static_cast(g_raw[ii]); + } + + // Reconstruct FP32 master params from BF16 + int16 remainder + fp32_or_int162 local_master_param[CILP]; +#pragma unroll + for (int ii = 0; ii < CILP; ii++) { + if (local_p_rem[ii] < 0) local_p[ii]--; + local_master_param[ii].int16[1] = local_p[ii]; + local_master_param[ii].int16[0] = local_p_rem[ii]; + } + + MATH_T *r_p = reinterpret_cast(local_master_param); + +#pragma unroll + for (int ii = 0; ii < CILP; ii++) { + if (MODE == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + } + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T denom = sqrtf(r_v[ii] * beta2_corr_inv) + epsilon; + if (MODE == ADAM_MODE_0) { // L2 + r_p[ii] = r_p[ii] - step_size * (r_m[ii] / denom); + } else { // weight decay + r_p[ii] = r_p[ii] - step_size * (r_m[ii] / denom) - lr * decay * r_p[ii]; + } + } + + // Split into BF16 params (rounded-to-nearest) and remainders +#pragma unroll + for (int ii = 0; ii < CILP; ii++) { + local_p[ii] = local_master_param[ii].int16[1]; + local_p_rem[ii] = local_master_param[ii].int16[0]; + if (local_p_rem[ii] < 0) local_p[ii]++; // Round up + } + + // Store + if (i_start + CILP <= n_this && is_aligned_n(p + i_start)) { + load_store_n(p, local_p, i_start / CILP, 0); + load_store_n(p_remainder, local_p_rem, i_start / CILP, 0); + load_store_n<4>(m, r_m, i_start / 4, 0); + load_store_n<4>(m, r_m, i_start / 4 + 1, 1); + load_store_n<4>(v, r_v, i_start / 4, 0); + load_store_n<4>(v, r_v, i_start / 4 + 1, 1); + } else { +#pragma unroll + for (int ii = 0; ii < CILP; ii++) { + int i = i_start + ii; + if (i < n_this) { + p[i] = local_p[ii]; + p_remainder[i] = local_p_rem[ii]; + m[i] = r_m[ii]; + v[i] = r_v[ii]; + } + } + } + } +} + +void multi_tensor_adam_param_remainder_cuda_custom( + int chunk_size, Tensor noop_flag, DType grad_dtype, + int64_t *addresses, int64_t *sizes, int *block_to_tensor, int *chunk_offsets, + int total_chunks, + const float lr, const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay, cudaStream_t stream) { + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + + const float step_size = lr / bias_correction1; + const float beta2_corr_inv = 1.0f / bias_correction2; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + grad_dtype, grad_type, + if (mode == ADAM_MODE_0) { + custom_adam_param_remainder_kernel + <<>>( + chunk_size, reinterpret_cast(noop_flag.data.dptr), + addresses, sizes, block_to_tensor, chunk_offsets, total_chunks, + beta1, beta2, step_size, beta2_corr_inv, epsilon, lr, + weight_decay); + } else { + custom_adam_param_remainder_kernel + <<>>( + chunk_size, reinterpret_cast(noop_flag.data.dptr), + addresses, sizes, block_to_tensor, chunk_offsets, total_chunks, + beta1, beta2, step_size, beta2_corr_inv, epsilon, lr, + weight_decay); + };); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +// --------------------------------------------------------------------------- +// Custom adam kernel (4-list: g, p, m, v) or (5-list: g, p, m, v, p_master) +// using device-side arrays. +// --------------------------------------------------------------------------- +template +__global__ __launch_bounds__(BLOCK_SIZE) +void custom_adam_kernel( + const int chunk_size, + volatile int * __restrict__ noop_gmem, + const int64_t * __restrict__ addresses, + const int64_t * __restrict__ sizes, + const int * __restrict__ block_to_tensor, + const int * __restrict__ chunk_offsets, + const int total_chunks, + const float beta1, const float beta2, + const float step_size, const float beta2_corr_inv, + const float epsilon, + const float lr, const float decay) { + constexpr int kDepth = HAS_MASTER ? 5 : 4; + const int global_chunk = blockIdx.x; + if (global_chunk >= total_chunks) return; + + const int tensor_loc = block_to_tensor[global_chunk]; + const int chunk_idx = global_chunk - chunk_offsets[tensor_loc]; + + GRAD_T * __restrict__ g = + reinterpret_cast(addresses[tensor_loc * kDepth + 0]); + PARAM_T * __restrict__ p = + reinterpret_cast(addresses[tensor_loc * kDepth + 1]); + float * __restrict__ m = + reinterpret_cast(addresses[tensor_loc * kDepth + 2]); + float * __restrict__ v = + reinterpret_cast(addresses[tensor_loc * kDepth + 3]); + float * __restrict__ p_master = nullptr; + if constexpr (HAS_MASTER) { + p_master = reinterpret_cast(addresses[tensor_loc * kDepth + 4]); + } + + const int64_t elem_offset = (int64_t)chunk_idx * chunk_size; + g += elem_offset; + p += elem_offset; + m += elem_offset; + v += elem_offset; + if constexpr (HAS_MASTER) { + p_master += elem_offset; + } + + const int n_this = static_cast( + min(sizes[tensor_loc] - elem_offset, (int64_t)chunk_size)); + + for (int i_start = threadIdx.x * CILP; i_start < n_this; + i_start += blockDim.x * CILP) { + GRAD_T g_raw[CILP]; + MATH_T r_p[CILP]; + MATH_T r_m[CILP]; + MATH_T r_v[CILP]; + + if (i_start + CILP <= n_this && is_aligned_n(g + i_start)) { + // Vectorized loads + if constexpr (sizeof(GRAD_T) == 2) { + load_store_n(g_raw, g, 0, i_start / CILP); + } else { + load_store_n<4>(g_raw, g, 0, i_start / 4); + load_store_n<4>(g_raw, g, 1, i_start / 4 + 1); + } + if constexpr (HAS_MASTER) { + // Load from FP32 master params + load_store_n<4>(r_p, p_master, 0, i_start / 4); + load_store_n<4>(r_p, p_master, 1, i_start / 4 + 1); + } else { + PARAM_T p_raw[CILP]; + if constexpr (sizeof(PARAM_T) == 2) { + load_store_n(p_raw, p, 0, i_start / CILP); + } else { + load_store_n<4>(p_raw, p, 0, i_start / 4); + load_store_n<4>(p_raw, p, 1, i_start / 4 + 1); + } +#pragma unroll + for (int ii = 0; ii < CILP; ii++) { + r_p[ii] = static_cast(p_raw[ii]); + } + } + load_store_n<4>(r_m, m, 0, i_start / 4); + load_store_n<4>(r_m, m, 1, i_start / 4 + 1); + load_store_n<4>(r_v, v, 0, i_start / 4); + load_store_n<4>(r_v, v, 1, i_start / 4 + 1); + } else { +#pragma unroll + for (int ii = 0; ii < CILP; ii++) { + int i = i_start + ii; + if (i < n_this) { + g_raw[ii] = g[i]; + if constexpr (HAS_MASTER) { + r_p[ii] = p_master[i]; + } else { + r_p[ii] = static_cast(p[i]); + } + r_m[ii] = m[i]; + r_v[ii] = v[i]; + } else { + g_raw[ii] = GRAD_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } + } + + MATH_T r_g[CILP]; +#pragma unroll + for (int ii = 0; ii < CILP; ii++) { + r_g[ii] = static_cast(g_raw[ii]); + } + +#pragma unroll + for (int ii = 0; ii < CILP; ii++) { + if (MODE == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + } + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T denom = sqrtf(r_v[ii] * beta2_corr_inv) + epsilon; + if (MODE == ADAM_MODE_0) { // L2 + r_p[ii] = r_p[ii] - step_size * (r_m[ii] / denom); + } else { // weight decay + r_p[ii] = r_p[ii] - step_size * (r_m[ii] / denom) - lr * decay * r_p[ii]; + } + } + + // Store + if (i_start + CILP <= n_this && is_aligned_n(p + i_start)) { + // Write p (PARAM_T) + PARAM_T p_out[CILP]; +#pragma unroll + for (int ii = 0; ii < CILP; ii++) { + p_out[ii] = static_cast(r_p[ii]); + } + if constexpr (sizeof(PARAM_T) == 2) { + load_store_n(p, p_out, i_start / CILP, 0); + } else { + load_store_n<4>(p, p_out, i_start / 4, 0); + load_store_n<4>(p, p_out, i_start / 4 + 1, 1); + } + if constexpr (HAS_MASTER) { + load_store_n<4>(p_master, r_p, i_start / 4, 0); + load_store_n<4>(p_master, r_p, i_start / 4 + 1, 1); + } + load_store_n<4>(m, r_m, i_start / 4, 0); + load_store_n<4>(m, r_m, i_start / 4 + 1, 1); + load_store_n<4>(v, r_v, i_start / 4, 0); + load_store_n<4>(v, r_v, i_start / 4 + 1, 1); + } else { +#pragma unroll + for (int ii = 0; ii < CILP; ii++) { + int i = i_start + ii; + if (i < n_this) { + p[i] = static_cast(r_p[ii]); + if constexpr (HAS_MASTER) { + p_master[i] = r_p[ii]; + } + m[i] = r_m[ii]; + v[i] = r_v[ii]; + } + } + } + } +} + +void multi_tensor_adam_cuda_custom( + int chunk_size, Tensor noop_flag, DType grad_dtype, DType param_dtype, + int64_t *addresses, int64_t *sizes, int *block_to_tensor, int *chunk_offsets, + int total_chunks, bool has_master, + const float lr, const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay, cudaStream_t stream) { + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + + const float step_size = lr / bias_correction1; + const float beta2_corr_inv = 1.0f / bias_correction2; + +#define LAUNCH_CUSTOM_ADAM(g_type, p_type, adam_mode, master_flag) \ + custom_adam_kernel \ + <<>>( \ + chunk_size, reinterpret_cast(noop_flag.data.dptr), \ + addresses, sizes, block_to_tensor, chunk_offsets, total_chunks, \ + beta1, beta2, step_size, beta2_corr_inv, epsilon, lr, \ + weight_decay) + + if (has_master) { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + param_dtype, p_type, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + grad_dtype, g_type, + if (mode == ADAM_MODE_0) { + LAUNCH_CUSTOM_ADAM(g_type, p_type, ADAM_MODE_0, true); + } else { + LAUNCH_CUSTOM_ADAM(g_type, p_type, ADAM_MODE_1, true); + };);); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + param_dtype, p_type, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + grad_dtype, g_type, + if (mode == ADAM_MODE_0) { + LAUNCH_CUSTOM_ADAM(g_type, p_type, ADAM_MODE_0, false); + } else { + LAUNCH_CUSTOM_ADAM(g_type, p_type, ADAM_MODE_1, false); + };);); + } + +#undef LAUNCH_CUSTOM_ADAM + NVTE_CHECK_CUDA(cudaGetLastError()); +} + } // namespace multi_tensor_adam } // namespace transformer_engine @@ -1004,3 +1430,37 @@ void nvte_multi_tensor_adam_capturable_master_cuda( *convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode, bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), stream); } + +void nvte_multi_tensor_adam_param_remainder_cuda_custom( + int chunk_size, NVTETensor noop_flag, NVTEDType grad_dtype, + int64_t *addresses, int64_t *sizes, int *block_to_tensor, int *chunk_offsets, + int total_chunks, + const float lr, const float beta1, const float beta2, + const float epsilon, const int step, const int mode, const int bias_correction, + const float weight_decay, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_adam_param_remainder_cuda_custom); + using namespace transformer_engine; + + multi_tensor_adam::multi_tensor_adam_param_remainder_cuda_custom( + chunk_size, *convertNVTETensorCheck(noop_flag), static_cast(grad_dtype), + addresses, sizes, block_to_tensor, chunk_offsets, total_chunks, + lr, beta1, beta2, epsilon, step, mode, bias_correction, weight_decay, stream); +} + +void nvte_multi_tensor_adam_cuda_custom( + int chunk_size, NVTETensor noop_flag, NVTEDType grad_dtype, NVTEDType param_dtype, + int64_t *addresses, int64_t *sizes, int *block_to_tensor, int *chunk_offsets, + int total_chunks, int has_master, + const float lr, const float beta1, const float beta2, + const float epsilon, const int step, const int mode, const int bias_correction, + const float weight_decay, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_adam_cuda_custom); + using namespace transformer_engine; + + multi_tensor_adam::multi_tensor_adam_cuda_custom( + chunk_size, *convertNVTETensorCheck(noop_flag), + static_cast(grad_dtype), static_cast(param_dtype), + addresses, sizes, block_to_tensor, chunk_offsets, total_chunks, + has_master != 0, + lr, beta1, beta2, epsilon, step, mode, bias_correction, weight_decay, stream); +} diff --git a/transformer_engine/common/multi_tensor/l2norm.cu b/transformer_engine/common/multi_tensor/l2norm.cu index 4bad90fb6..4742c7d0e 100644 --- a/transformer_engine/common/multi_tensor/l2norm.cu +++ b/transformer_engine/common/multi_tensor/l2norm.cu @@ -31,6 +31,17 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int s ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; // NOLINT(*) } +template +__device__ __forceinline__ bool is_aligned_n(T *p) { + return ((uint64_t)p) % (N * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store_n(T *dst, T *src, int dst_offset, int src_offset) { + typedef typename std::aligned_storage::type LT; + ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; // NOLINT(*) +} + template __device__ __forceinline__ T reduce_block_into_lanes(T *x, T val, int lanes = 1, @@ -261,6 +272,158 @@ struct UnscaleL2NormFunctor { } }; +template +__global__ void custom_multi_tensor_l2norm_kernel( + int chunk_size, volatile int * __restrict__ noop_gmem, + const int64_t * __restrict__ addresses, const int * __restrict__ sizes, + const int * __restrict__ block_to_tensor, const int * __restrict__ chunk_offsets, + int total_chunks, const float * __restrict__ inv_scale, + float * __restrict__ output) { + constexpr int CILP = 8; + const int global_chunk = blockIdx.x; + __shared__ float s_vals[512]; + + const float inv_scale_val = UNSCALE ? *inv_scale : 0.f; + + int n = 0; + x_t *x = nullptr; + if (global_chunk < total_chunks) { + const int tensor_loc = block_to_tensor[global_chunk]; + const int chunk_idx = global_chunk - chunk_offsets[tensor_loc]; + n = sizes[tensor_loc]; + + x = reinterpret_cast(static_cast(addresses[tensor_loc])); + x += static_cast(chunk_idx) * chunk_size; + n -= chunk_idx * chunk_size; + } + + float vals[CILP]; + x_t r_x[CILP]; + for (int i = 0; i < CILP; i++) { + vals[i] = 0.f; + r_x[i] = 0.f; + } + + if (global_chunk < total_chunks) { + if (n % CILP == 0 && chunk_size % CILP == 0 && is_aligned_n(x)) { + for (int i_start = threadIdx.x; i_start * CILP < n && i_start * CILP < chunk_size; + i_start += blockDim.x) { + load_store_n(r_x, x, 0, i_start); +#pragma unroll + for (int ii = 0; ii < CILP; ii++) { + float next = static_cast(r_x[ii]); + if constexpr (UNSCALE) { + next *= inv_scale_val; + } + vals[ii] += next * next; + } + } + } else { + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * CILP) { +#pragma unroll + for (int ii = 0; ii < CILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + float next = static_cast(x[i]); + if constexpr (UNSCALE) { + next *= inv_scale_val; + } + vals[ii] += next * next; + } + } + } + } + } + + float val = 0.f; + for (int i = 0; i < CILP; i++) val += vals[i]; + + float final = reduce_block_into_lanes(s_vals, val); + + if (threadIdx.x == 0) { + if (!isfinite(final)) { + *noop_gmem = 1; + } + output[global_chunk] = final; + } +} + +__global__ void custom_multi_tensor_l2norm_reduce( + const float * __restrict__ output, float * __restrict__ ret, int total_chunks) { + __shared__ float s_vals[512]; + + float val = 0.f; + for (int i = threadIdx.x; i < total_chunks; i += blockDim.x) { + val += output[i]; + } + + float final = reduce_block_into_lanes(s_vals, val); + + if (threadIdx.x == 0) { + *ret = sqrtf(final); + } +} + +void multi_tensor_l2norm_cuda_custom(int chunk_size, NVTETensor noop_flag_tensor, + DType input_dtype, const int64_t *addresses, + const int *sizes, const int *block_to_tensor, + const int *chunk_offsets, int total_chunks, + NVTETensor output_tensor, + NVTETensor ret_tensor, cudaStream_t stream) { + auto *noop_flag = convertNVTETensorCheck(noop_flag_tensor); + auto *output = convertNVTETensorCheck(output_tensor); + auto *ret = convertNVTETensorCheck(ret_tensor); + + if (total_chunks == 0) { + nvte_memset(ret->data.dptr, 0, sizeof(float), stream); + return; + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input_dtype, dtype, + custom_multi_tensor_l2norm_kernel<<>>( + chunk_size, reinterpret_cast(noop_flag->data.dptr), addresses, sizes, + block_to_tensor, chunk_offsets, total_chunks, nullptr, + reinterpret_cast(output->data.dptr));) + NVTE_CHECK_CUDA(cudaGetLastError()); + + custom_multi_tensor_l2norm_reduce<<<1, BLOCK_SIZE, 0, stream>>>( + reinterpret_cast(output->data.dptr), + reinterpret_cast(ret->data.dptr), total_chunks); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void multi_tensor_unscale_l2norm_cuda_custom(int chunk_size, NVTETensor noop_flag_tensor, + DType input_dtype, const int64_t *addresses, + const int *sizes, const int *block_to_tensor, + const int *chunk_offsets, int total_chunks, + NVTETensor output_tensor, NVTETensor ret_tensor, + NVTETensor inv_scale_tensor, cudaStream_t stream) { + auto *noop_flag = convertNVTETensorCheck(noop_flag_tensor); + auto *output = convertNVTETensorCheck(output_tensor); + auto *ret = convertNVTETensorCheck(ret_tensor); + auto *inv_scale = convertNVTETensorCheck(inv_scale_tensor); + + if (total_chunks == 0) { + nvte_memset(ret->data.dptr, 0, sizeof(float), stream); + return; + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input_dtype, dtype, + custom_multi_tensor_l2norm_kernel<<>>( + chunk_size, reinterpret_cast(noop_flag->data.dptr), addresses, sizes, + block_to_tensor, chunk_offsets, total_chunks, + reinterpret_cast(inv_scale->data.dptr), + reinterpret_cast(output->data.dptr));) + NVTE_CHECK_CUDA(cudaGetLastError()); + + custom_multi_tensor_l2norm_reduce<<<1, BLOCK_SIZE, 0, stream>>>( + reinterpret_cast(output->data.dptr), + reinterpret_cast(ret->data.dptr), total_chunks); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + // Probably better to template, but since we are not likely to support other norm template struct MaxNormFunctor { @@ -415,9 +578,6 @@ void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag, NVTE_CHECK_CUDA(cudaGetLastError()); - // This involves one more small kernel launches, but will be negligible end to end. - // I could get rid of these by hacking the functor + multi tensor harness with persistence - // logic, but keeping it simple for now cleanup<<>>( reinterpret_cast(output.data.dptr), per_tensor ? reinterpret_cast(output_per_tensor.data.dptr) : nullptr, @@ -443,9 +603,6 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag, NVTE_CHECK_CUDA(cudaGetLastError()); - // This involves one more small kernel launches, but will be negligible end to end. - // I could get rid of these by hacking the functor + multi tensor harness with persistence - // logic, but keeping it simple for now cleanup<<>>( reinterpret_cast(output.data.dptr), per_tensor ? reinterpret_cast(output_per_tensor.data.dptr) : nullptr, @@ -458,6 +615,33 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag, } // namespace multi_tensor_l2norm } // namespace transformer_engine +void nvte_multi_tensor_l2norm_cuda_custom(int chunk_size, NVTETensor noop_flag, + const NVTEDType input_dtype, const int64_t *addresses, + const int *sizes, const int *block_to_tensor, + const int *chunk_offsets, int total_chunks, + NVTETensor output, + NVTETensor ret, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_l2norm_cuda_custom); + using namespace transformer_engine; + + multi_tensor_l2norm::multi_tensor_l2norm_cuda_custom( + chunk_size, noop_flag, static_cast(input_dtype), addresses, sizes, block_to_tensor, + chunk_offsets, total_chunks, output, ret, stream); +} + +void nvte_multi_tensor_unscale_l2norm_cuda_custom( + int chunk_size, NVTETensor noop_flag, const NVTEDType input_dtype, + const int64_t *addresses, const int *sizes, const int *block_to_tensor, + const int *chunk_offsets, int total_chunks, + NVTETensor output, NVTETensor ret, NVTETensor inv_scale, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_unscale_l2norm_cuda_custom); + using namespace transformer_engine; + + multi_tensor_l2norm::multi_tensor_unscale_l2norm_cuda_custom( + chunk_size, noop_flag, static_cast(input_dtype), addresses, sizes, block_to_tensor, + chunk_offsets, total_chunks, output, ret, inv_scale, stream); +} + void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret, diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp index 145e1d4b4..7d21906c2 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp @@ -8,18 +8,302 @@ namespace transformer_engine::pytorch { +// Cache for device-side mapping arrays used by the custom adam param-remainder +// kernel. During training the tensor list structure (shapes and data pointers) +// is typically identical across iterations, so we can avoid per-call device +// allocations and H2D memcpy by caching the arrays and only re-uploading +// when something changes. +struct CustomAdamParamRemainderCache { + static constexpr int kDepth = 5; // g, p, m, v, p_remainder + + std::vector addresses_host; // [ntensors * kDepth] + std::vector sizes_host; // [ntensors] + at::Tensor addresses_dev; + at::Tensor sizes_dev; + at::Tensor block_to_tensor_dev; + at::Tensor chunk_offsets_dev; + int total_chunks = 0; + int chunk_size = 0; + + bool shapes_valid(int ntensors, int cs, + const std::vector> &tensor_lists) const { + if (chunk_size != cs || static_cast(sizes_host.size()) != ntensors) + return false; + for (int t = 0; t < ntensors; t++) { + if (sizes_host[t] != static_cast(tensor_lists[0][t].numel())) + return false; + } + return true; + } + + bool addresses_valid(int ntensors, + const std::vector> &tensor_lists) const { + if (static_cast(addresses_host.size()) != ntensors * kDepth) + return false; + for (int t = 0; t < ntensors; t++) { + for (int d = 0; d < kDepth; d++) { + if (addresses_host[t * kDepth + d] != + reinterpret_cast(tensor_lists[d][t].data_ptr())) + return false; + } + } + return true; + } + + void rebuild(int ntensors, int cs, int tc, + const std::vector> &tensor_lists, + cudaStream_t stream) { + chunk_size = cs; + total_chunks = tc; + + addresses_host.clear(); + sizes_host.clear(); + std::vector block_to_tensor_host; + std::vector chunk_offsets_host; + addresses_host.reserve(ntensors * kDepth); + sizes_host.reserve(ntensors); + block_to_tensor_host.reserve(tc); + chunk_offsets_host.reserve(ntensors); + + int running_offset = 0; + for (int t = 0; t < ntensors; t++) { + const auto &tensor = tensor_lists[0][t]; + const int64_t tensor_numel = tensor.numel(); + const int chunks_this_tensor = static_cast( + (tensor_numel + cs - 1) / cs); + for (int d = 0; d < kDepth; d++) { + addresses_host.push_back( + reinterpret_cast(tensor_lists[d][t].data_ptr())); + } + sizes_host.push_back(tensor_numel); + chunk_offsets_host.push_back(running_offset); + for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { + block_to_tensor_host.push_back(t); + } + running_offset += chunks_this_tensor; + } + + auto int_options = tensor_lists[0][0].options().dtype(at::kInt); + auto long_options = tensor_lists[0][0].options().dtype(at::kLong); + addresses_dev = at::empty({ntensors * kDepth}, long_options); + sizes_dev = at::empty({ntensors}, long_options); + block_to_tensor_dev = at::empty({tc}, int_options); + chunk_offsets_dev = at::empty({ntensors}, int_options); + + NVTE_CHECK_CUDA(cudaMemcpyAsync(addresses_dev.data_ptr(), addresses_host.data(), + ntensors * kDepth * sizeof(int64_t), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(sizes_dev.data_ptr(), sizes_host.data(), + ntensors * sizeof(int64_t), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(block_to_tensor_dev.data_ptr(), + block_to_tensor_host.data(), + tc * sizeof(int), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(chunk_offsets_dev.data_ptr(), + chunk_offsets_host.data(), + ntensors * sizeof(int), + cudaMemcpyHostToDevice, stream)); + } + + void update_addresses(int ntensors, + const std::vector> &tensor_lists, + cudaStream_t stream) { + addresses_host.clear(); + addresses_host.reserve(ntensors * kDepth); + for (int t = 0; t < ntensors; t++) { + for (int d = 0; d < kDepth; d++) { + addresses_host.push_back( + reinterpret_cast(tensor_lists[d][t].data_ptr())); + } + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(addresses_dev.data_ptr(), + addresses_host.data(), + ntensors * kDepth * sizeof(int64_t), + cudaMemcpyHostToDevice, stream)); + } + + int ensure(int ntensors, int cs, int tc, + const std::vector> &tensor_lists, + cudaStream_t stream) { + if (!shapes_valid(ntensors, cs, tensor_lists)) { + rebuild(ntensors, cs, tc, tensor_lists, stream); + } else if (!addresses_valid(ntensors, tensor_lists)) { + update_addresses(ntensors, tensor_lists, stream); + } + return total_chunks; + } +}; + +static CustomAdamParamRemainderCache g_adam_param_remainder_cache; + +template +struct CustomAdamCache { + + std::vector addresses_host; + std::vector sizes_host; + at::Tensor addresses_dev; + at::Tensor sizes_dev; + at::Tensor block_to_tensor_dev; + at::Tensor chunk_offsets_dev; + int total_chunks = 0; + int chunk_size = 0; + + bool shapes_valid(int ntensors, int cs, + const std::vector> &tensor_lists) const { + if (chunk_size != cs || static_cast(sizes_host.size()) != ntensors) + return false; + for (int t = 0; t < ntensors; t++) { + if (sizes_host[t] != static_cast(tensor_lists[0][t].numel())) + return false; + } + return true; + } + + bool addresses_valid(int ntensors, + const std::vector> &tensor_lists) const { + if (static_cast(addresses_host.size()) != ntensors * kDepth) + return false; + for (int t = 0; t < ntensors; t++) { + for (int d = 0; d < kDepth; d++) { + if (addresses_host[t * kDepth + d] != + reinterpret_cast(tensor_lists[d][t].data_ptr())) + return false; + } + } + return true; + } + + void rebuild(int ntensors, int cs, int tc, + const std::vector> &tensor_lists, + cudaStream_t stream) { + chunk_size = cs; + total_chunks = tc; + + addresses_host.clear(); + sizes_host.clear(); + std::vector block_to_tensor_host; + std::vector chunk_offsets_host; + addresses_host.reserve(ntensors * kDepth); + sizes_host.reserve(ntensors); + block_to_tensor_host.reserve(tc); + chunk_offsets_host.reserve(ntensors); + + int running_offset = 0; + for (int t = 0; t < ntensors; t++) { + const auto &tensor = tensor_lists[0][t]; + const int64_t tensor_numel = tensor.numel(); + const int chunks_this_tensor = static_cast( + (tensor_numel + cs - 1) / cs); + for (int d = 0; d < kDepth; d++) { + addresses_host.push_back( + reinterpret_cast(tensor_lists[d][t].data_ptr())); + } + sizes_host.push_back(tensor_numel); + chunk_offsets_host.push_back(running_offset); + for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { + block_to_tensor_host.push_back(t); + } + running_offset += chunks_this_tensor; + } + + auto int_options = tensor_lists[0][0].options().dtype(at::kInt); + auto long_options = tensor_lists[0][0].options().dtype(at::kLong); + addresses_dev = at::empty({ntensors * kDepth}, long_options); + sizes_dev = at::empty({ntensors}, long_options); + block_to_tensor_dev = at::empty({tc}, int_options); + chunk_offsets_dev = at::empty({ntensors}, int_options); + + NVTE_CHECK_CUDA(cudaMemcpyAsync(addresses_dev.data_ptr(), addresses_host.data(), + ntensors * kDepth * sizeof(int64_t), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(sizes_dev.data_ptr(), sizes_host.data(), + ntensors * sizeof(int64_t), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(block_to_tensor_dev.data_ptr(), + block_to_tensor_host.data(), + tc * sizeof(int), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(chunk_offsets_dev.data_ptr(), + chunk_offsets_host.data(), + ntensors * sizeof(int), + cudaMemcpyHostToDevice, stream)); + } + + void update_addresses(int ntensors, + const std::vector> &tensor_lists, + cudaStream_t stream) { + addresses_host.clear(); + addresses_host.reserve(ntensors * kDepth); + for (int t = 0; t < ntensors; t++) { + for (int d = 0; d < kDepth; d++) { + addresses_host.push_back( + reinterpret_cast(tensor_lists[d][t].data_ptr())); + } + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(addresses_dev.data_ptr(), + addresses_host.data(), + ntensors * kDepth * sizeof(int64_t), + cudaMemcpyHostToDevice, stream)); + } + + int ensure(int ntensors, int cs, int tc, + const std::vector> &tensor_lists, + cudaStream_t stream) { + if (!shapes_valid(ntensors, cs, tensor_lists)) { + rebuild(ntensors, cs, tc, tensor_lists, stream); + } else if (!addresses_valid(ntensors, tensor_lists)) { + update_addresses(ntensors, tensor_lists, stream); + } + return total_chunks; + } +}; + +static CustomAdamCache<4> g_adam_cache; // g, p, m, v +static CustomAdamCache<5> g_adam_master_cache; // g, p, m, v, p_master + void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay) { + const size_t num_lists = tensor_lists.size(); + const int ntensors = tensor_lists[0].size(); + int total_chunks = 0; + for (int t = 0; t < ntensors; t++) { + total_chunks += static_cast( + (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size); + } + auto stream = at::cuda::getCurrentCUDAStream(); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); - auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = - makeTransformerEngineTensorList(tensor_lists); - nvte_multi_tensor_adam_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, - num_tensors, lr, beta1, beta2, epsilon, step, mode, bias_correction, - weight_decay, at::cuda::getCurrentCUDAStream()); + if (num_lists == 4) { + g_adam_cache.ensure(ntensors, chunk_size, total_chunks, tensor_lists, stream); + nvte_multi_tensor_adam_cuda_custom( + chunk_size, noop_flag_cu.data(), + static_cast(GetTransformerEngineDType(tensor_lists[0][0].scalar_type())), + static_cast(GetTransformerEngineDType(tensor_lists[1][0].scalar_type())), + g_adam_cache.addresses_dev.data_ptr(), + g_adam_cache.sizes_dev.data_ptr(), + g_adam_cache.block_to_tensor_dev.data_ptr(), + g_adam_cache.chunk_offsets_dev.data_ptr(), + total_chunks, 0, + lr, beta1, beta2, epsilon, step, mode, bias_correction, + weight_decay, stream); + } else { + g_adam_master_cache.ensure(ntensors, chunk_size, total_chunks, tensor_lists, stream); + nvte_multi_tensor_adam_cuda_custom( + chunk_size, noop_flag_cu.data(), + static_cast(GetTransformerEngineDType(tensor_lists[0][0].scalar_type())), + static_cast(GetTransformerEngineDType(tensor_lists[1][0].scalar_type())), + g_adam_master_cache.addresses_dev.data_ptr(), + g_adam_master_cache.sizes_dev.data_ptr(), + g_adam_master_cache.block_to_tensor_dev.data_ptr(), + g_adam_master_cache.chunk_offsets_dev.data_ptr(), + total_chunks, 1, + lr, beta1, beta2, epsilon, step, mode, bias_correction, + weight_decay, stream); + } } void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag, @@ -27,13 +311,27 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay) { - auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); - auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = - makeTransformerEngineTensorList(tensor_lists); + const int ntensors = tensor_lists[0].size(); + int total_chunks = 0; + for (int t = 0; t < ntensors; t++) { + total_chunks += static_cast( + (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size); + } + + auto stream = at::cuda::getCurrentCUDAStream(); + g_adam_param_remainder_cache.ensure(ntensors, chunk_size, total_chunks, + tensor_lists, stream); - nvte_multi_tensor_adam_param_remainder_cuda( - chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr, beta1, - beta2, epsilon, step, mode, bias_correction, weight_decay, at::cuda::getCurrentCUDAStream()); + auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); + nvte_multi_tensor_adam_param_remainder_cuda_custom( + chunk_size, noop_flag_cu.data(), + static_cast(GetTransformerEngineDType(tensor_lists[0][0].scalar_type())), + g_adam_param_remainder_cache.addresses_dev.data_ptr(), + g_adam_param_remainder_cache.sizes_dev.data_ptr(), + g_adam_param_remainder_cache.block_to_tensor_dev.data_ptr(), + g_adam_param_remainder_cache.chunk_offsets_dev.data_ptr(), + total_chunks, lr, beta1, beta2, epsilon, step, mode, bias_correction, + weight_decay, stream); } void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp index b02cf1fbb..03f7a3d30 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp @@ -8,33 +8,182 @@ namespace transformer_engine::pytorch { +// Cache for device-side mapping arrays used by the custom l2norm kernel. +// During training the tensor list structure (shapes and data pointers) is +// typically identical across iterations, so we can avoid per-call device +// allocations and H2D memcpy by caching the arrays and only re-uploading +// when something changes. +struct CustomL2NormCache { + std::vector addresses_host; + std::vector sizes_host; + at::Tensor addresses_dev; + at::Tensor sizes_dev; + at::Tensor block_to_tensor_dev; + at::Tensor chunk_offsets_dev; + at::Tensor output_dev; + at::Tensor ret_dev; + int total_chunks = 0; + int chunk_size = 0; + + // Check whether the cached shape metadata still matches. + bool shapes_valid(int ntensors, int cs, + const std::vector> &tensor_lists) const { + if (chunk_size != cs || static_cast(sizes_host.size()) != ntensors) + return false; + for (int t = 0; t < ntensors; t++) { + if (sizes_host[t] != static_cast(tensor_lists[0][t].numel())) + return false; + } + return true; + } + + // Check whether cached data pointers still match. + bool addresses_valid(int ntensors, + const std::vector> &tensor_lists) const { + if (static_cast(addresses_host.size()) != ntensors) + return false; + for (int t = 0; t < ntensors; t++) { + if (addresses_host[t] != + reinterpret_cast(tensor_lists[0][t].data_ptr())) + return false; + } + return true; + } + + // Full rebuild: shapes changed, so all arrays need to be re-created. + void rebuild(int ntensors, int cs, int tc, + const std::vector> &tensor_lists, + cudaStream_t stream) { + chunk_size = cs; + total_chunks = tc; + + addresses_host.clear(); + sizes_host.clear(); + std::vector block_to_tensor_host; + std::vector chunk_offsets_host; + addresses_host.reserve(ntensors); + sizes_host.reserve(ntensors); + block_to_tensor_host.reserve(tc); + chunk_offsets_host.reserve(ntensors); + + int running_offset = 0; + for (int t = 0; t < ntensors; t++) { + const auto &tensor = tensor_lists[0][t]; + const int tensor_numel = static_cast(tensor.numel()); + const int chunks_this_tensor = (tensor_numel + cs - 1) / cs; + addresses_host.push_back(reinterpret_cast(tensor.data_ptr())); + sizes_host.push_back(tensor_numel); + chunk_offsets_host.push_back(running_offset); + for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { + block_to_tensor_host.push_back(t); + } + running_offset += chunks_this_tensor; + } + + auto int_options = tensor_lists[0][0].options().dtype(at::kInt); + auto long_options = tensor_lists[0][0].options().dtype(at::kLong); + auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); + addresses_dev = at::empty({ntensors}, long_options); + sizes_dev = at::empty({ntensors}, int_options); + block_to_tensor_dev = at::empty({tc}, int_options); + chunk_offsets_dev = at::empty({ntensors}, int_options); + output_dev = at::empty({tc}, float_options); + ret_dev = at::empty({1}, float_options); + + NVTE_CHECK_CUDA(cudaMemcpyAsync(addresses_dev.data_ptr(), addresses_host.data(), + ntensors * sizeof(int64_t), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(sizes_dev.data_ptr(), sizes_host.data(), + ntensors * sizeof(int), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(block_to_tensor_dev.data_ptr(), + block_to_tensor_host.data(), + tc * sizeof(int), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(chunk_offsets_dev.data_ptr(), + chunk_offsets_host.data(), + ntensors * sizeof(int), + cudaMemcpyHostToDevice, stream)); + } + + // Addresses-only update: shapes unchanged but data pointers moved. + void update_addresses(int ntensors, + const std::vector> &tensor_lists, + cudaStream_t stream) { + addresses_host.clear(); + addresses_host.reserve(ntensors); + for (int t = 0; t < ntensors; t++) { + addresses_host.push_back( + reinterpret_cast(tensor_lists[0][t].data_ptr())); + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(addresses_dev.data_ptr(), + addresses_host.data(), + ntensors * sizeof(int64_t), + cudaMemcpyHostToDevice, stream)); + } + + // Ensure the cache is up to date. Returns total_chunks. + int ensure(int ntensors, int cs, int tc, + const std::vector> &tensor_lists, + cudaStream_t stream) { + if (!shapes_valid(ntensors, cs, tensor_lists)) { + rebuild(ntensors, cs, tc, tensor_lists, stream); + } else if (!addresses_valid(ntensors, tensor_lists)) { + update_addresses(ntensors, tensor_lists, stream); + } + // else: full cache hit, no memcpy needed + return total_chunks; + } +}; + +static CustomL2NormCache g_l2norm_cache; +static CustomL2NormCache g_unscale_l2norm_cache; + std::tuple multi_tensor_l2norm_cuda( int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::optional per_tensor_python) { bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; + const int ntensors = tensor_lists[0].size(); + int total_chunks = 0; + for (int t = 0; t < ntensors; t++) { + total_chunks += (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + } + auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); - auto output = at::zeros({320}, float_options); - at::Tensor output_per_tensor; - at::Tensor ret_per_tensor; - auto ret = at::empty({1}, output.options()); + if (!per_tensor) { + auto stream = at::cuda::getCurrentCUDAStream(); + g_l2norm_cache.ensure(ntensors, chunk_size, total_chunks, tensor_lists, stream); - int ntensors = tensor_lists[0].size(); - int max_chunks_per_tensor = -1; + auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); + auto output_cu = makeTransformerEngineTensor(g_l2norm_cache.output_dev); + auto ret_cu = makeTransformerEngineTensor(g_l2norm_cache.ret_dev); + nvte_multi_tensor_l2norm_cuda_custom( + chunk_size, noop_flag_cu.data(), + static_cast(GetTransformerEngineDType(tensor_lists[0][0].scalar_type())), + g_l2norm_cache.addresses_dev.data_ptr(), + g_l2norm_cache.sizes_dev.data_ptr(), + g_l2norm_cache.block_to_tensor_dev.data_ptr(), + g_l2norm_cache.chunk_offsets_dev.data_ptr(), + total_chunks, output_cu.data(), ret_cu.data(), stream); + auto ret_per_tensor = at::empty({0}, float_options); + return std::tuple(g_l2norm_cache.ret_dev, ret_per_tensor); + } - if (per_tensor) { - for (int t = 0; t < ntensors; t++) { - int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; - if (max_chunks_this_tensor > max_chunks_per_tensor) - max_chunks_per_tensor = max_chunks_this_tensor; - } - output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options); - ret_per_tensor = at::empty({ntensors}, float_options); - } else { - output_per_tensor = at::empty({0}, float_options); - ret_per_tensor = at::empty({0}, float_options); + // per_tensor path: use multi_tensor_apply + const int output_size = total_chunks > 320 ? total_chunks : 320; + auto output = at::zeros({output_size}, float_options); + auto ret = at::empty({1}, float_options); + + int max_chunks_per_tensor = -1; + for (int t = 0; t < ntensors; t++) { + int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + if (max_chunks_this_tensor > max_chunks_per_tensor) + max_chunks_per_tensor = max_chunks_this_tensor; } + auto output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options); + auto ret_per_tensor = at::zeros({ntensors}, float_options); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = @@ -57,30 +206,48 @@ std::tuple multi_tensor_unscale_l2norm_cuda( at::Tensor inv_scale, at::optional per_tensor_python) { bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; - auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); - auto output = at::zeros({320}, float_options); + const int ntensors = tensor_lists[0].size(); + int total_chunks = 0; + for (int t = 0; t < ntensors; t++) { + total_chunks += (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + } - at::Tensor output_per_tensor; - at::Tensor ret_per_tensor; + auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); - int ntensors = tensor_lists[0].size(); - int max_chunks_per_tensor = -1; + if (!per_tensor) { + auto stream = at::cuda::getCurrentCUDAStream(); + g_unscale_l2norm_cache.ensure(ntensors, chunk_size, total_chunks, tensor_lists, stream); - // Create output tensors for multi scale L2 norm kernel. - if (per_tensor) { - for (int t = 0; t < ntensors; t++) { - int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; - if (max_chunks_this_tensor > max_chunks_per_tensor) - max_chunks_per_tensor = max_chunks_this_tensor; - } - output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options); - ret_per_tensor = at::empty({ntensors}, float_options); - } else { - output_per_tensor = at::empty({0}, float_options); - ret_per_tensor = at::empty({0}, float_options); + auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); + auto output_cu = makeTransformerEngineTensor(g_unscale_l2norm_cache.output_dev); + auto ret_cu = makeTransformerEngineTensor(g_unscale_l2norm_cache.ret_dev); + auto inv_scale_cu = makeTransformerEngineTensor(inv_scale); + nvte_multi_tensor_unscale_l2norm_cuda_custom( + chunk_size, noop_flag_cu.data(), + static_cast(GetTransformerEngineDType(tensor_lists[0][0].scalar_type())), + g_unscale_l2norm_cache.addresses_dev.data_ptr(), + g_unscale_l2norm_cache.sizes_dev.data_ptr(), + g_unscale_l2norm_cache.block_to_tensor_dev.data_ptr(), + g_unscale_l2norm_cache.chunk_offsets_dev.data_ptr(), + total_chunks, output_cu.data(), ret_cu.data(), + inv_scale_cu.data(), stream); + auto ret_per_tensor = at::empty({0}, float_options); + return std::tuple(g_unscale_l2norm_cache.ret_dev, ret_per_tensor); } - auto ret = at::empty({1}, output.options()); + // per_tensor path: use multi_tensor_apply + const int output_size = total_chunks > 320 ? total_chunks : 320; + auto output = at::zeros({output_size}, float_options); + auto ret = at::empty({1}, float_options); + + int max_chunks_per_tensor = -1; + for (int t = 0; t < ntensors; t++) { + int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + if (max_chunks_this_tensor > max_chunks_per_tensor) + max_chunks_per_tensor = max_chunks_this_tensor; + } + auto output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options); + auto ret_per_tensor = at::zeros({ntensors}, float_options); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = diff --git a/transformer_engine/pytorch/optimizers/__init__.py b/transformer_engine/pytorch/optimizers/__init__.py index 792eab094..76c1e806b 100644 --- a/transformer_engine/pytorch/optimizers/__init__.py +++ b/transformer_engine/pytorch/optimizers/__init__.py @@ -3,10 +3,12 @@ # See LICENSE for license information. """Fused optimizers and multi-tensor kernels.""" +import torch + from transformer_engine_torch import ( - multi_tensor_scale, - multi_tensor_l2norm, - multi_tensor_unscale_l2norm, + multi_tensor_scale as _multi_tensor_scale, + multi_tensor_l2norm as _multi_tensor_l2norm, + multi_tensor_unscale_l2norm as _multi_tensor_unscale_l2norm, multi_tensor_adam, multi_tensor_adam_fp8, multi_tensor_adam_capturable, @@ -16,3 +18,53 @@ from .fused_adam import FusedAdam from .fused_sgd import FusedSGD from .multi_tensor_apply import MultiTensorApply, multi_tensor_applier + + +def _is_single_tensor_list(tensor_lists: list[list[torch.Tensor]], expected_lists: int) -> bool: + return len(tensor_lists) == expected_lists and all(len(tensor_list) == 1 for tensor_list in tensor_lists) + + +def _update_noop_flag(noop_flag_buffer: torch.Tensor | None, condition: torch.Tensor) -> None: + if noop_flag_buffer is None: + return + noop_flag_buffer.add_(condition.to(dtype=noop_flag_buffer.dtype, device=noop_flag_buffer.device)) + + +def multi_tensor_scale(chunk_size, noop_flag_buffer, tensor_lists, scale): + if _is_single_tensor_list(tensor_lists, 2) and ( + noop_flag_buffer is None or noop_flag_buffer.numel() == 0 + ): + input_tensor = tensor_lists[0][0] + output_tensor = tensor_lists[1][0] + torch.mul(input_tensor, scale, out=output_tensor) + return None + return _multi_tensor_scale(chunk_size, noop_flag_buffer, tensor_lists, scale) + + +def multi_tensor_l2norm(chunk_size, noop_flag_buffer, tensor_lists, per_tensor): + if _is_single_tensor_list(tensor_lists, 1): + input_tensor = tensor_lists[0][0] + norm = torch.empty((), device=input_tensor.device, dtype=torch.float32) + torch.linalg.vector_norm(input_tensor, ord=2, dtype=torch.float32, out=norm) + _update_noop_flag(noop_flag_buffer, ~torch.isfinite(norm)) + norm = norm.reshape(1) + per_tensor_norm = norm if per_tensor else torch.empty(0, device=input_tensor.device, dtype=torch.float32) + return norm, per_tensor_norm + return _multi_tensor_l2norm(chunk_size, noop_flag_buffer, tensor_lists, per_tensor) + + +def multi_tensor_unscale_l2norm(chunk_size, noop_flag_buffer, tensor_lists, inv_scale, per_tensor): + if _is_single_tensor_list(tensor_lists, 1): + input_tensor = tensor_lists[0][0] + scaled_norm = torch.empty((), device=input_tensor.device, dtype=torch.float32) + torch.linalg.vector_norm(input_tensor, ord=2, dtype=torch.float32, out=scaled_norm) + scaled_norm.mul_(torch.abs(inv_scale.reshape(())).to(dtype=torch.float32)) + _update_noop_flag(noop_flag_buffer, ~torch.isfinite(scaled_norm)) + scaled_norm = scaled_norm.reshape(1) + per_tensor_norm = ( + scaled_norm + if per_tensor + else torch.empty(0, device=input_tensor.device, dtype=torch.float32) + ) + return scaled_norm, per_tensor_norm + return _multi_tensor_unscale_l2norm(chunk_size, noop_flag_buffer, tensor_lists, inv_scale, per_tensor)