diff --git a/Project.toml b/Project.toml index a4f45b2..919d655 100644 --- a/Project.toml +++ b/Project.toml @@ -1,16 +1,12 @@ name = "FunctionProperties" uuid = "f62d2435-5019-4c03-9749-2d4c77af0cbc" +version = "0.1.4" authors = ["SciML"] -version = "0.1.3" [deps] -Cassette = "7057c7e9-c182-5462-911a-8362d720325c" -DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" [compat] -Cassette = "0.3.12" ComponentArrays = "0.15" -DiffRules = "1.15" Random = "1.10" SafeTestsets = "0.1" Test = "1.10" diff --git a/src/FunctionProperties.jl b/src/FunctionProperties.jl index 91198d9..761e7fa 100644 --- a/src/FunctionProperties.jl +++ b/src/FunctionProperties.jl @@ -1,160 +1,60 @@ module FunctionProperties -using Cassette, DiffRules -using Core: CodeInfo, SlotNumber, SSAValue, ReturnNode, GotoIfNot +using Core: GotoIfNot -const printbranch = false - -Cassette.@context HasBranchingCtx - -function Cassette.overdub(ctx::HasBranchingCtx, f, args...) - if Cassette.canrecurse(ctx, f, args...) - return Cassette.recurse(ctx, f, args...) - else - return Cassette.fallback(ctx, f, args...) - end -end +""" + is_leaf(f, args...) -> Bool -for (mod, f, n) in DiffRules.diffrules(; filter_modules = nothing) - if !(isdefined(@__MODULE__, mod) && isdefined(getfield(@__MODULE__, mod), f)) - continue # Skip rules for methods not defined in the current scope - end - @eval function Cassette.overdub( - ::HasBranchingCtx, f::Core.Typeof($mod.$f), - x::Vararg{Any, $n} - ) - return f(x...) - end -end +Override this to exempt a function from `hasbranching` analysis. +Return `true` to treat `f` as branch-free regardless of its implementation. -function _pass(::Type{<:HasBranchingCtx}, reflection::Cassette.Reflection) - ir = reflection.code_info - - if any(x -> isa(x, GotoIfNot), ir.code) - printbranch && println("GotoIfNot detected in $(reflection.method)\nir = $ir\n") - Cassette.insert_statements!( - ir.code, ir.codelocs, - (stmt, i) -> i == 1 ? 3 : nothing, - ( - stmt, - i, - ) -> Any[ - Expr( - :call, - Expr( - :nooverdub, - GlobalRef(Base, :getfield) - ), - Expr(:contextslot), - QuoteNode(:metadata) - ), - Expr( - :call, - Expr( - :nooverdub, - GlobalRef(Base, :setindex!) - ), - SSAValue(1), true, - QuoteNode(:has_branching) - ), - stmt, - ] - ) - Cassette.insert_statements!( - ir.code, ir.codelocs, - (stmt, i) -> i > 2 && isa(stmt, Expr) ? 1 : nothing, - ( - stmt, - i, - ) -> begin - callstmt = Meta.isexpr(stmt, :(=)) ? stmt.args[2] : - stmt - Meta.isexpr(stmt, :call) || - Meta.isexpr(stmt, :invoke) || return Any[stmt] - callstmt = Expr( - callstmt.head, - Expr(:nooverdub, callstmt.args[1]), - callstmt.args[2:end]... - ) - return Any[ - Meta.isexpr(stmt, :(=)) ? - Expr(:(=), stmt.args[1], callstmt) : - callstmt, - ] - end - ) - end - return ir -end +## Example -const pass = Cassette.@pass _pass +```julia +FunctionProperties.is_leaf(::typeof(my_fn)) = true +``` +""" +is_leaf(f, args...) = false """ hasbranching(f, x...) -Checks whether the function `f` has branches (if statements) that are dependent on the value x -that would be taken in a tracing system, such as during AD tracing by a package like ReverseDiff.jl. +Checks whether the function `f` has branches (if statements) that are dependent on the +value `x` that would be taken in a tracing system, such as during AD tracing by a package +like ReverseDiff.jl. -## Arguments: +## Arguments - * `f`: the function to inspect - * `x`: test arguments for the inspection. These values do not need to be the values that - would be used in the actual calls to the function but instead prototype values which - match the types that would be used in the actual function call. This is used to trace to - the correct internal dispatches. + - `f`: the function to inspect. + - `x`: test arguments. These values don't need to match the actual call values, but their + *types* must match — they are used to select the right method specialization. -## Outputs: +## Outputs - Boolean for whether the function has branches. +Boolean for whether the function's immediate IR contains a conditional branch (`GotoIfNot`). -## Customizing and Removing Dispatches from the Checks +## Customizing and Removing Functions from the Checks -Some internal functions of a package may cause false positives because a branch may be known to -resolve at compile time. If this is known, then you can add a dispatch to opt that function out -of the analysis via: +Some functions may produce false positives because their internal branches are compile-time +constants. Override `FunctionProperties.is_leaf` to opt them out: ```julia -function FunctionProperties.Cassette.overdub(::FunctionProperties.HasBranchingCtx, ::typeof(f), x...) - f(x...) -end +FunctionProperties.is_leaf(::typeof(my_fn)) = true ``` """ function hasbranching(f, x...) - metadata = Dict(:has_branching => false) - Cassette.overdub(Cassette.disablehooks(HasBranchingCtx(; pass, metadata)), f, x...) - return metadata[:has_branching] -end - -Cassette.overdub(::HasBranchingCtx, ::typeof(+), x...) = +(x...) -Cassette.overdub(::HasBranchingCtx, ::typeof(*), x...) = *(x...) -function Cassette.overdub(::HasBranchingCtx, ::typeof(Base.materialize), x...) - return Base.materialize(x...) -end -function Cassette.overdub(::HasBranchingCtx, ::typeof(Base.literal_pow), x...) - return Base.literal_pow(x...) -end -Cassette.overdub(::HasBranchingCtx, ::typeof(Base.getindex), x...) = Base.getindex(x...) -Cassette.overdub(::HasBranchingCtx, ::typeof(Base.setindex!), x...) = Base.setindex!(x...) -Cassette.overdub(::HasBranchingCtx, ::typeof(Core.Typeof), x...) = Core.Typeof(x...) -Cassette.overdub(::HasBranchingCtx, ::typeof(Base.vec), x...) = Base.vec(x...) -Cassette.overdub(::HasBranchingCtx, ::typeof(Base.vect), x...) = Base.vect(x...) - -Cassette.overdub(::HasBranchingCtx, ::typeof(Base.vcat), x...) = Base.vcat(x...) -Cassette.overdub(::HasBranchingCtx, ::typeof(Base.hcat), x...) = Base.hcat(x...) -Cassette.overdub(::HasBranchingCtx, ::typeof(Base.hvcat), x...) = Base.hvcat(x...) -Cassette.overdub(::HasBranchingCtx, ::typeof(Base.cat), x...) = Base.cat(x...) -Cassette.overdub(::HasBranchingCtx, ::typeof(Base.stack), x...) = Base.stack(x...) - -function Cassette.overdub(::HasBranchingCtx, ::typeof(Base.Broadcast.broadcasted), x...) - return Base.Broadcast.broadcasted(x...) -end -function Cassette.overdub( - ::HasBranchingCtx, ::Type{Base.OneTo{T}}, - stop - ) where {T <: Integer} - return Base.OneTo{T}(stop) + is_leaf(f, x...) && return false + argtypes = Tuple{Core.Typeof.(x)...} + results = try + code_typed(f, argtypes; optimize = false) + catch + return false + end + isempty(results) && return false + ci = first(results)[1] + return any(isa(s, GotoIfNot) for s in ci.code) end -export hasbranching +export hasbranching, is_leaf end diff --git a/test/runtests.jl b/test/runtests.jl index f295781..675899d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,92 +13,88 @@ end if GROUP in ("All", "Core") -@test hasbranching(1, 2) do x, y - (x < 0 ? -x : x) + exp(y) -end - -@test !hasbranching(1, 2) do x, y - ifelse(x < 0, -x, x) + exp(y) -end - -# Test overloading - -f_branch() = true ? 1 : 0 -@test FunctionProperties.hasbranching(f_branch) -function FunctionProperties.Cassette.overdub( - ::FunctionProperties.HasBranchingCtx, ::typeof(f_branch), x... - ) - return f_branch(x...) -end -@test !FunctionProperties.hasbranching(f_branch) - -# Test simple mutating functions -function f(dx, x) - return @inbounds dx[1] = x[1] -end -x = zeros(1) -dx = zeros(1) -@test !FunctionProperties.hasbranching(f, dx, x) - -# Test broadcast -function f(x) - return cos.(x .+ x .* x) -end -x = [1.0] -@test !FunctionProperties.hasbranching(f, x) - -# Neural networks -# -# The relevant scenario is a neural-network-shaped ODE right-hand side (SciML/SciMLSensitivity.jl#997): -# `hasbranching` must report it as branch-free so a tracing AD like ReverseDiff can compile a tape. -# The forward pass is expressed here as explicit affine transforms plus broadcast activations, which -# is the value flow `hasbranching` actually inspects. We deliberately do not trace a real Lux layer: -# modern Lux layer dispatch routes through device-detection / type-introspection helpers that contain -# genuine (but value-independent, compile-time) `GotoIfNot` branches, which this syntactic IR scan -# cannot distinguish from value-dependent branches (SciML/FunctionProperties.jl#46). -rng = Random.default_rng() -W = randn(rng, Float32, 1, 1) -b = randn(rng, Float32, 1) -p = ComponentArray(; weight = W, bias = b) -t = [0.0] - -function f(x, ps) - return ps.weight * x -end -@test !FunctionProperties.hasbranching(f, t, p) - -function f(x, ps) - return x .+ x -end -@test !FunctionProperties.hasbranching(f, t, p) - -# Affine transform followed by a broadcast activation (the original `apply_activation` intent). -function f2(x, ps) - return identity.(ps.weight * x .+ vec(ps.bias)) -end -@test !FunctionProperties.hasbranching(f2, t, p) - -# A multi-layer perceptron forward pass built from broadcast `tanh` activations. -rng = Random.default_rng() -tspan = (0.0f0, 8.0f0) -W1 = randn(rng, Float32, 32, 2) -b1 = randn(rng, Float32, 32) -W2 = randn(rng, Float32, 32, 32) -b2 = randn(rng, Float32, 32) -W3 = randn(rng, Float32, 1, 32) -b3 = randn(rng, Float32, 1) -p = ComponentArray(; W1, b1, W2, b2, W3, b3) -θ, ax = getdata(p), getaxes(p) - -ann(x, p) = p.W3 * tanh.(p.W2 * tanh.(p.W1 * x .+ p.b1) .+ p.b2) .+ p.b3 - -function dxdt_(dx, x, p, t) - x1, x2 = x - dx[1] = x[2] + first(ann(x, p)) - return dx[2] = first(ann([t, t], p)) -end -x0 = [-4.0f0, 0.0f0] -ts = Float32.(collect(0.0:0.01:tspan[2])) -@test !FunctionProperties.hasbranching(dxdt_, copy(x0), x0, p, tspan[1]) + @test hasbranching(1, 2) do x, y + (x < 0 ? -x : x) + exp(y) + end + + @test !hasbranching(1, 2) do x, y + ifelse(x < 0, -x, x) + exp(y) + end + + # Test overloading via is_leaf + + f_branch() = true ? 1 : 0 + @test FunctionProperties.hasbranching(f_branch) + FunctionProperties.is_leaf(::typeof(f_branch)) = true + @test !FunctionProperties.hasbranching(f_branch) + + # Test simple mutating functions + function f(dx, x) + return @inbounds dx[1] = x[1] + end + x = zeros(1) + dx = zeros(1) + @test !FunctionProperties.hasbranching(f, dx, x) + + # Test broadcast + function f(x) + return cos.(x .+ x .* x) + end + x = [1.0] + @test !FunctionProperties.hasbranching(f, x) + + # Neural networks + # + # The relevant scenario is a neural-network-shaped ODE right-hand side (SciML/SciMLSensitivity.jl#997): + # `hasbranching` must report it as branch-free so a tracing AD like ReverseDiff can compile a tape. + # The forward pass is expressed here as explicit affine transforms plus broadcast activations, which + # is the value flow `hasbranching` actually inspects. We deliberately do not trace a real Lux layer: + # modern Lux layer dispatch routes through device-detection / type-introspection helpers that contain + # genuine (but value-independent, compile-time) `GotoIfNot` branches, which this syntactic IR scan + # cannot distinguish from value-dependent branches (SciML/FunctionProperties.jl#46). + rng = Random.default_rng() + W = randn(rng, Float32, 1, 1) + b = randn(rng, Float32, 1) + p = ComponentArray(; weight = W, bias = b) + t = [0.0] + + function f(x, ps) + return ps.weight * x + end + @test !FunctionProperties.hasbranching(f, t, p) + + function f(x, ps) + return x .+ x + end + @test !FunctionProperties.hasbranching(f, t, p) + + # Affine transform followed by a broadcast activation (the original `apply_activation` intent). + function f2(x, ps) + return identity.(ps.weight * x .+ vec(ps.bias)) + end + @test !FunctionProperties.hasbranching(f2, t, p) + + # A multi-layer perceptron forward pass built from broadcast `tanh` activations. + rng = Random.default_rng() + tspan = (0.0f0, 8.0f0) + W1 = randn(rng, Float32, 32, 2) + b1 = randn(rng, Float32, 32) + W2 = randn(rng, Float32, 32, 32) + b2 = randn(rng, Float32, 32) + W3 = randn(rng, Float32, 1, 32) + b3 = randn(rng, Float32, 1) + p = ComponentArray(; W1, b1, W2, b2, W3, b3) + θ, ax = getdata(p), getaxes(p) + + ann(x, p) = p.W3 * tanh.(p.W2 * tanh.(p.W1 * x .+ p.b1) .+ p.b2) .+ p.b3 + + function dxdt_(dx, x, p, t) + x1, x2 = x + dx[1] = x[2] + first(ann(x, p)) + return dx[2] = first(ann([t, t], p)) + end + x0 = [-4.0f0, 0.0f0] + ts = Float32.(collect(0.0:0.01:tspan[2])) + @test !FunctionProperties.hasbranching(dxdt_, copy(x0), x0, p, tspan[1]) end