Skip to content

Commit cbbb3dc

Browse files
committed
optimized: add BFloat16 and Half support to opt_log_softmax_out
opt_log_softmax_out only handled Float; BFloat16 and Half fell through to ET_KERNEL_CHECK(false), leaving output unchanged. The underlying log_softmax_kernel<IN_T, OUT_T> is fully generic and the ATen vectorized functions it delegates to already support BFloat16 and Half. - Extend log_softmax_wrapper with an if constexpr branch for BFloat16/Half that calls log_softmax_kernel<T, T> - Add BFloat16 and Half dispatch cases in opt_log_softmax_out
1 parent 16ba018 commit cbbb3dc

1 file changed

Lines changed: 32 additions & 13 deletions

File tree

kernels/optimized/cpu/op_log_softmax.cpp

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,20 +98,27 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
9898
return;
9999
}
100100

101-
// OUT_T is the corresponding C++ type for out.scalar_type(). Only takes float
102-
// or double.
103-
template <
104-
typename OUT_T,
105-
std::enable_if_t<std::is_floating_point<OUT_T>::value, bool> = true>
101+
// OUT_T is the corresponding C++ type for out.scalar_type().
102+
template <typename OUT_T>
106103
bool log_softmax_wrapper(const Tensor& X, int64_t dim, Tensor& out) {
107-
auto input_scalar_type = X.scalar_type();
108-
switch (input_scalar_type) {
109-
// TODO: support Double as well
110-
case ScalarType::Float:
111-
log_softmax_kernel<float, OUT_T>(X, dim, out);
112-
return true;
113-
default:
114-
return false; // Unsupported input dtype
104+
if constexpr (
105+
std::is_same_v<OUT_T, executorch::aten::BFloat16> ||
106+
std::is_same_v<OUT_T, executorch::aten::Half>) {
107+
// Input dtype equals output dtype (enforced by check_log_softmax_args).
108+
// Use if constexpr to avoid instantiating cross-type combinations that
109+
// the ATen vectorized functions do not support.
110+
log_softmax_kernel<OUT_T, OUT_T>(X, dim, out);
111+
return true;
112+
} else {
113+
auto input_scalar_type = X.scalar_type();
114+
switch (input_scalar_type) {
115+
// TODO: support Double as well
116+
case ScalarType::Float:
117+
log_softmax_kernel<float, OUT_T>(X, dim, out);
118+
return true;
119+
default:
120+
return false; // Unsupported input dtype
121+
}
115122
}
116123
}
117124
} // namespace
@@ -148,6 +155,18 @@ Tensor& opt_log_softmax_out(
148155
ET_KERNEL_CHECK(context, success, InvalidArgument, out);
149156
break;
150157
}
158+
case ScalarType::BFloat16: {
159+
bool success =
160+
log_softmax_wrapper<executorch::aten::BFloat16>(self, dim, out);
161+
ET_KERNEL_CHECK(context, success, InvalidArgument, out);
162+
break;
163+
}
164+
case ScalarType::Half: {
165+
bool success =
166+
log_softmax_wrapper<executorch::aten::Half>(self, dim, out);
167+
ET_KERNEL_CHECK(context, success, InvalidArgument, out);
168+
break;
169+
}
151170
default:
152171
ET_KERNEL_CHECK(context, false, InvalidArgument, out);
153172
}

0 commit comments

Comments
 (0)