From a04da54df63d8502f6fd3f6a3e2ee9bfd3fa9314 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Sun, 29 Jun 2025 17:17:22 -0700 Subject: [PATCH] Parallelize optimized op_log_softmax Straightforward application of parallel_for. Differential Revision: [D76831122](https://our.internmc.facebook.com/intern/diff/D76831122/) [ghstack-poisoned] --- kernels/optimized/cpu/op_log_softmax.cpp | 52 ++++++++++++++++-------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/kernels/optimized/cpu/op_log_softmax.cpp b/kernels/optimized/cpu/op_log_softmax.cpp index e2e8dfbeca7..f57beac1dbb 100644 --- a/kernels/optimized/cpu/op_log_softmax.cpp +++ b/kernels/optimized/cpu/op_log_softmax.cpp @@ -54,33 +54,49 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) { } if (dim == input.dim() - 1) { - at::native::serial_vec_log_softmax_lastdim_range( - input_data_base, - output_data_base, - dim_size, - at::native::vec_log_softmax_lastdim_chunk_size( - executorch::extension::internal::GRAIN_SIZE, outer_size, dim_size), - // TODO: parallelize. + ::executorch::extension::parallel_for( 0, - outer_size); + outer_size, + ::executorch::extension::internal::GRAIN_SIZE, + [&](const auto begin, const auto end) { + at::native::serial_vec_log_softmax_lastdim_range( + input_data_base, + output_data_base, + dim_size, + at::native::vec_log_softmax_lastdim_chunk_size( + executorch::extension::internal::GRAIN_SIZE, + outer_size, + dim_size), + begin, + end); + }); } else { // BLOCK_SIZE in PyTorch is intended for server CPUs; let's // halve it to try and have a better chance of fitting in mobile // chip caches. - const auto [chunk_size, num_chunks] = + const auto [chunk_size_binding, num_chunks_binding] = at::native::vec_logsoftmax_chunk_size_and_num_chunks< float, /*BLOCK_SIZE=*/64 * 1024>(inner_size, dim_size); - at::native::serial_vec_logsoftmax_range( - input_data_base, - output_data_base, - inner_size, - chunk_size, - num_chunks, - dim_size, - // TODO: parallelize + // Work around "capturing a structured binding is not yet supported in + // OpenMP". + const auto chunk_size = chunk_size_binding; + const auto num_chunks = num_chunks_binding; + ::executorch::extension::parallel_for( 0, - outer_size * num_chunks); + outer_size * num_chunks, + ::executorch::extension::internal::GRAIN_SIZE, + [&](const auto begin, const auto end) { + at::native::serial_vec_logsoftmax_range( + input_data_base, + output_data_base, + inner_size, + chunk_size, + num_chunks, + dim_size, + begin, + end); + }); } return; }