diff --git a/Project.toml b/Project.toml index c51aa400..9d29ebe6 100644 --- a/Project.toml +++ b/Project.toml @@ -4,20 +4,24 @@ authors = ["Chris Rackauckas "] version = "1.8.0" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] +ChainRulesCore = "0.9" +FiniteDifferences = "0.11" Requires = "1.0" julia = "1" [extras] +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ForwardDiff", "Test", "SafeTestsets", "Random"] +test = ["FiniteDifferences", "ForwardDiff", "Test", "SafeTestsets", "Random"] diff --git a/src/ExponentialUtilities.jl b/src/ExponentialUtilities.jl index 7a584b15..4cddebf4 100644 --- a/src/ExponentialUtilities.jl +++ b/src/ExponentialUtilities.jl @@ -1,5 +1,5 @@ module ExponentialUtilities -using LinearAlgebra, SparseArrays, Printf, Requires +using LinearAlgebra, SparseArrays, Printf, Requires, ChainRulesCore """ @diagview(A,d) -> view of the `d`th diagonal of `A`. @@ -20,6 +20,7 @@ include("krylov_phiv_adaptive.jl") include("kiops.jl") include("StegrWork.jl") include("krylov_phiv_error_estimate.jl") +include("krylov_phiv_chainrules.jl") export phi, phi!, KrylovSubspace, arnoldi, arnoldi!, lanczos!, ExpvCache, PhivCache, expv, expv!, exp_generic, phiv, phiv!, kiops, expv_timestep, expv_timestep!, phiv_timestep, phiv_timestep!, diff --git a/src/krylov_phiv_chainrules.jl b/src/krylov_phiv_chainrules.jl new file mode 100644 index 00000000..f44e6f09 --- /dev/null +++ b/src/krylov_phiv_chainrules.jl @@ -0,0 +1,42 @@ +function ChainRulesCore.frule((_, Δt, ΔA, Δb), ::typeof(expv), t, A, b; kwargs...) + w = expv(t, A, b; kwargs...) + ∂w = similar(w) + mul!(∂w, A, w) + ∂w .*= Δt + if !isa(Δb, AbstractZero) + ∂w .+= expv(t, A, Δb; kwargs...) + end + # TODO: handle ΔA + ΔA isa AbstractZero || error("ΔA currently cannot be pushed forward") + return w, ∂w +end + +function ChainRulesCore.rrule(::typeof(expv), t, A, b; kwargs...) + w = expv(t, A, b; kwargs...) + function expv_pullback(Δw) + ∂t = Thunk() do + t̄ = A isa AbstractMatrix ? conj(dot(Δw, A, w)) : dot(mul!(similar(w), A, w), Δw) + return t isa Real ? real(t̄) : t̄ + end + # TODO: handle ∂A + ∂A = @thunk error("Adjoint wrt A not yet implemented") + ∂b = Thunk() do + # using similar is necessary to ensure type-stability + b̄ = similar(b) + _copyto!(b̄, expv(t', A', Δw; kwargs...)) + return b̄ + end + return (NO_FIELDS, ∂t, ∂A, ∂b) + end + expv_pullback(::Zero) = (NO_FIELDS, Zero(), Zero(), Zero()) + return w, expv_pullback +end + +function _copyto!(x, y) + if eltype(x) <: Real && !(eltype(y) <: Real) + x .= real.(y) + else + copyto!(x, y) + end + return x +end diff --git a/test/runtests.jl b/test/runtests.jl index 4172affb..fa9e841c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using Test, LinearAlgebra, Random, SparseArrays, ExponentialUtilities using ExponentialUtilities: getH, getV, _exp! +using ChainRulesCore, FiniteDifferences using ForwardDiff @testset "Exp" begin @@ -228,3 +229,50 @@ struct OpnormFunctor end @test pv ≈ pv′ atol=1e-12 end end + +@testset "expv chain rules" begin + n = 30 + @testset "frule for T=$T" for T in (Float64, ComplexF64) + t = rand(T) + A = randn(T, n, n) + b = randn(T, n) + Δt = FiniteDifferences.rand_tangent(t) + Δb = FiniteDifferences.rand_tangent(b) + + w = expv(t, A, b) + w_ad, ∂w_ad = frule((NO_FIELDS, Δt, Zero(), Δb), expv, t, A, b) + @test w_ad == w + ∂w_fd = jvp(central_fdm(5, 1), (t, b) -> expv(t, A, b), (t, Δt), (b, Δb)) + @test ∂w_ad ≈ ∂w_fd + + w_ad, ∂w_ad = frule((NO_FIELDS, Δt, Zero(), Zero()), expv, t, A, b) + @test w_ad == w + ∂w_fd = jvp(central_fdm(5, 1), t -> expv(t, A, b), (t, Δt)) + @test ∂w_ad ≈ ∂w_fd + + ΔA = FiniteDifferences.rand_tangent(A) + @test_throws ErrorException frule((NO_FIELDS, Δt, ΔA, Δb), expv, t, A, b) + end + + @testset "rrule for T=$T" for T in (Float64, ComplexF64) + t = rand(T) + A = randn(T, n, n) + b = randn(T, n) + w = expv(t, A, b) + Δw = FiniteDifferences.rand_tangent(w) + + w_ad, back = rrule(expv, t, A, b) + @test w_ad == w + ∂self, ∂t_ad, ∂A_ad, ∂b_ad = @inferred back(Δw) + @test ∂self === NO_FIELDS + @test @inferred(extern(∂t_ad)) isa typeof(t) + @test @inferred(extern(∂b_ad)) isa typeof(b) + + ∂t_fd, ∂A_fd, ∂b_fd = j′vp(central_fdm(5, 1), expv, Δw, t, A, b) + @test extern(∂t_ad) ≈ ∂t_fd + @test extern(∂b_ad) ≈ ∂b_fd + @test_throws ErrorException unthunk(∂A_ad) + + @test @inferred(back(Zero())) === (NO_FIELDS, Zero(), Zero(), Zero()) + end +end