From 8e357b81f383646bb485296df238a20ccc5c31e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Wed, 10 Jun 2026 22:22:54 +0300 Subject: [PATCH 1/4] Guard solve_up Enzyme reverse rule against runtime-activity aliased shadows MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Under `set_runtime_activity`, a runtime-inactive value reaches the rule as Duplicated/MixedDuplicated whose shadow IS the primal (dval === val). The reverse rule for `DiffEqBase.solve_up` accumulated every non-Const cotangent into `ptr.dval` unconditionally, so e.g. the du0 cotangent was broadcast-added into the caller's primal u0 whenever the solved problem's u0 aliased an array reachable from a Const argument (the common `setsym_oop`/`remake` pattern in MTK loss functions). The first gradient call returned correct results while silently corrupting the Const problem's u0; subsequent calls were garbage. Skip accumulation when the shadow aliases the primal — a runtime-inactive value accumulates nowhere, exactly as if it were Const. Found while reducing SciML/SciMLSensitivity.jl#1477 (failure mode B). Co-Authored-By: Claude Fable 5 --- lib/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/lib/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl b/lib/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl index c3f69716c0..60e736d2bf 100644 --- a/lib/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl +++ b/lib/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl @@ -124,9 +124,18 @@ module DiffEqBaseEnzymeExt if darg == ChainRulesCore.NoTangent() continue end + # Under `set_runtime_activity`, a runtime-inactive value arrives as + # Duplicated/MixedDuplicated with its shadow ALIASING the primal + # (`dval === val`). Accumulating the cotangent into such a shadow writes + # gradient values into the caller's primal data (e.g. the `u0` of a + # `Const` problem whose array was reused via `remake`), silently + # corrupting subsequent calls. Skip them: a runtime-inactive value + # accumulates nowhere, exactly as if it were `Const`. if ptr isa MixedDuplicated + ptr.dval[] === ptr.val && continue _accumulate_tangent!(ptr.dval[], darg) else + ptr.dval === ptr.val && continue _accumulate_tangent!(ptr.dval, darg) end end From a51b8bcb690491b08a3b82d1101f0c83fb59dca9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Thu, 11 Jun 2026 04:25:10 +0300 Subject: [PATCH 2/4] Add regression test for the solve_up Enzyme rule aliased-shadow guard Tests the runtime-activity scenario fixed in the previous commit: a loss that reuses the Const problem's own u0 via remake must keep correct gradients across repeated calls without mutating the primal problem, and genuinely active u0 must still receive its cotangent. Verified against ForwardDiff with a real SciMLSensitivity adjoint (GaussAdjoint); on the unpatched rule the first call silently corrupts prob.u0 and the second call's gradient is garbage. Runs in the Downstream group (adds SciMLSensitivity and Enzyme to that test environment), guarded out on prerelease Julia like DiffEqBaseEnzymeExt itself. Co-Authored-By: Claude Fable 5 --- lib/DiffEqBase/test/downstream/Project.toml | 4 ++ .../test/downstream/enzyme_solve_up_rule.jl | 57 +++++++++++++++++++ lib/DiffEqBase/test/runtests.jl | 4 ++ 3 files changed, 65 insertions(+) create mode 100644 lib/DiffEqBase/test/downstream/enzyme_solve_up_rule.jl diff --git a/lib/DiffEqBase/test/downstream/Project.toml b/lib/DiffEqBase/test/downstream/Project.toml index 3183685d3b..d1b907b8c2 100644 --- a/lib/DiffEqBase/test/downstream/Project.toml +++ b/lib/DiffEqBase/test/downstream/Project.toml @@ -4,6 +4,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FlexUnits = "76e01b6b-c995-4ce6-8559-91e72a3d4e95" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8" @@ -21,6 +22,7 @@ OrdinaryDiffEqSymplecticRK = "fa646aed-7ef9-47eb-84c4-9443fc8cbfa8" OrdinaryDiffEqVerner = "79d7bb75-1356-48c1-b8c0-6832512096c2" SDEProblemLibrary = "c72e72a9-a271-4b2b-8966-303ed956772e" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" @@ -41,8 +43,10 @@ StochasticDiffEq = {path = "../../../StochasticDiffEq"} [compat] ADTypes = "1" DifferentiationInterface = "0.7" +Enzyme = "0.13.100" ForwardDiff = "0.10, 1" GTPSA = "1.4" MultiScaleArrays = "1.8" OrdinaryDiffEq = "7" +SciMLSensitivity = "7" StochasticDiffEq = "7" diff --git a/lib/DiffEqBase/test/downstream/enzyme_solve_up_rule.jl b/lib/DiffEqBase/test/downstream/enzyme_solve_up_rule.jl new file mode 100644 index 0000000000..e56400fe73 --- /dev/null +++ b/lib/DiffEqBase/test/downstream/enzyme_solve_up_rule.jl @@ -0,0 +1,57 @@ +# Regression test for the DiffEqBaseEnzymeExt `solve_up` reverse rule guard +# (SciML/OrdinaryDiffEq.jl#3740): under `set_runtime_activity`, a u0 that aliases a +# Const-annotated problem's own array must not receive the du0 cotangent (its +# "shadow" IS the primal). First gradient call used to be correct while silently +# corrupting `prob.u0`; the second call then returned garbage. +# +# Deps: SciMLSensitivity (real adjoint), OrdinaryDiffEqVerner, Enzyme, ForwardDiff. + +using DiffEqBase, Enzyme, Test +using SciMLSensitivity +using OrdinaryDiffEqVerner +using ForwardDiff + +f_oop(u, p, t) = p .* u +u0_init = [2.0, 3.0] +p_init = [0.5, 0.7] +prob = ODEProblem(f_oop, copy(u0_init), (0.0, 1.0), copy(p_init)) + +solve_kwargs = (; saveat = 0.25, abstol = 1e-8, reltol = 1e-8) + +@testset "runtime-activity aliased u0 is not corrupted" begin + # u0 is the Const problem's own array reused via remake — the + # `setsym_oop`/`remake` pattern of MTK loss functions. + function loss_aliased(p, q) + prob = q[1] + prob2 = remake(prob; u0 = prob.u0, p = p) + sol = solve(prob2, Vern7(); sensealg = GaussAdjoint(), solve_kwargs...) + return sum(abs2, Array(sol)) + end + + q = (prob,) + g_ref = ForwardDiff.gradient(p -> loss_aliased(p, q), p_init) + + g1 = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), loss_aliased, copy(p_init), Const(q))[1] + @test g1 ≈ g_ref rtol = 1e-5 + @test prob.u0 == u0_init # primal problem must NOT have been mutated + g2 = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), loss_aliased, copy(p_init), Const(q))[1] + @test g2 ≈ g_ref rtol = 1e-5 # second call sees uncorrupted state +end + +@testset "genuinely active u0 still accumulates" begin + # The guard must not skip accumulation when a real shadow exists: here both + # u0 and p derive from the differentiated input, so du0 must flow. + function loss_active(x, q) + prob2 = remake(q[1]; u0 = x[1:2], p = x[3:4]) + sol = solve(prob2, Vern7(); sensealg = GaussAdjoint(), solve_kwargs...) + return sum(abs2, Array(sol)) + end + + q = (prob,) + x0 = vcat(u0_init, p_init) + g_ref = ForwardDiff.gradient(x -> loss_active(x, q), x0) + @test maximum(abs, g_ref[1:2]) > 0 # sanity: u0 gradient is nonzero + + g = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), loss_active, copy(x0), Const(q))[1] + @test g ≈ g_ref rtol = 1e-5 +end diff --git a/lib/DiffEqBase/test/runtests.jl b/lib/DiffEqBase/test/runtests.jl index b1a9f02f07..0d3738aba0 100644 --- a/lib/DiffEqBase/test/runtests.jl +++ b/lib/DiffEqBase/test/runtests.jl @@ -83,6 +83,10 @@ end @time @safetestset "SubArray Support" include("downstream/subarray_support.jl") @time @safetestset "Unitful" include("downstream/unitful.jl") @time @safetestset "FlexUnits" include("downstream/flexunits.jl") + # DiffEqBaseEnzymeExt is disabled on prerelease Julia + if isempty(VERSION.prerelease) + @time @safetestset "Enzyme solve_up rule" include("downstream/enzyme_solve_up_rule.jl") + end end # Downstream2 tests — additional OrdinaryDiffEq integration tests From ee3b14d198a8283fa60cb5f804787e762fcc48c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Wed, 24 Jun 2026 09:54:37 +0300 Subject: [PATCH 3/4] detect runtime activity --- lib/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lib/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl b/lib/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl index 60e736d2bf..8a0f7d4bfa 100644 --- a/lib/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl +++ b/lib/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl @@ -114,6 +114,7 @@ module DiffEqBaseEnzymeExt # `sensealg` is inactive (see augmented_primal note); skip its slot # whether it arrived as Const or as a runtime-activity-promoted # Duplicated/MixedDuplicated/Active. + rta = Enzyme.EnzymeRules.runtime_activity(config) # detect runtime activity for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...)) if ptr isa Enzyme.Const continue @@ -124,7 +125,7 @@ module DiffEqBaseEnzymeExt if darg == ChainRulesCore.NoTangent() continue end - # Under `set_runtime_activity`, a runtime-inactive value arrives as + # Under `set_runtime_activity` (detected by rta), a runtime-inactive value arrives as # Duplicated/MixedDuplicated with its shadow ALIASING the primal # (`dval === val`). Accumulating the cotangent into such a shadow writes # gradient values into the caller's primal data (e.g. the `u0` of a @@ -132,10 +133,10 @@ module DiffEqBaseEnzymeExt # corrupting subsequent calls. Skip them: a runtime-inactive value # accumulates nowhere, exactly as if it were `Const`. if ptr isa MixedDuplicated - ptr.dval[] === ptr.val && continue + rta && ptr.dval[] === ptr.val && continue _accumulate_tangent!(ptr.dval[], darg) else - ptr.dval === ptr.val && continue + rta && ptr.dval === ptr.val && continue _accumulate_tangent!(ptr.dval, darg) end end From 3847a8924661bf79b0ac4100df1dd9c9310d47c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Miclu=C8=9Ba-C=C3=A2mpeanu?= Date: Wed, 24 Jun 2026 13:29:10 +0300 Subject: [PATCH 4/4] Fix Runic formatting and run Enzyme test ahead of FlexUnits Normalize the float literals in the new enzyme_solve_up_rule.jl test (1e-8 -> 1.0e-8, 1e-5 -> 1.0e-5) so it passes the Runic format check. Reorder the "Enzyme solve_up rule" @safetestset ahead of the Unitful/FlexUnits tests in runtests.jl: the FlexUnits test currently errors (a pre-existing FlexUnits unit-promotion failure, also failing on master), and a @safetestset that errors aborts the rest of the Downstream group, which was shadowing the new test so it never ran in CI. Co-Authored-By: Claude Opus 4.8 (1M context) --- lib/DiffEqBase/test/downstream/enzyme_solve_up_rule.jl | 8 ++++---- lib/DiffEqBase/test/runtests.jl | 8 +++++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/lib/DiffEqBase/test/downstream/enzyme_solve_up_rule.jl b/lib/DiffEqBase/test/downstream/enzyme_solve_up_rule.jl index e56400fe73..8e887f59c0 100644 --- a/lib/DiffEqBase/test/downstream/enzyme_solve_up_rule.jl +++ b/lib/DiffEqBase/test/downstream/enzyme_solve_up_rule.jl @@ -16,7 +16,7 @@ u0_init = [2.0, 3.0] p_init = [0.5, 0.7] prob = ODEProblem(f_oop, copy(u0_init), (0.0, 1.0), copy(p_init)) -solve_kwargs = (; saveat = 0.25, abstol = 1e-8, reltol = 1e-8) +solve_kwargs = (; saveat = 0.25, abstol = 1.0e-8, reltol = 1.0e-8) @testset "runtime-activity aliased u0 is not corrupted" begin # u0 is the Const problem's own array reused via remake — the @@ -32,10 +32,10 @@ solve_kwargs = (; saveat = 0.25, abstol = 1e-8, reltol = 1e-8) g_ref = ForwardDiff.gradient(p -> loss_aliased(p, q), p_init) g1 = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), loss_aliased, copy(p_init), Const(q))[1] - @test g1 ≈ g_ref rtol = 1e-5 + @test g1 ≈ g_ref rtol = 1.0e-5 @test prob.u0 == u0_init # primal problem must NOT have been mutated g2 = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), loss_aliased, copy(p_init), Const(q))[1] - @test g2 ≈ g_ref rtol = 1e-5 # second call sees uncorrupted state + @test g2 ≈ g_ref rtol = 1.0e-5 # second call sees uncorrupted state end @testset "genuinely active u0 still accumulates" begin @@ -53,5 +53,5 @@ end @test maximum(abs, g_ref[1:2]) > 0 # sanity: u0 gradient is nonzero g = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), loss_active, copy(x0), Const(q))[1] - @test g ≈ g_ref rtol = 1e-5 + @test g ≈ g_ref rtol = 1.0e-5 end diff --git a/lib/DiffEqBase/test/runtests.jl b/lib/DiffEqBase/test/runtests.jl index 0d3738aba0..faf829e8b9 100644 --- a/lib/DiffEqBase/test/runtests.jl +++ b/lib/DiffEqBase/test/runtests.jl @@ -81,12 +81,14 @@ end @time @safetestset "LabelledArrays Tests" include("downstream/labelledarrays.jl") @time @safetestset "GTPSA Tests" include("downstream/gtpsa.jl") @time @safetestset "SubArray Support" include("downstream/subarray_support.jl") - @time @safetestset "Unitful" include("downstream/unitful.jl") - @time @safetestset "FlexUnits" include("downstream/flexunits.jl") - # DiffEqBaseEnzymeExt is disabled on prerelease Julia + # Run ahead of the Unitful/FlexUnits tests: a @safetestset that errors + # (FlexUnits currently does) aborts the rest of this group, which would + # otherwise shadow this test. DiffEqBaseEnzymeExt is disabled on prerelease. if isempty(VERSION.prerelease) @time @safetestset "Enzyme solve_up rule" include("downstream/enzyme_solve_up_rule.jl") end + @time @safetestset "Unitful" include("downstream/unitful.jl") + @time @safetestset "FlexUnits" include("downstream/flexunits.jl") end # Downstream2 tests — additional OrdinaryDiffEq integration tests