diff --git a/lib/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl b/lib/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl index c3f69716c08..8a0f7d4bfa5 100644 --- a/lib/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl +++ b/lib/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl @@ -114,6 +114,7 @@ module DiffEqBaseEnzymeExt # `sensealg` is inactive (see augmented_primal note); skip its slot # whether it arrived as Const or as a runtime-activity-promoted # Duplicated/MixedDuplicated/Active. + rta = Enzyme.EnzymeRules.runtime_activity(config) # detect runtime activity for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...)) if ptr isa Enzyme.Const continue @@ -124,9 +125,18 @@ module DiffEqBaseEnzymeExt if darg == ChainRulesCore.NoTangent() continue end + # Under `set_runtime_activity` (detected by rta), a runtime-inactive value arrives as + # Duplicated/MixedDuplicated with its shadow ALIASING the primal + # (`dval === val`). Accumulating the cotangent into such a shadow writes + # gradient values into the caller's primal data (e.g. the `u0` of a + # `Const` problem whose array was reused via `remake`), silently + # corrupting subsequent calls. Skip them: a runtime-inactive value + # accumulates nowhere, exactly as if it were `Const`. if ptr isa MixedDuplicated + rta && ptr.dval[] === ptr.val && continue _accumulate_tangent!(ptr.dval[], darg) else + rta && ptr.dval === ptr.val && continue _accumulate_tangent!(ptr.dval, darg) end end diff --git a/lib/DiffEqBase/test/downstream/Project.toml b/lib/DiffEqBase/test/downstream/Project.toml index 3183685d3bf..d1b907b8c2a 100644 --- a/lib/DiffEqBase/test/downstream/Project.toml +++ b/lib/DiffEqBase/test/downstream/Project.toml @@ -4,6 +4,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FlexUnits = "76e01b6b-c995-4ce6-8559-91e72a3d4e95" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8" @@ -21,6 +22,7 @@ OrdinaryDiffEqSymplecticRK = "fa646aed-7ef9-47eb-84c4-9443fc8cbfa8" OrdinaryDiffEqVerner = "79d7bb75-1356-48c1-b8c0-6832512096c2" SDEProblemLibrary = "c72e72a9-a271-4b2b-8966-303ed956772e" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" @@ -41,8 +43,10 @@ StochasticDiffEq = {path = "../../../StochasticDiffEq"} [compat] ADTypes = "1" DifferentiationInterface = "0.7" +Enzyme = "0.13.100" ForwardDiff = "0.10, 1" GTPSA = "1.4" MultiScaleArrays = "1.8" OrdinaryDiffEq = "7" +SciMLSensitivity = "7" StochasticDiffEq = "7" diff --git a/lib/DiffEqBase/test/downstream/enzyme_solve_up_rule.jl b/lib/DiffEqBase/test/downstream/enzyme_solve_up_rule.jl new file mode 100644 index 00000000000..8e887f59c0d --- /dev/null +++ b/lib/DiffEqBase/test/downstream/enzyme_solve_up_rule.jl @@ -0,0 +1,57 @@ +# Regression test for the DiffEqBaseEnzymeExt `solve_up` reverse rule guard +# (SciML/OrdinaryDiffEq.jl#3740): under `set_runtime_activity`, a u0 that aliases a +# Const-annotated problem's own array must not receive the du0 cotangent (its +# "shadow" IS the primal). First gradient call used to be correct while silently +# corrupting `prob.u0`; the second call then returned garbage. +# +# Deps: SciMLSensitivity (real adjoint), OrdinaryDiffEqVerner, Enzyme, ForwardDiff. + +using DiffEqBase, Enzyme, Test +using SciMLSensitivity +using OrdinaryDiffEqVerner +using ForwardDiff + +f_oop(u, p, t) = p .* u +u0_init = [2.0, 3.0] +p_init = [0.5, 0.7] +prob = ODEProblem(f_oop, copy(u0_init), (0.0, 1.0), copy(p_init)) + +solve_kwargs = (; saveat = 0.25, abstol = 1.0e-8, reltol = 1.0e-8) + +@testset "runtime-activity aliased u0 is not corrupted" begin + # u0 is the Const problem's own array reused via remake — the + # `setsym_oop`/`remake` pattern of MTK loss functions. + function loss_aliased(p, q) + prob = q[1] + prob2 = remake(prob; u0 = prob.u0, p = p) + sol = solve(prob2, Vern7(); sensealg = GaussAdjoint(), solve_kwargs...) + return sum(abs2, Array(sol)) + end + + q = (prob,) + g_ref = ForwardDiff.gradient(p -> loss_aliased(p, q), p_init) + + g1 = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), loss_aliased, copy(p_init), Const(q))[1] + @test g1 ≈ g_ref rtol = 1.0e-5 + @test prob.u0 == u0_init # primal problem must NOT have been mutated + g2 = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), loss_aliased, copy(p_init), Const(q))[1] + @test g2 ≈ g_ref rtol = 1.0e-5 # second call sees uncorrupted state +end + +@testset "genuinely active u0 still accumulates" begin + # The guard must not skip accumulation when a real shadow exists: here both + # u0 and p derive from the differentiated input, so du0 must flow. + function loss_active(x, q) + prob2 = remake(q[1]; u0 = x[1:2], p = x[3:4]) + sol = solve(prob2, Vern7(); sensealg = GaussAdjoint(), solve_kwargs...) + return sum(abs2, Array(sol)) + end + + q = (prob,) + x0 = vcat(u0_init, p_init) + g_ref = ForwardDiff.gradient(x -> loss_active(x, q), x0) + @test maximum(abs, g_ref[1:2]) > 0 # sanity: u0 gradient is nonzero + + g = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), loss_active, copy(x0), Const(q))[1] + @test g ≈ g_ref rtol = 1.0e-5 +end diff --git a/lib/DiffEqBase/test/runtests.jl b/lib/DiffEqBase/test/runtests.jl index b1a9f02f074..faf829e8b90 100644 --- a/lib/DiffEqBase/test/runtests.jl +++ b/lib/DiffEqBase/test/runtests.jl @@ -81,6 +81,12 @@ end @time @safetestset "LabelledArrays Tests" include("downstream/labelledarrays.jl") @time @safetestset "GTPSA Tests" include("downstream/gtpsa.jl") @time @safetestset "SubArray Support" include("downstream/subarray_support.jl") + # Run ahead of the Unitful/FlexUnits tests: a @safetestset that errors + # (FlexUnits currently does) aborts the rest of this group, which would + # otherwise shadow this test. DiffEqBaseEnzymeExt is disabled on prerelease. + if isempty(VERSION.prerelease) + @time @safetestset "Enzyme solve_up rule" include("downstream/enzyme_solve_up_rule.jl") + end @time @safetestset "Unitful" include("downstream/unitful.jl") @time @safetestset "FlexUnits" include("downstream/flexunits.jl") end