diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml deleted file mode 100644 index 0f93ea5..0000000 --- a/.JuliaFormatter.toml +++ /dev/null @@ -1,3 +0,0 @@ -style = "sciml" -format_markdown = true -annotate_untyped_fields_with_any = false \ No newline at end of file diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml index 7e46c8d..6762c6f 100644 --- a/.github/workflows/FormatCheck.yml +++ b/.github/workflows/FormatCheck.yml @@ -1,13 +1,19 @@ -name: "Format Check" +name: format-check on: push: branches: + - 'master' - 'main' + - 'release-' tags: '*' pull_request: jobs: - format-check: - name: "Format Check" - uses: "SciML/.github/.github/workflows/format-check.yml@v1" + runic: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: fredrikekre/runic-action@v1 + with: + version: '1' diff --git a/docs/make.jl b/docs/make.jl index 86c33d9..96f55d1 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -5,10 +5,11 @@ cp("./docs/Project.toml", "./docs/src/assets/Project.toml", force = true) pages = [ "Home" => "index.md", - "api.md" + "api.md", ] -makedocs(sitename = "FunctionProperties.jl", +makedocs( + sitename = "FunctionProperties.jl", authors = "Chris Rackauckas", modules = [FunctionProperties], clean = true, doctest = false, linkcheck = true, @@ -16,12 +17,17 @@ makedocs(sitename = "FunctionProperties.jl", :doctest, :linkcheck, :parse_error, - :example_block # Other available options are # :autodocs_block, :cross_references, :docs_block, :eval_block, :example_block, :footnote, :meta_block, :missing_docs, :setup_block + :example_block, # Other available options are # :autodocs_block, :cross_references, :docs_block, :eval_block, :example_block, :footnote, :meta_block, :missing_docs, :setup_block ], - format = Documenter.HTML(analytics = "UA-90474609-3", + format = Documenter.HTML( + analytics = "UA-90474609-3", assets = ["assets/favicon.ico"], - canonical = "https://docs.sciml.ai/FunctionProperties/stable/"), - pages = pages) + canonical = "https://docs.sciml.ai/FunctionProperties/stable/" + ), + pages = pages +) -deploydocs(repo = "github.com/SciML/MultiScaleArrays.jl.git"; - push_preview = true) +deploydocs( + repo = "github.com/SciML/MultiScaleArrays.jl.git"; + push_preview = true +) diff --git a/src/FunctionProperties.jl b/src/FunctionProperties.jl index 34768ca..91198d9 100644 --- a/src/FunctionProperties.jl +++ b/src/FunctionProperties.jl @@ -19,9 +19,11 @@ for (mod, f, n) in DiffRules.diffrules(; filter_modules = nothing) if !(isdefined(@__MODULE__, mod) && isdefined(getfield(@__MODULE__, mod), f)) continue # Skip rules for methods not defined in the current scope end - @eval function Cassette.overdub(::HasBranchingCtx, f::Core.Typeof($mod.$f), - x::Vararg{Any, $n}) - f(x...) + @eval function Cassette.overdub( + ::HasBranchingCtx, f::Core.Typeof($mod.$f), + x::Vararg{Any, $n} + ) + return f(x...) end end @@ -30,36 +32,57 @@ function _pass(::Type{<:HasBranchingCtx}, reflection::Cassette.Reflection) if any(x -> isa(x, GotoIfNot), ir.code) printbranch && println("GotoIfNot detected in $(reflection.method)\nir = $ir\n") - Cassette.insert_statements!(ir.code, ir.codelocs, + Cassette.insert_statements!( + ir.code, ir.codelocs, (stmt, i) -> i == 1 ? 3 : nothing, - (stmt, - i) -> Any[ - Expr(:call, - Expr(:nooverdub, - GlobalRef(Base, :getfield)), + ( + stmt, + i, + ) -> Any[ + Expr( + :call, + Expr( + :nooverdub, + GlobalRef(Base, :getfield) + ), Expr(:contextslot), - QuoteNode(:metadata)), - Expr(:call, - Expr(:nooverdub, - GlobalRef(Base, :setindex!)), + QuoteNode(:metadata) + ), + Expr( + :call, + Expr( + :nooverdub, + GlobalRef(Base, :setindex!) + ), SSAValue(1), true, - QuoteNode(:has_branching)), - stmt]) - Cassette.insert_statements!(ir.code, ir.codelocs, + QuoteNode(:has_branching) + ), + stmt, + ] + ) + Cassette.insert_statements!( + ir.code, ir.codelocs, (stmt, i) -> i > 2 && isa(stmt, Expr) ? 1 : nothing, - (stmt, - i) -> begin + ( + stmt, + i, + ) -> begin callstmt = Meta.isexpr(stmt, :(=)) ? stmt.args[2] : - stmt + stmt Meta.isexpr(stmt, :call) || Meta.isexpr(stmt, :invoke) || return Any[stmt] - callstmt = Expr(callstmt.head, + callstmt = Expr( + callstmt.head, Expr(:nooverdub, callstmt.args[1]), - callstmt.args[2:end]...) - return Any[Meta.isexpr(stmt, :(=)) ? - Expr(:(=), stmt.args[1], callstmt) : - callstmt] - end) + callstmt.args[2:end]... + ) + return Any[ + Meta.isexpr(stmt, :(=)) ? + Expr(:(=), stmt.args[1], callstmt) : + callstmt, + ] + end + ) end return ir end @@ -105,10 +128,10 @@ end Cassette.overdub(::HasBranchingCtx, ::typeof(+), x...) = +(x...) Cassette.overdub(::HasBranchingCtx, ::typeof(*), x...) = *(x...) function Cassette.overdub(::HasBranchingCtx, ::typeof(Base.materialize), x...) - Base.materialize(x...) + return Base.materialize(x...) end function Cassette.overdub(::HasBranchingCtx, ::typeof(Base.literal_pow), x...) - Base.literal_pow(x...) + return Base.literal_pow(x...) end Cassette.overdub(::HasBranchingCtx, ::typeof(Base.getindex), x...) = Base.getindex(x...) Cassette.overdub(::HasBranchingCtx, ::typeof(Base.setindex!), x...) = Base.setindex!(x...) @@ -123,11 +146,13 @@ Cassette.overdub(::HasBranchingCtx, ::typeof(Base.cat), x...) = Base.cat(x...) Cassette.overdub(::HasBranchingCtx, ::typeof(Base.stack), x...) = Base.stack(x...) function Cassette.overdub(::HasBranchingCtx, ::typeof(Base.Broadcast.broadcasted), x...) - Base.Broadcast.broadcasted(x...) + return Base.Broadcast.broadcasted(x...) end -function Cassette.overdub(::HasBranchingCtx, ::Type{Base.OneTo{T}}, - stop) where {T <: Integer} - Base.OneTo{T}(stop) +function Cassette.overdub( + ::HasBranchingCtx, ::Type{Base.OneTo{T}}, + stop + ) where {T <: Integer} + return Base.OneTo{T}(stop) end export hasbranching diff --git a/test/runtests.jl b/test/runtests.jl index fdea993..28aa10d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,14 +13,15 @@ end f_branch() = true ? 1 : 0 @test FunctionProperties.hasbranching(f_branch) function FunctionProperties.Cassette.overdub( - ::FunctionProperties.HasBranchingCtx, ::typeof(f_branch), x...) - f_branch(x...) + ::FunctionProperties.HasBranchingCtx, ::typeof(f_branch), x... + ) + return f_branch(x...) end @test !FunctionProperties.hasbranching(f_branch) # Test simple mutating functions function f(dx, x) - @inbounds dx[1] = x[1] + return @inbounds dx[1] = x[1] end x = zeros(1) dx = zeros(1) @@ -28,7 +29,7 @@ dx = zeros(1) # Test broadcast function f(x) - cos.(x .+ x .* x) + return cos.(x .+ x .* x) end x = [1.0] @test !FunctionProperties.hasbranching(f, x) @@ -43,17 +44,17 @@ x0 = [-4.0f0, 0.0f0] t = [0.0] function f(x, ps, st) - ps.weight * x + return ps.weight * x end @test !FunctionProperties.hasbranching(f, t, p, st) function f(x, ps, st) - x .+ x + return x .+ x end @test !FunctionProperties.hasbranching(f, t, p, st) function f2(x, ps, st) - Lux.apply_activation(identity, ps.weight * x .+ vec(ps.bias)), st + return Lux.apply_activation(identity, ps.weight * x .+ vec(ps.bias)), st end @test !FunctionProperties.hasbranching(f2, t, p, st) @test !FunctionProperties.hasbranching(ann, t, p, st) @@ -68,7 +69,7 @@ p = ComponentArray(ps) function dxdt_(dx, x, p, t) x1, x2 = x dx[1] = x[2] + first(ann(x, p, st))[1] - dx[2] = first(ann([t, t], p, st))[1] + return dx[2] = first(ann([t, t], p, st))[1] end x0 = [-4.0f0, 0.0f0] ts = Float32.(collect(0.0:0.01:tspan[2]))