Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .JuliaFormatter.toml

This file was deleted.

14 changes: 10 additions & 4 deletions .github/workflows/FormatCheck.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
name: "Format Check"
name: format-check

on:
push:
branches:
- 'master'
- 'main'
- 'release-'
tags: '*'
pull_request:

jobs:
format-check:
name: "Format Check"
uses: "SciML/.github/.github/workflows/format-check.yml@v1"
runic:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: fredrikekre/runic-action@v1
with:
version: '1'
22 changes: 14 additions & 8 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,29 @@ cp("./docs/Project.toml", "./docs/src/assets/Project.toml", force = true)

pages = [
"Home" => "index.md",
"api.md"
"api.md",
]

makedocs(sitename = "FunctionProperties.jl",
makedocs(
sitename = "FunctionProperties.jl",
authors = "Chris Rackauckas",
modules = [FunctionProperties],
clean = true, doctest = false, linkcheck = true,
warnonly = [
:doctest,
:linkcheck,
:parse_error,
:example_block # Other available options are # :autodocs_block, :cross_references, :docs_block, :eval_block, :example_block, :footnote, :meta_block, :missing_docs, :setup_block
:example_block, # Other available options are # :autodocs_block, :cross_references, :docs_block, :eval_block, :example_block, :footnote, :meta_block, :missing_docs, :setup_block
],
format = Documenter.HTML(analytics = "UA-90474609-3",
format = Documenter.HTML(
analytics = "UA-90474609-3",
assets = ["assets/favicon.ico"],
canonical = "https://docs.sciml.ai/FunctionProperties/stable/"),
pages = pages)
canonical = "https://docs.sciml.ai/FunctionProperties/stable/"
),
pages = pages
)

deploydocs(repo = "github.com/SciML/MultiScaleArrays.jl.git";
push_preview = true)
deploydocs(
repo = "github.com/SciML/MultiScaleArrays.jl.git";
push_preview = true
)
87 changes: 56 additions & 31 deletions src/FunctionProperties.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ for (mod, f, n) in DiffRules.diffrules(; filter_modules = nothing)
if !(isdefined(@__MODULE__, mod) && isdefined(getfield(@__MODULE__, mod), f))
continue # Skip rules for methods not defined in the current scope
end
@eval function Cassette.overdub(::HasBranchingCtx, f::Core.Typeof($mod.$f),
x::Vararg{Any, $n})
f(x...)
@eval function Cassette.overdub(
::HasBranchingCtx, f::Core.Typeof($mod.$f),
x::Vararg{Any, $n}
)
return f(x...)
end
end

Expand All @@ -30,36 +32,57 @@ function _pass(::Type{<:HasBranchingCtx}, reflection::Cassette.Reflection)

if any(x -> isa(x, GotoIfNot), ir.code)
printbranch && println("GotoIfNot detected in $(reflection.method)\nir = $ir\n")
Cassette.insert_statements!(ir.code, ir.codelocs,
Cassette.insert_statements!(
ir.code, ir.codelocs,
(stmt, i) -> i == 1 ? 3 : nothing,
(stmt,
i) -> Any[
Expr(:call,
Expr(:nooverdub,
GlobalRef(Base, :getfield)),
(
stmt,
i,
) -> Any[
Expr(
:call,
Expr(
:nooverdub,
GlobalRef(Base, :getfield)
),
Expr(:contextslot),
QuoteNode(:metadata)),
Expr(:call,
Expr(:nooverdub,
GlobalRef(Base, :setindex!)),
QuoteNode(:metadata)
),
Expr(
:call,
Expr(
:nooverdub,
GlobalRef(Base, :setindex!)
),
SSAValue(1), true,
QuoteNode(:has_branching)),
stmt])
Cassette.insert_statements!(ir.code, ir.codelocs,
QuoteNode(:has_branching)
),
stmt,
]
)
Cassette.insert_statements!(
ir.code, ir.codelocs,
(stmt, i) -> i > 2 && isa(stmt, Expr) ? 1 : nothing,
(stmt,
i) -> begin
(
stmt,
i,
) -> begin
callstmt = Meta.isexpr(stmt, :(=)) ? stmt.args[2] :
stmt
stmt
Meta.isexpr(stmt, :call) ||
Meta.isexpr(stmt, :invoke) || return Any[stmt]
callstmt = Expr(callstmt.head,
callstmt = Expr(
callstmt.head,
Expr(:nooverdub, callstmt.args[1]),
callstmt.args[2:end]...)
return Any[Meta.isexpr(stmt, :(=)) ?
Expr(:(=), stmt.args[1], callstmt) :
callstmt]
end)
callstmt.args[2:end]...
)
return Any[
Meta.isexpr(stmt, :(=)) ?
Expr(:(=), stmt.args[1], callstmt) :
callstmt,
]
end
)
end
return ir
end
Expand Down Expand Up @@ -105,10 +128,10 @@ end
Cassette.overdub(::HasBranchingCtx, ::typeof(+), x...) = +(x...)
Cassette.overdub(::HasBranchingCtx, ::typeof(*), x...) = *(x...)
function Cassette.overdub(::HasBranchingCtx, ::typeof(Base.materialize), x...)
Base.materialize(x...)
return Base.materialize(x...)
end
function Cassette.overdub(::HasBranchingCtx, ::typeof(Base.literal_pow), x...)
Base.literal_pow(x...)
return Base.literal_pow(x...)
end
Cassette.overdub(::HasBranchingCtx, ::typeof(Base.getindex), x...) = Base.getindex(x...)
Cassette.overdub(::HasBranchingCtx, ::typeof(Base.setindex!), x...) = Base.setindex!(x...)
Expand All @@ -123,11 +146,13 @@ Cassette.overdub(::HasBranchingCtx, ::typeof(Base.cat), x...) = Base.cat(x...)
Cassette.overdub(::HasBranchingCtx, ::typeof(Base.stack), x...) = Base.stack(x...)

function Cassette.overdub(::HasBranchingCtx, ::typeof(Base.Broadcast.broadcasted), x...)
Base.Broadcast.broadcasted(x...)
return Base.Broadcast.broadcasted(x...)
end
function Cassette.overdub(::HasBranchingCtx, ::Type{Base.OneTo{T}},
stop) where {T <: Integer}
Base.OneTo{T}(stop)
function Cassette.overdub(
::HasBranchingCtx, ::Type{Base.OneTo{T}},
stop
) where {T <: Integer}
return Base.OneTo{T}(stop)
end

export hasbranching
Expand Down
17 changes: 9 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,23 @@ end
f_branch() = true ? 1 : 0
@test FunctionProperties.hasbranching(f_branch)
function FunctionProperties.Cassette.overdub(
::FunctionProperties.HasBranchingCtx, ::typeof(f_branch), x...)
f_branch(x...)
::FunctionProperties.HasBranchingCtx, ::typeof(f_branch), x...
)
return f_branch(x...)
end
@test !FunctionProperties.hasbranching(f_branch)

# Test simple mutating functions
function f(dx, x)
@inbounds dx[1] = x[1]
return @inbounds dx[1] = x[1]
end
x = zeros(1)
dx = zeros(1)
@test !FunctionProperties.hasbranching(f, dx, x)

# Test broadcast
function f(x)
cos.(x .+ x .* x)
return cos.(x .+ x .* x)
end
x = [1.0]
@test !FunctionProperties.hasbranching(f, x)
Expand All @@ -43,17 +44,17 @@ x0 = [-4.0f0, 0.0f0]
t = [0.0]

function f(x, ps, st)
ps.weight * x
return ps.weight * x
end
@test !FunctionProperties.hasbranching(f, t, p, st)

function f(x, ps, st)
x .+ x
return x .+ x
end
@test !FunctionProperties.hasbranching(f, t, p, st)

function f2(x, ps, st)
Lux.apply_activation(identity, ps.weight * x .+ vec(ps.bias)), st
return Lux.apply_activation(identity, ps.weight * x .+ vec(ps.bias)), st
end
@test !FunctionProperties.hasbranching(f2, t, p, st)
@test !FunctionProperties.hasbranching(ann, t, p, st)
Expand All @@ -68,7 +69,7 @@ p = ComponentArray(ps)
function dxdt_(dx, x, p, t)
x1, x2 = x
dx[1] = x[2] + first(ann(x, p, st))[1]
dx[2] = first(ann([t, t], p, st))[1]
return dx[2] = first(ann([t, t], p, st))[1]
end
x0 = [-4.0f0, 0.0f0]
ts = Float32.(collect(0.0:0.01:tspan[2]))
Expand Down
Loading