Skip to content

Bug: NNlib.softmax fails with MissingPrimalError (isfinite check) #313

Description

@bdrhill

Problem

NNlib.softmax fails with MissingPrimalError due to an internal isfinite check, while LogExpFunctions.softmax works fine.

MWE

using SparseConnectivityTracer
using NNlib
using LogExpFunctions

detector = TracerSparsityDetector()

# NNlib.softmax fails
jacobian_sparsity(x -> NNlib.softmax(x), [1.0, 2.0, 3.0], detector)
# MissingPrimalError: isfinite

# LogExpFunctions.softmax works
jacobian_sparsity(x -> LogExpFunctions.softmax(x), [1.0, 2.0, 3.0], detector)
# OK: nnz=9

# Manual implementation works
jacobian_sparsity(x -> (ex = exp.(x); ex / sum(ex)), [1.0, 2.0, 3.0], detector)
# OK: nnz=9

Root cause

NNlib.softmax! calls all(isfinite, out) to check for overflow:
https://github.com/FluxML/NNlib.jl/blob/master/src/softmax.jl#L62

Workaround

Use LogExpFunctions.softmax or TracerLocalSparsityDetector.

Metadata

Metadata

Assignees

No one assigned

    Labels

    agent 🤖Agentically generated issuearrayFeatures regarding array overloadsnew overloadsA new method on tracers is required by a user.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions