Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions lib/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ module DiffEqBaseEnzymeExt
# `sensealg` is inactive (see augmented_primal note); skip its slot
# whether it arrived as Const or as a runtime-activity-promoted
# Duplicated/MixedDuplicated/Active.
rta = Enzyme.EnzymeRules.runtime_activity(config) # detect runtime activity
for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...))
if ptr isa Enzyme.Const
continue
Expand All @@ -124,9 +125,18 @@ module DiffEqBaseEnzymeExt
if darg == ChainRulesCore.NoTangent()
continue
end
# Under `set_runtime_activity` (detected by rta), a runtime-inactive value arrives as
# Duplicated/MixedDuplicated with its shadow ALIASING the primal
# (`dval === val`). Accumulating the cotangent into such a shadow writes
# gradient values into the caller's primal data (e.g. the `u0` of a
# `Const` problem whose array was reused via `remake`), silently
# corrupting subsequent calls. Skip them: a runtime-inactive value
# accumulates nowhere, exactly as if it were `Const`.
if ptr isa MixedDuplicated
rta && ptr.dval[] === ptr.val && continue
_accumulate_tangent!(ptr.dval[], darg)
else
rta && ptr.dval === ptr.val && continue
_accumulate_tangent!(ptr.dval, darg)
end
end
Expand Down
4 changes: 4 additions & 0 deletions lib/DiffEqBase/test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FlexUnits = "76e01b6b-c995-4ce6-8559-91e72a3d4e95"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
Expand All @@ -21,6 +22,7 @@ OrdinaryDiffEqSymplecticRK = "fa646aed-7ef9-47eb-84c4-9443fc8cbfa8"
OrdinaryDiffEqVerner = "79d7bb75-1356-48c1-b8c0-6832512096c2"
SDEProblemLibrary = "c72e72a9-a271-4b2b-8966-303ed956772e"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Expand All @@ -41,8 +43,10 @@ StochasticDiffEq = {path = "../../../StochasticDiffEq"}
[compat]
ADTypes = "1"
DifferentiationInterface = "0.7"
Enzyme = "0.13.100"
ForwardDiff = "0.10, 1"
GTPSA = "1.4"
MultiScaleArrays = "1.8"
OrdinaryDiffEq = "7"
SciMLSensitivity = "7"
StochasticDiffEq = "7"
57 changes: 57 additions & 0 deletions lib/DiffEqBase/test/downstream/enzyme_solve_up_rule.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Regression test for the DiffEqBaseEnzymeExt `solve_up` reverse rule guard
# (SciML/OrdinaryDiffEq.jl#3740): under `set_runtime_activity`, a u0 that aliases a
# Const-annotated problem's own array must not receive the du0 cotangent (its
# "shadow" IS the primal). First gradient call used to be correct while silently
# corrupting `prob.u0`; the second call then returned garbage.
#
# Deps: SciMLSensitivity (real adjoint), OrdinaryDiffEqVerner, Enzyme, ForwardDiff.

using DiffEqBase, Enzyme, Test
using SciMLSensitivity
using OrdinaryDiffEqVerner
using ForwardDiff

f_oop(u, p, t) = p .* u
u0_init = [2.0, 3.0]
p_init = [0.5, 0.7]
prob = ODEProblem(f_oop, copy(u0_init), (0.0, 1.0), copy(p_init))

solve_kwargs = (; saveat = 0.25, abstol = 1.0e-8, reltol = 1.0e-8)

@testset "runtime-activity aliased u0 is not corrupted" begin
# u0 is the Const problem's own array reused via remake — the
# `setsym_oop`/`remake` pattern of MTK loss functions.
function loss_aliased(p, q)
prob = q[1]
prob2 = remake(prob; u0 = prob.u0, p = p)
sol = solve(prob2, Vern7(); sensealg = GaussAdjoint(), solve_kwargs...)
return sum(abs2, Array(sol))
end

q = (prob,)
g_ref = ForwardDiff.gradient(p -> loss_aliased(p, q), p_init)

g1 = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), loss_aliased, copy(p_init), Const(q))[1]
@test g1 ≈ g_ref rtol = 1.0e-5
@test prob.u0 == u0_init # primal problem must NOT have been mutated
g2 = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), loss_aliased, copy(p_init), Const(q))[1]
@test g2 ≈ g_ref rtol = 1.0e-5 # second call sees uncorrupted state
end

@testset "genuinely active u0 still accumulates" begin
# The guard must not skip accumulation when a real shadow exists: here both
# u0 and p derive from the differentiated input, so du0 must flow.
function loss_active(x, q)
prob2 = remake(q[1]; u0 = x[1:2], p = x[3:4])
sol = solve(prob2, Vern7(); sensealg = GaussAdjoint(), solve_kwargs...)
return sum(abs2, Array(sol))
end

q = (prob,)
x0 = vcat(u0_init, p_init)
g_ref = ForwardDiff.gradient(x -> loss_active(x, q), x0)
@test maximum(abs, g_ref[1:2]) > 0 # sanity: u0 gradient is nonzero

g = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), loss_active, copy(x0), Const(q))[1]
@test g ≈ g_ref rtol = 1.0e-5
end
6 changes: 6 additions & 0 deletions lib/DiffEqBase/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ end
@time @safetestset "LabelledArrays Tests" include("downstream/labelledarrays.jl")
@time @safetestset "GTPSA Tests" include("downstream/gtpsa.jl")
@time @safetestset "SubArray Support" include("downstream/subarray_support.jl")
# Run ahead of the Unitful/FlexUnits tests: a @safetestset that errors
# (FlexUnits currently does) aborts the rest of this group, which would
# otherwise shadow this test. DiffEqBaseEnzymeExt is disabled on prerelease.
if isempty(VERSION.prerelease)
@time @safetestset "Enzyme solve_up rule" include("downstream/enzyme_solve_up_rule.jl")
end
@time @safetestset "Unitful" include("downstream/unitful.jl")
@time @safetestset "FlexUnits" include("downstream/flexunits.jl")
end
Expand Down
Loading