Problem
NNlib.softmax fails with MissingPrimalError due to an internal isfinite check, while LogExpFunctions.softmax works fine.
MWE
using SparseConnectivityTracer
using NNlib
using LogExpFunctions
detector = TracerSparsityDetector()
# NNlib.softmax fails
jacobian_sparsity(x -> NNlib.softmax(x), [1.0, 2.0, 3.0], detector)
# MissingPrimalError: isfinite
# LogExpFunctions.softmax works
jacobian_sparsity(x -> LogExpFunctions.softmax(x), [1.0, 2.0, 3.0], detector)
# OK: nnz=9
# Manual implementation works
jacobian_sparsity(x -> (ex = exp.(x); ex / sum(ex)), [1.0, 2.0, 3.0], detector)
# OK: nnz=9
Root cause
NNlib.softmax! calls all(isfinite, out) to check for overflow:
https://github.com/FluxML/NNlib.jl/blob/master/src/softmax.jl#L62
Workaround
Use LogExpFunctions.softmax or TracerLocalSparsityDetector.
Problem
NNlib.softmaxfails withMissingPrimalErrordue to an internalisfinitecheck, whileLogExpFunctions.softmaxworks fine.MWE
Root cause
NNlib.softmax!callsall(isfinite, out)to check for overflow:https://github.com/FluxML/NNlib.jl/blob/master/src/softmax.jl#L62
Workaround
Use
LogExpFunctions.softmaxorTracerLocalSparsityDetector.