diff --git a/.githooks/pre-commit b/.githooks/pre-commit new file mode 100755 index 0000000..039bb60 --- /dev/null +++ b/.githooks/pre-commit @@ -0,0 +1,15 @@ +#!/usr/bin/env bash +# Pre-commit hook: format staged Julia files with Runic + +JULIA_FILES=$(git diff --cached --name-only --diff-filter=ACM -- '*.jl') + +if [ -z "$JULIA_FILES" ]; then + exit 0 +fi + +julia --project=@runic --startup-file=no -e 'using Runic; exit(Runic.main(ARGS))' -- --inplace $JULIA_FILES + +# Re-stage the formatted files +echo "$JULIA_FILES" | xargs git add + +exit 0 diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f6a93eb..f4c9e69 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -25,7 +25,6 @@ jobs: version: - '1' - 'lts' - - 'pre' uses: "SciML/.github/.github/workflows/tests.yml@v1" with: julia-version: "${{ matrix.version }}" diff --git a/.gitignore b/.gitignore index 0e34500..d92db27 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ Manifest.toml /.benchmarkci /benchmark/*.json LocalPreferences.toml -docs/build \ No newline at end of file +docs/build +benchmark/results.json diff --git a/Project.toml b/Project.toml index 5371ac7..578117a 100644 --- a/Project.toml +++ b/Project.toml @@ -3,29 +3,26 @@ uuid = "e0ca9c66-1f9e-11ec-127a-1304ce62169c" version = "1.1.0" authors = ["various contributors"] +[workspace] +projects = ["test", "benchmark"] + [deps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" -UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [compat] -ChainRulesCore = "1.24" CommonSolve = "0.2" -DiffEqBase = "6.145, 7" -Distributions = "0.25" +ConcreteStructs = "0.2.3" +DiffEqBase = "6" LinearAlgebra = "1" -PDMats = "0.11" PrecompileTools = "1" -RecursiveArrayTools = "2.34, 4" -SciMLBase = "2, 3" +SciMLBase = "2" +StaticArrays = "1" SymbolicIndexingInterface = "0.3" -UnPack = "1" -julia = "1.8" +julia = "1.10" diff --git a/benchmark/Project.toml b/benchmark/Project.toml index 6be6e89..f5cd85d 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -1,5 +1,14 @@ [deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" DifferenceEquations = "e0ca9c66-1f9e-11ec-127a-1304ce62169c" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[sources] +DifferenceEquations = {path = ".."} diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 2e6df63..c3d55ed 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -1,5 +1,5 @@ -using DifferenceEquations, BenchmarkTools -using Test, LinearAlgebra, Random +using DifferenceEquations, BenchmarkTools, Enzyme +using LinearAlgebra, Random, StaticArrays # Check if MKL or not julia_mkl = @static if VERSION < v"1.7" @@ -13,14 +13,33 @@ if !julia_mkl BLAS.set_num_threads(openblas_threads) end -println("Running Testsuite with Threads.nthreads = $(Threads.nthreads()) MKL = $julia_mkl, and BLAS.num_threads = $(BLAS.get_num_threads()) \n") +println( + "Threads.nthreads = $(Threads.nthreads()), MKL = $julia_mkl, " * + "BLAS.num_threads = $(BLAS.get_num_threads())\n" +) -# Benchmark groups -BenchmarkTools.DEFAULT_PARAMETERS.seconds = 15.0 # 10 seconds per benchmark by default. -BenchmarkTools.DEFAULT_PARAMETERS.evals = 3 +BenchmarkTools.DEFAULT_PARAMETERS.seconds = 5.0 +BenchmarkTools.DEFAULT_PARAMETERS.evals = 1 -const SUITE = BenchmarkGroup() -SUITE["linear"] = include(pkgdir(DifferenceEquations) * "/benchmark/linear.jl") -SUITE["quadratic"] = include(pkgdir(DifferenceEquations) * "/benchmark/quadratic.jl") +# Enzyme reverse-mode AD corrupts GC metadata under repeated invocation, causing segfaults. +# GC disabled globally to prevent GC from running during Enzyme AD. +# Between benchmark samples, Enzyme @benchmarkable calls use a `teardown` to briefly +# re-enable GC, collect, and disable again — safe because Enzyme is not running at that point. +# This prevents OOM from leaked memory accumulating across samples. +# Upstream: https://github.com/EnzymeAD/Enzyme.jl/issues/2355 +GC.enable(false) -# results = run(SUITE; verbose = true) +const SUITE = BenchmarkGroup() +const _bdir = joinpath(pkgdir(DifferenceEquations), "benchmark") +SUITE["kalman"] = include(joinpath(_bdir, "enzyme_kalman.jl")) +SUITE["linear_likelihood"] = include(joinpath(_bdir, "enzyme_linear_likelihood.jl")) +SUITE["linear_simulation"] = include(joinpath(_bdir, "enzyme_linear_simulation.jl")) +SUITE["quadratic"] = include(joinpath(_bdir, "enzyme_quadratic.jl")) +SUITE["static_arrays"] = include(joinpath(_bdir, "static_arrays.jl")) +SUITE["ensemble"] = include(joinpath(_bdir, "ensemble.jl")) +SUITE["forwarddiff_kalman"] = include(joinpath(_bdir, "forwarddiff_kalman.jl")) +SUITE["forwarddiff_linear_likelihood"] = include(joinpath(_bdir, "forwarddiff_linear_likelihood.jl")) +SUITE["forwarddiff_linear_simulation"] = include(joinpath(_bdir, "forwarddiff_linear_simulation.jl")) +SUITE["conditional_likelihood"] = include(joinpath(_bdir, "enzyme_conditional_likelihood.jl")) +SUITE["forwarddiff_conditional_likelihood"] = include(joinpath(_bdir, "forwarddiff_conditional_likelihood.jl")) +SUITE["gradient_comparison"] = include(joinpath(_bdir, "gradient_comparison.jl")) diff --git a/benchmark/ensemble.jl b/benchmark/ensemble.jl new file mode 100644 index 0000000..61040c0 --- /dev/null +++ b/benchmark/ensemble.jl @@ -0,0 +1,249 @@ +# Manual ensemble loop with Enzyme AD +# NOT using EnsembleProblem — Enzyme cannot differentiate through DiffEqBase dispatch. +# Uses construct-inside + solve! in a tight loop over trajectories. +# +# Returns ENS_BENCH BenchmarkGroup + +using Enzyme: make_zero, make_zero! +using DifferenceEquations: init, solve!, StateSpaceWorkspace, fill_zero!! + +const ENS_BENCH = BenchmarkGroup() +ENS_BENCH["raw"] = BenchmarkGroup() +ENS_BENCH["forward"] = BenchmarkGroup() +ENS_BENCH["reverse"] = BenchmarkGroup() + +# ============================================================================= +# Problem sizes +# ============================================================================= + +const p_ens_small = (; N = 2, K = 1, M = 2, T = 10, N_traj = 20) +const p_ens_large = (; N = 5, K = 2, M = 3, T = 50, N_traj = 50) + +# ============================================================================= +# Problem setup +# ============================================================================= + +function make_ensemble_benchmark(; N, K, M, T, N_traj, seed = 42) + Random.seed!(seed) + A_raw = randn(N, N) + A = 0.5 * A_raw / maximum(abs.(eigvals(A_raw))) + B = 0.1 * randn(N, K) + C = randn(M, N) + u0 = zeros(N) + + # Pre-generate noise for each trajectory + all_noise = [[randn(K) for _ in 1:T] for _ in 1:N_traj] + + # Pre-allocate sol/cache for each trajectory (simulation only, no obs) + prob_template = LinearStateSpaceProblem(A, B, u0, (0, T); C, noise = all_noise[1]) + all_sol = [deepcopy(init(prob_template, DirectIteration()).output) for _ in 1:N_traj] + all_cache = [deepcopy(init(prob_template, DirectIteration()).cache) for _ in 1:N_traj] + + # Shadows for AD + dA = make_zero(A) + dB = make_zero(B) + dC = make_zero(C) + du0 = make_zero(u0) + dall_noise = [[make_zero(all_noise[1][1]) for _ in 1:T] for _ in 1:N_traj] + dall_sol = [make_zero(s) for s in all_sol] + dall_cache = [make_zero(c) for c in all_cache] + + return (; + A, B, C, u0, all_noise, all_sol, all_cache, + dA, dB, dC, du0, dall_noise, dall_sol, dall_cache, + ) +end + +# ============================================================================= +# Wrapper functions +# ============================================================================= + +function ensemble_raw!(A, B, C, u0, all_noise, all_sol, all_cache) + total = 0.0 + for i in eachindex(all_noise) + prob = LinearStateSpaceProblem(A, B, u0, (0, length(all_noise[i])); C, noise = all_noise[i]) + ws = StateSpaceWorkspace(prob, DirectIteration(), all_sol[i], all_cache[i]) + solve!(ws) + total += sum(all_sol[i].u[end]) + end + return total / length(all_noise) +end + +function ensemble_forward_bench!(A, B, C, u0, all_noise, all_sol, all_cache) + # Same as raw — Enzyme differentiates through this + return ensemble_raw!(A, B, C, u0, all_noise, all_sol, all_cache) +end + +function ensemble_scalar!(A, B, C, u0, all_noise, all_sol, all_cache)::Float64 + return ensemble_raw!(A, B, C, u0, all_noise, all_sol, all_cache) +end + +# ============================================================================= +# Instantiate problems +# ============================================================================= + +const ens_s = make_ensemble_benchmark(; p_ens_small...) +const ens_l = make_ensemble_benchmark(; p_ens_large...) + +# ============================================================================= +# Raw benchmarks (primal solve through public API) +# ============================================================================= + +function raw_ens!(A, B, C, u0, all_noise, all_sol, all_cache) + return ensemble_raw!(A, B, C, u0, all_noise, all_sol, all_cache) +end + +# Warmup +raw_ens!( + ens_s.A, ens_s.B, ens_s.C, ens_s.u0, ens_s.all_noise, + ens_s.all_sol, ens_s.all_cache +) +raw_ens!( + ens_l.A, ens_l.B, ens_l.C, ens_l.u0, ens_l.all_noise, + ens_l.all_sol, ens_l.all_cache +) + +ENS_BENCH["raw"]["small"] = @benchmarkable raw_ens!( + $(ens_s.A), $(ens_s.B), $(ens_s.C), $(ens_s.u0), $(ens_s.all_noise), + $(ens_s.all_sol), $(ens_s.all_cache) +) +ENS_BENCH["raw"]["large"] = @benchmarkable raw_ens!( + $(ens_l.A), $(ens_l.B), $(ens_l.C), $(ens_l.u0), $(ens_l.all_noise), + $(ens_l.all_sol), $(ens_l.all_cache) +) + +# ============================================================================= +# Forward mode AD — perturb A[1,1], return computed arrays +# ============================================================================= + +function forward_ensemble_bench!( + A, B, C, u0, all_noise, all_sol, all_cache, + dA, dB, dC, du0, dall_noise, dall_sol, dall_cache + ) + # Zero all shadows + dA = fill_zero!!(dA); dB = fill_zero!!(dB); dC = fill_zero!!(dC); du0 = fill_zero!!(du0) + @inbounds for i in eachindex(dall_noise) + for j in eachindex(dall_noise[i]) + dall_noise[i][j] = fill_zero!!(dall_noise[i][j]) + end + end + @inbounds for i in eachindex(dall_sol) + make_zero!(dall_sol[i]) + end + @inbounds for i in eachindex(dall_cache) + make_zero!(dall_cache[i]) + end + # Set perturbation direction + dA[1, 1] = 1.0 + + autodiff( + Forward, ensemble_forward_bench!, + Duplicated(A, dA), Duplicated(B, dB), Duplicated(C, dC), + Duplicated(u0, du0), Duplicated(all_noise, dall_noise), + Duplicated(all_sol, dall_sol), + Duplicated(all_cache, dall_cache) + ) + return nothing +end + +# Warmup +forward_ensemble_bench!( + copy(ens_s.A), copy(ens_s.B), copy(ens_s.C), copy(ens_s.u0), + [[copy(n) for n in traj] for traj in ens_s.all_noise], + ens_s.all_sol, ens_s.all_cache, + ens_s.dA, ens_s.dB, ens_s.dC, ens_s.du0, + ens_s.dall_noise, ens_s.dall_sol, ens_s.dall_cache +) + +ENS_BENCH["forward"]["small"] = @benchmarkable forward_ensemble_bench!( + $(copy(ens_s.A)), $(copy(ens_s.B)), $(copy(ens_s.C)), $(copy(ens_s.u0)), + $([[copy(n) for n in traj] for traj in ens_s.all_noise]), + $(ens_s.all_sol), $(ens_s.all_cache), + $(ens_s.dA), $(ens_s.dB), $(ens_s.dC), $(ens_s.du0), + $(ens_s.dall_noise), $(ens_s.dall_sol), $(ens_s.dall_cache) +) + +# Warmup large +forward_ensemble_bench!( + copy(ens_l.A), copy(ens_l.B), copy(ens_l.C), copy(ens_l.u0), + [[copy(n) for n in traj] for traj in ens_l.all_noise], + ens_l.all_sol, ens_l.all_cache, + ens_l.dA, ens_l.dB, ens_l.dC, ens_l.du0, + ens_l.dall_noise, ens_l.dall_sol, ens_l.dall_cache +) + +ENS_BENCH["forward"]["large"] = @benchmarkable forward_ensemble_bench!( + $(copy(ens_l.A)), $(copy(ens_l.B)), $(copy(ens_l.C)), $(copy(ens_l.u0)), + $([[copy(n) for n in traj] for traj in ens_l.all_noise]), + $(ens_l.all_sol), $(ens_l.all_cache), + $(ens_l.dA), $(ens_l.dB), $(ens_l.dC), $(ens_l.du0), + $(ens_l.dall_noise), $(ens_l.dall_sol), $(ens_l.dall_cache) +) + +# ============================================================================= +# Reverse mode AD — all Duplicated, scalar return with Active +# ============================================================================= + +function reverse_ensemble_bench!( + A, B, C, u0, all_noise, all_sol, all_cache, + dA, dB, dC, du0, dall_noise, dall_sol, dall_cache + ) + # Zero all shadows + dA = fill_zero!!(dA); dB = fill_zero!!(dB); dC = fill_zero!!(dC); du0 = fill_zero!!(du0) + @inbounds for i in eachindex(dall_noise) + for j in eachindex(dall_noise[i]) + dall_noise[i][j] = fill_zero!!(dall_noise[i][j]) + end + end + @inbounds for i in eachindex(dall_sol) + make_zero!(dall_sol[i]) + end + @inbounds for i in eachindex(dall_cache) + make_zero!(dall_cache[i]) + end + + autodiff( + Reverse, ensemble_scalar!, Active, + Duplicated(A, dA), Duplicated(B, dB), Duplicated(C, dC), + Duplicated(u0, du0), Duplicated(all_noise, dall_noise), + Duplicated(all_sol, dall_sol), + Duplicated(all_cache, dall_cache) + ) + return nothing +end + +# Warmup +reverse_ensemble_bench!( + copy(ens_s.A), copy(ens_s.B), copy(ens_s.C), copy(ens_s.u0), + [[copy(n) for n in traj] for traj in ens_s.all_noise], + ens_s.all_sol, ens_s.all_cache, + ens_s.dA, ens_s.dB, ens_s.dC, ens_s.du0, + ens_s.dall_noise, ens_s.dall_sol, ens_s.dall_cache +) + +ENS_BENCH["reverse"]["small"] = @benchmarkable reverse_ensemble_bench!( + $(copy(ens_s.A)), $(copy(ens_s.B)), $(copy(ens_s.C)), $(copy(ens_s.u0)), + $([[copy(n) for n in traj] for traj in ens_s.all_noise]), + $(ens_s.all_sol), $(ens_s.all_cache), + $(ens_s.dA), $(ens_s.dB), $(ens_s.dC), $(ens_s.du0), + $(ens_s.dall_noise), $(ens_s.dall_sol), $(ens_s.dall_cache) +) + +# Warmup large +reverse_ensemble_bench!( + copy(ens_l.A), copy(ens_l.B), copy(ens_l.C), copy(ens_l.u0), + [[copy(n) for n in traj] for traj in ens_l.all_noise], + ens_l.all_sol, ens_l.all_cache, + ens_l.dA, ens_l.dB, ens_l.dC, ens_l.du0, + ens_l.dall_noise, ens_l.dall_sol, ens_l.dall_cache +) + +ENS_BENCH["reverse"]["large"] = @benchmarkable reverse_ensemble_bench!( + $(copy(ens_l.A)), $(copy(ens_l.B)), $(copy(ens_l.C)), $(copy(ens_l.u0)), + $([[copy(n) for n in traj] for traj in ens_l.all_noise]), + $(ens_l.all_sol), $(ens_l.all_cache), + $(ens_l.dA), $(ens_l.dB), $(ens_l.dC), $(ens_l.du0), + $(ens_l.dall_noise), $(ens_l.dall_sol), $(ens_l.dall_cache) +) + +ENS_BENCH diff --git a/benchmark/enzyme_conditional_likelihood.jl b/benchmark/enzyme_conditional_likelihood.jl new file mode 100644 index 0000000..a946e2f --- /dev/null +++ b/benchmark/enzyme_conditional_likelihood.jl @@ -0,0 +1,220 @@ +# Enzyme AD benchmarks for ConditionalLikelihood +# Returns CL_ENZYME BenchmarkGroup + +using Enzyme: make_zero, make_zero! +using DifferenceEquations: init, solve!, StateSpaceWorkspace, fill_zero!! + +const CL_ENZYME = BenchmarkGroup() +CL_ENZYME["raw"] = BenchmarkGroup() +CL_ENZYME["forward"] = BenchmarkGroup() +CL_ENZYME["reverse"] = BenchmarkGroup() + +# ============================================================================= +# Problem sizes +# ============================================================================= + +# CL requires fully-observed state: M = N (observations are state-dimensional) +const p_cl_small = (; N = 5, M = 5, T = 10) +const p_cl_large = (; N = 30, M = 30, T = 100) + +# ============================================================================= +# Problem setup +# ============================================================================= + +function make_cl_benchmark(p; seed = 42) + (; N, M, T) = p + Random.seed!(seed) + A_raw = randn(N, N) + A = 0.5 * A_raw / maximum(abs.(eigvals(A_raw))) + C = randn(M, N) + H = 0.1 * randn(M, M) + R = H * H' + + # Generate observations: state evolves via A, observed via C + noise + x = zeros(N) + y = Vector{Vector{Float64}}(undef, T) + for t in 1:T + x = A * x + 0.1 * randn(N) + y[t] = C * x + H * randn(M) + end + + # Create problem and workspace (B=nothing, no process noise in prediction) + prob = LinearStateSpaceProblem( + A, nothing, zeros(N), (0, T); C, + observables_noise = R, observables = y + ) + ws = init(prob, ConditionalLikelihood()) + sol_out = ws.output + cache = ws.cache + + # Shadow copies for AD + dsol_out = make_zero(sol_out) + dcache = make_zero(cache) + dA = make_zero(A) + dC = make_zero(C) + dH = make_zero(H) + dy = [make_zero(y[1]) for _ in 1:T] + + return (; + A, C, H, R, y, prob, sol_out, cache, + dsol_out, dcache, dA, dC, dH, dy, + ) +end + +# ============================================================================= +# Scalar wrapper for reverse mode (returns logpdf) +# ============================================================================= + +function cl_loglik_bench!(A, C, H, y, sol_out, cache) + R = H * H' + prob = LinearStateSpaceProblem( + A, nothing, zeros(eltype(A), size(A, 1)), (0, length(y)); C, + observables_noise = R, observables = y + ) + ws = StateSpaceWorkspace(prob, ConditionalLikelihood(), sol_out, cache) + return solve!(ws).logpdf +end + +# ============================================================================= +# Forward wrapper (returns state and obs from cache) +# ============================================================================= + +function cl_forward_bench!(A, C, H, y, sol_out, cache) + R = H * H' + prob = LinearStateSpaceProblem( + A, nothing, zeros(eltype(A), size(A, 1)), (0, length(y)); C, + observables_noise = R, observables = y + ) + ws = StateSpaceWorkspace(prob, ConditionalLikelihood(), sol_out, cache) + solve!(ws) + return (sol_out.u[end], sol_out.z[end]) +end + +# ============================================================================= +# Instantiate problems +# ============================================================================= + +const cl_s = make_cl_benchmark(p_cl_small) +const cl_l = make_cl_benchmark(p_cl_large) + +# ============================================================================= +# Raw benchmarks (primal solve through public API) +# ============================================================================= + +function raw_cl!(prob, sol_out, cache) + ws = StateSpaceWorkspace(prob, ConditionalLikelihood(), sol_out, cache) + return solve!(ws).logpdf +end + +# Warmup +raw_cl!(cl_s.prob, cl_s.sol_out, cl_s.cache) +raw_cl!(cl_l.prob, cl_l.sol_out, cl_l.cache) + +CL_ENZYME["raw"]["small_mutable"] = @benchmarkable raw_cl!($(cl_s.prob), $(cl_s.sol_out), $(cl_s.cache)) +CL_ENZYME["raw"]["large_mutable"] = @benchmarkable raw_cl!($(cl_l.prob), $(cl_l.sol_out), $(cl_l.cache)) + +# ============================================================================= +# Forward mode AD — perturb A[1,1], return computed matrices +# ============================================================================= + +function forward_cl_bench!( + A, C, H, y, sol_out, cache, + dA, dC, dH, dy, dsol_out, dcache + ) + # Zero all shadows + make_zero!(dsol_out) + make_zero!(dcache) + dA = fill_zero!!(dA); dC = fill_zero!!(dC); dH = fill_zero!!(dH) + @inbounds for i in eachindex(dy) + dy[i] = fill_zero!!(dy[i]) + end + # Set perturbation direction + dA[1, 1] = 1.0 + + autodiff( + Forward, cl_forward_bench!, + Duplicated(A, dA), Duplicated(C, dC), + Duplicated(H, dH), Duplicated(y, dy), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return nothing +end + +# Warmup +forward_cl_bench!( + copy(cl_s.A), copy(cl_s.C), copy(cl_s.H), + [copy(yi) for yi in cl_s.y], cl_s.sol_out, cl_s.cache, + cl_s.dA, cl_s.dC, cl_s.dH, cl_s.dy, cl_s.dsol_out, cl_s.dcache +) + +CL_ENZYME["forward"]["small_mutable"] = @benchmarkable forward_cl_bench!( + $(copy(cl_s.A)), $(copy(cl_s.C)), $(copy(cl_s.H)), + $([copy(yi) for yi in cl_s.y]), $(cl_s.sol_out), $(cl_s.cache), + $(cl_s.dA), $(cl_s.dC), $(cl_s.dH), $(cl_s.dy), $(cl_s.dsol_out), $(cl_s.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# Warmup large +forward_cl_bench!( + copy(cl_l.A), copy(cl_l.C), copy(cl_l.H), + [copy(yi) for yi in cl_l.y], cl_l.sol_out, cl_l.cache, + cl_l.dA, cl_l.dC, cl_l.dH, cl_l.dy, cl_l.dsol_out, cl_l.dcache +) + +CL_ENZYME["forward"]["large_mutable"] = @benchmarkable forward_cl_bench!( + $(copy(cl_l.A)), $(copy(cl_l.C)), $(copy(cl_l.H)), + $([copy(yi) for yi in cl_l.y]), $(cl_l.sol_out), $(cl_l.cache), + $(cl_l.dA), $(cl_l.dC), $(cl_l.dH), $(cl_l.dy), $(cl_l.dsol_out), $(cl_l.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# ============================================================================= +# Reverse mode AD — all Duplicated, scalar logpdf output +# ============================================================================= + +function reverse_cl_bench!( + A, C, H, y, sol_out, cache, + dA, dC, dH, dy, dsol_out, dcache + ) + # Zero all shadows + make_zero!(dsol_out) + make_zero!(dcache) + dA = fill_zero!!(dA); dC = fill_zero!!(dC); dH = fill_zero!!(dH) + @inbounds for i in eachindex(dy) + dy[i] = fill_zero!!(dy[i]) + end + + autodiff( + Reverse, cl_loglik_bench!, Active, + Duplicated(A, dA), Duplicated(C, dC), + Duplicated(H, dH), Duplicated(y, dy), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return nothing +end + +# Warmup +reverse_cl_bench!( + copy(cl_s.A), copy(cl_s.C), copy(cl_s.H), + [copy(yi) for yi in cl_s.y], cl_s.sol_out, cl_s.cache, + cl_s.dA, cl_s.dC, cl_s.dH, cl_s.dy, cl_s.dsol_out, cl_s.dcache +) + +CL_ENZYME["reverse"]["small_mutable"] = @benchmarkable reverse_cl_bench!( + $(copy(cl_s.A)), $(copy(cl_s.C)), $(copy(cl_s.H)), + $([copy(yi) for yi in cl_s.y]), $(cl_s.sol_out), $(cl_s.cache), + $(cl_s.dA), $(cl_s.dC), $(cl_s.dH), $(cl_s.dy), $(cl_s.dsol_out), $(cl_s.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# Warmup large +reverse_cl_bench!( + copy(cl_l.A), copy(cl_l.C), copy(cl_l.H), + [copy(yi) for yi in cl_l.y], cl_l.sol_out, cl_l.cache, + cl_l.dA, cl_l.dC, cl_l.dH, cl_l.dy, cl_l.dsol_out, cl_l.dcache +) + +CL_ENZYME["reverse"]["large_mutable"] = @benchmarkable reverse_cl_bench!( + $(copy(cl_l.A)), $(copy(cl_l.C)), $(copy(cl_l.H)), + $([copy(yi) for yi in cl_l.y]), $(cl_l.sol_out), $(cl_l.cache), + $(cl_l.dA), $(cl_l.dC), $(cl_l.dH), $(cl_l.dy), $(cl_l.dsol_out), $(cl_l.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +CL_ENZYME diff --git a/benchmark/enzyme_kalman.jl b/benchmark/enzyme_kalman.jl new file mode 100644 index 0000000..054d224 --- /dev/null +++ b/benchmark/enzyme_kalman.jl @@ -0,0 +1,248 @@ +# Enzyme AD benchmarks for Kalman filter +# Returns KALMAN_ENZYME BenchmarkGroup + +using Enzyme: make_zero, make_zero! +using DifferenceEquations: init, solve!, StateSpaceWorkspace, fill_zero!! + +const KALMAN_ENZYME = BenchmarkGroup() +KALMAN_ENZYME["raw"] = BenchmarkGroup() +KALMAN_ENZYME["forward"] = BenchmarkGroup() +KALMAN_ENZYME["reverse"] = BenchmarkGroup() + +# ============================================================================= +# Problem sizes +# ============================================================================= + +const p_kf_small = (; N = 5, M = 2, K = 2, L = 2, T = 10) +const p_kf_large = (; N = 30, M = 10, K = 10, L = 10, T = 100) + +# ============================================================================= +# Problem setup +# ============================================================================= + +function make_kalman_benchmark(p; seed = 42) + (; N, M, K, L, T) = p + Random.seed!(seed) + A_raw = randn(N, N) + A = 0.5 * A_raw / maximum(abs.(eigvals(A_raw))) + B = 0.1 * randn(N, K) + C = randn(M, N) + H = 0.1 * randn(M, L) + R = H * H' + mu_0 = zeros(N) + Sigma_0 = Matrix{Float64}(I, N, N) + + # Generate observations using package's solve + x0 = randn(N) + noise = [randn(K) for _ in 1:T] + sim = solve(LinearStateSpaceProblem(A, B, x0, (0, T); C, noise)) + y = [sim.z[t + 1] + H * randn(L) for t in 1:T] + + # Create problem and workspace + prob = LinearStateSpaceProblem( + A, B, zeros(N), (0, T); C, + u0_prior_mean = mu_0, u0_prior_var = Sigma_0, + observables_noise = R, observables = y + ) + ws = init(prob, KalmanFilter()) + sol_out = ws.output + cache = ws.cache + + # Shadow copies for AD (all Duplicated) + dsol_out = make_zero(sol_out) + dcache = make_zero(cache) + dA = make_zero(A) + dB = make_zero(B) + dC = make_zero(C) + dmu_0 = make_zero(mu_0) + dSigma_0 = make_zero(Sigma_0) + dR = make_zero(R) + dy = [make_zero(y[1]) for _ in 1:T] + + return (; + A, B, C, R, mu_0, Sigma_0, y, prob, sol_out, cache, + dsol_out, dcache, dA, dB, dC, dmu_0, dSigma_0, dR, dy, + ) +end + +# ============================================================================= +# Scalar wrapper for reverse mode (returns logpdf) +# ============================================================================= + +function kalman_loglik_bench!(A, B, C, mu_0, Sigma_0, R, y, sol_out, cache) + prob = LinearStateSpaceProblem( + A, B, zeros(eltype(A), size(A, 1)), (0, length(y)); C, + u0_prior_mean = mu_0, u0_prior_var = Sigma_0, + observables_noise = R, observables = y + ) + ws = StateSpaceWorkspace(prob, KalmanFilter(), sol_out, cache) + return solve!(ws).logpdf +end + +# ============================================================================= +# Forward wrapper (returns solution output matrices for tangent validation) +# ============================================================================= + +function kalman_forward_bench!(A, B, C, mu_0, Sigma_0, R, y, sol_out, cache) + prob = LinearStateSpaceProblem( + A, B, zeros(eltype(A), size(A, 1)), (0, length(y)); C, + u0_prior_mean = mu_0, u0_prior_var = Sigma_0, + observables_noise = R, observables = y + ) + ws = StateSpaceWorkspace(prob, KalmanFilter(), sol_out, cache) + solve!(ws) + return (sol_out.u[end], sol_out.P[end]) +end + +# ============================================================================= +# Instantiate problems +# ============================================================================= + +const kf_s = make_kalman_benchmark(p_kf_small) +const kf_l = make_kalman_benchmark(p_kf_large) + +# ============================================================================= +# Raw benchmarks (primal solve through public API) +# ============================================================================= + +function raw_kalman!(prob, sol_out, cache) + ws = StateSpaceWorkspace(prob, KalmanFilter(), sol_out, cache) + return solve!(ws).logpdf +end + +# Warmup +raw_kalman!(kf_s.prob, kf_s.sol_out, kf_s.cache) +raw_kalman!(kf_l.prob, kf_l.sol_out, kf_l.cache) + +KALMAN_ENZYME["raw"]["small_mutable"] = @benchmarkable raw_kalman!( + $(kf_s.prob), $(kf_s.sol_out), $(kf_s.cache) +) +KALMAN_ENZYME["raw"]["large_mutable"] = @benchmarkable raw_kalman!( + $(kf_l.prob), $(kf_l.sol_out), $(kf_l.cache) +) + +# ============================================================================= +# Forward mode AD — perturb A[1,1], return computed matrices +# ============================================================================= + +function forward_kalman_bench!( + A, B, C, mu_0, Sigma_0, R, y, sol_out, cache, + dA, dB, dC, dmu_0, dSigma_0, dR, dy, dsol_out, dcache + ) + # Zero all shadows + make_zero!(dsol_out) + make_zero!(dcache) + dA = fill_zero!!(dA); dB = fill_zero!!(dB); dC = fill_zero!!(dC) + dmu_0 = fill_zero!!(dmu_0); dSigma_0 = fill_zero!!(dSigma_0); dR = fill_zero!!(dR) + @inbounds for i in eachindex(dy) + dy[i] = fill_zero!!(dy[i]) + end + # Set perturbation direction + dA[1, 1] = 1.0 + + autodiff( + Forward, kalman_forward_bench!, + Duplicated(A, dA), Duplicated(B, dB), Duplicated(C, dC), + Duplicated(mu_0, dmu_0), Duplicated(Sigma_0, dSigma_0), + Duplicated(R, dR), Duplicated(y, dy), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return nothing +end + +# Warmup +forward_kalman_bench!( + copy(kf_s.A), copy(kf_s.B), copy(kf_s.C), + copy(kf_s.mu_0), copy(kf_s.Sigma_0), copy(kf_s.R), + [copy(yi) for yi in kf_s.y], kf_s.sol_out, kf_s.cache, + kf_s.dA, kf_s.dB, kf_s.dC, kf_s.dmu_0, kf_s.dSigma_0, kf_s.dR, + kf_s.dy, kf_s.dsol_out, kf_s.dcache +) + +KALMAN_ENZYME["forward"]["small_mutable"] = @benchmarkable forward_kalman_bench!( + $(copy(kf_s.A)), $(copy(kf_s.B)), $(copy(kf_s.C)), + $(copy(kf_s.mu_0)), $(copy(kf_s.Sigma_0)), $(copy(kf_s.R)), + $([copy(yi) for yi in kf_s.y]), $(kf_s.sol_out), $(kf_s.cache), + $(kf_s.dA), $(kf_s.dB), $(kf_s.dC), $(kf_s.dmu_0), $(kf_s.dSigma_0), $(kf_s.dR), + $(kf_s.dy), $(kf_s.dsol_out), $(kf_s.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# Warmup large +forward_kalman_bench!( + copy(kf_l.A), copy(kf_l.B), copy(kf_l.C), + copy(kf_l.mu_0), copy(kf_l.Sigma_0), copy(kf_l.R), + [copy(yi) for yi in kf_l.y], kf_l.sol_out, kf_l.cache, + kf_l.dA, kf_l.dB, kf_l.dC, kf_l.dmu_0, kf_l.dSigma_0, kf_l.dR, + kf_l.dy, kf_l.dsol_out, kf_l.dcache +) + +KALMAN_ENZYME["forward"]["large_mutable"] = @benchmarkable forward_kalman_bench!( + $(copy(kf_l.A)), $(copy(kf_l.B)), $(copy(kf_l.C)), + $(copy(kf_l.mu_0)), $(copy(kf_l.Sigma_0)), $(copy(kf_l.R)), + $([copy(yi) for yi in kf_l.y]), $(kf_l.sol_out), $(kf_l.cache), + $(kf_l.dA), $(kf_l.dB), $(kf_l.dC), $(kf_l.dmu_0), $(kf_l.dSigma_0), $(kf_l.dR), + $(kf_l.dy), $(kf_l.dsol_out), $(kf_l.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# ============================================================================= +# Reverse mode AD — all Duplicated, scalar logpdf output +# ============================================================================= + +function reverse_kalman_bench!( + A, B, C, mu_0, Sigma_0, R, y, sol_out, cache, + dA, dB, dC, dmu_0, dSigma_0, dR, dy, dsol_out, dcache + ) + # Zero all shadows + make_zero!(dsol_out) + make_zero!(dcache) + dA = fill_zero!!(dA); dB = fill_zero!!(dB); dC = fill_zero!!(dC) + dmu_0 = fill_zero!!(dmu_0); dSigma_0 = fill_zero!!(dSigma_0); dR = fill_zero!!(dR) + @inbounds for i in eachindex(dy) + dy[i] = fill_zero!!(dy[i]) + end + + autodiff( + Reverse, kalman_loglik_bench!, Active, + Duplicated(A, dA), Duplicated(B, dB), Duplicated(C, dC), + Duplicated(mu_0, dmu_0), Duplicated(Sigma_0, dSigma_0), + Duplicated(R, dR), Duplicated(y, dy), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return nothing +end + +# Warmup +reverse_kalman_bench!( + copy(kf_s.A), copy(kf_s.B), copy(kf_s.C), + copy(kf_s.mu_0), copy(kf_s.Sigma_0), copy(kf_s.R), + [copy(yi) for yi in kf_s.y], kf_s.sol_out, kf_s.cache, + kf_s.dA, kf_s.dB, kf_s.dC, kf_s.dmu_0, kf_s.dSigma_0, kf_s.dR, + kf_s.dy, kf_s.dsol_out, kf_s.dcache +) + +KALMAN_ENZYME["reverse"]["small_mutable"] = @benchmarkable reverse_kalman_bench!( + $(copy(kf_s.A)), $(copy(kf_s.B)), $(copy(kf_s.C)), + $(copy(kf_s.mu_0)), $(copy(kf_s.Sigma_0)), $(copy(kf_s.R)), + $([copy(yi) for yi in kf_s.y]), $(kf_s.sol_out), $(kf_s.cache), + $(kf_s.dA), $(kf_s.dB), $(kf_s.dC), $(kf_s.dmu_0), $(kf_s.dSigma_0), $(kf_s.dR), + $(kf_s.dy), $(kf_s.dsol_out), $(kf_s.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# Warmup large +reverse_kalman_bench!( + copy(kf_l.A), copy(kf_l.B), copy(kf_l.C), + copy(kf_l.mu_0), copy(kf_l.Sigma_0), copy(kf_l.R), + [copy(yi) for yi in kf_l.y], kf_l.sol_out, kf_l.cache, + kf_l.dA, kf_l.dB, kf_l.dC, kf_l.dmu_0, kf_l.dSigma_0, kf_l.dR, + kf_l.dy, kf_l.dsol_out, kf_l.dcache +) + +KALMAN_ENZYME["reverse"]["large_mutable"] = @benchmarkable reverse_kalman_bench!( + $(copy(kf_l.A)), $(copy(kf_l.B)), $(copy(kf_l.C)), + $(copy(kf_l.mu_0)), $(copy(kf_l.Sigma_0)), $(copy(kf_l.R)), + $([copy(yi) for yi in kf_l.y]), $(kf_l.sol_out), $(kf_l.cache), + $(kf_l.dA), $(kf_l.dB), $(kf_l.dC), $(kf_l.dmu_0), $(kf_l.dSigma_0), $(kf_l.dR), + $(kf_l.dy), $(kf_l.dsol_out), $(kf_l.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +KALMAN_ENZYME diff --git a/benchmark/enzyme_linear_likelihood.jl b/benchmark/enzyme_linear_likelihood.jl new file mode 100644 index 0000000..72e9058 --- /dev/null +++ b/benchmark/enzyme_linear_likelihood.jl @@ -0,0 +1,247 @@ +# Enzyme AD benchmarks for DirectIteration (joint likelihood) +# Returns DI_ENZYME BenchmarkGroup + +using Enzyme: make_zero, make_zero! +using DifferenceEquations: init, solve!, StateSpaceWorkspace, fill_zero!! + +const DI_ENZYME = BenchmarkGroup() +DI_ENZYME["raw"] = BenchmarkGroup() +DI_ENZYME["forward"] = BenchmarkGroup() +DI_ENZYME["reverse"] = BenchmarkGroup() + +# ============================================================================= +# Problem sizes +# ============================================================================= + +const p_di_small = (; N = 5, M = 2, K = 2, L = 2, T = 10) +const p_di_large = (; N = 30, M = 10, K = 10, L = 10, T = 100) + +# ============================================================================= +# Problem setup +# ============================================================================= + +function make_di_benchmark(p; seed = 42) + (; N, M, K, L, T) = p + Random.seed!(seed) + A_raw = randn(N, N) + A = 0.5 * A_raw / maximum(abs.(eigvals(A_raw))) + B = 0.1 * randn(N, K) + C = randn(M, N) + H = 0.1 * randn(M, L) + R = H * H' + u0 = zeros(N) + noise = [randn(K) for _ in 1:T] + + # Generate observations using package's solve + sim = solve(LinearStateSpaceProblem(A, B, u0, (0, T); C, noise)) + y = [sim.z[t + 1] + H * randn(L) for t in 1:T] + + # Create problem and workspace + prob = LinearStateSpaceProblem( + A, B, u0, (0, T); C, + observables_noise = R, observables = y, noise + ) + ws = init(prob, DirectIteration()) + sol_out = ws.output + cache = ws.cache + + # Shadow copies for AD (all Duplicated) + dsol_out = make_zero(sol_out) + dcache = make_zero(cache) + dA = make_zero(A) + dB = make_zero(B) + dC = make_zero(C) + dH = make_zero(H) + du0 = make_zero(u0) + dnoise = [make_zero(noise[1]) for _ in 1:T] + dy = [make_zero(y[1]) for _ in 1:T] + + return (; + A, B, C, H, R, u0, noise, y, prob, sol_out, cache, + dsol_out, dcache, dA, dB, dC, dH, du0, dnoise, dy, + ) +end + +# ============================================================================= +# Scalar wrapper for reverse mode (returns logpdf) +# ============================================================================= + +function di_loglik_bench!(A, B, C, u0, noise, y, H, sol_out, cache) + R = H * H' + prob = LinearStateSpaceProblem( + A, B, u0, (0, length(y)); C, + observables_noise = R, observables = y, noise + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol_out, cache) + return solve!(ws).logpdf +end + +# ============================================================================= +# Forward wrapper (returns matrices from cache) +# ============================================================================= + +function di_forward_bench!(A, B, C, u0, noise, y, H, sol_out, cache) + R = H * H' + prob = LinearStateSpaceProblem( + A, B, u0, (0, length(y)); C, + observables_noise = R, observables = y, noise + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol_out, cache) + solve!(ws) + return (sol_out.u[end], sol_out.z[end]) +end + +# ============================================================================= +# Instantiate problems +# ============================================================================= + +const di_s = make_di_benchmark(p_di_small) +const di_l = make_di_benchmark(p_di_large) + +# ============================================================================= +# Raw benchmarks (primal solve through public API) +# ============================================================================= + +function raw_di!(prob, sol_out, cache) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol_out, cache) + return solve!(ws).logpdf +end + +# Warmup +raw_di!(di_s.prob, di_s.sol_out, di_s.cache) +raw_di!(di_l.prob, di_l.sol_out, di_l.cache) + +DI_ENZYME["raw"]["small_mutable"] = @benchmarkable raw_di!($(di_s.prob), $(di_s.sol_out), $(di_s.cache)) +DI_ENZYME["raw"]["large_mutable"] = @benchmarkable raw_di!($(di_l.prob), $(di_l.sol_out), $(di_l.cache)) + +# ============================================================================= +# Forward mode AD — perturb A[1,1], return computed matrices +# ============================================================================= + +function forward_di_bench!( + A, B, C, u0, noise, y, H, sol_out, cache, + dA, dB, dC, du0, dnoise, dy, dH, dsol_out, dcache + ) + # Zero all shadows + make_zero!(dsol_out) + make_zero!(dcache) + dA = fill_zero!!(dA); dB = fill_zero!!(dB); dC = fill_zero!!(dC); dH = fill_zero!!(dH) + du0 = fill_zero!!(du0) + @inbounds for i in eachindex(dnoise) + dnoise[i] = fill_zero!!(dnoise[i]) + end + @inbounds for i in eachindex(dy) + dy[i] = fill_zero!!(dy[i]) + end + # Set perturbation direction + dA[1, 1] = 1.0 + + autodiff( + Forward, di_forward_bench!, + Duplicated(A, dA), Duplicated(B, dB), Duplicated(C, dC), + Duplicated(u0, du0), Duplicated(noise, dnoise), Duplicated(y, dy), + Duplicated(H, dH), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return nothing +end + +# Warmup +forward_di_bench!( + copy(di_s.A), copy(di_s.B), copy(di_s.C), + copy(di_s.u0), [copy(n) for n in di_s.noise], [copy(yi) for yi in di_s.y], + copy(di_s.H), di_s.sol_out, di_s.cache, + di_s.dA, di_s.dB, di_s.dC, di_s.du0, di_s.dnoise, di_s.dy, di_s.dH, + di_s.dsol_out, di_s.dcache +) + +DI_ENZYME["forward"]["small_mutable"] = @benchmarkable forward_di_bench!( + $(copy(di_s.A)), $(copy(di_s.B)), $(copy(di_s.C)), + $(copy(di_s.u0)), $([copy(n) for n in di_s.noise]), $([copy(yi) for yi in di_s.y]), + $(copy(di_s.H)), $(di_s.sol_out), $(di_s.cache), + $(di_s.dA), $(di_s.dB), $(di_s.dC), $(di_s.du0), $(di_s.dnoise), $(di_s.dy), $(di_s.dH), + $(di_s.dsol_out), $(di_s.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# Warmup large +forward_di_bench!( + copy(di_l.A), copy(di_l.B), copy(di_l.C), + copy(di_l.u0), [copy(n) for n in di_l.noise], [copy(yi) for yi in di_l.y], + copy(di_l.H), di_l.sol_out, di_l.cache, + di_l.dA, di_l.dB, di_l.dC, di_l.du0, di_l.dnoise, di_l.dy, di_l.dH, + di_l.dsol_out, di_l.dcache +) + +DI_ENZYME["forward"]["large_mutable"] = @benchmarkable forward_di_bench!( + $(copy(di_l.A)), $(copy(di_l.B)), $(copy(di_l.C)), + $(copy(di_l.u0)), $([copy(n) for n in di_l.noise]), $([copy(yi) for yi in di_l.y]), + $(copy(di_l.H)), $(di_l.sol_out), $(di_l.cache), + $(di_l.dA), $(di_l.dB), $(di_l.dC), $(di_l.du0), $(di_l.dnoise), $(di_l.dy), $(di_l.dH), + $(di_l.dsol_out), $(di_l.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# ============================================================================= +# Reverse mode AD — all Duplicated, scalar logpdf output +# ============================================================================= + +function reverse_di_bench!( + A, B, C, u0, noise, y, H, sol_out, cache, + dA, dB, dC, du0, dnoise, dy, dH, dsol_out, dcache + ) + # Zero all shadows + make_zero!(dsol_out) + make_zero!(dcache) + dA = fill_zero!!(dA); dB = fill_zero!!(dB); dC = fill_zero!!(dC); dH = fill_zero!!(dH) + du0 = fill_zero!!(du0) + @inbounds for i in eachindex(dnoise) + dnoise[i] = fill_zero!!(dnoise[i]) + end + @inbounds for i in eachindex(dy) + dy[i] = fill_zero!!(dy[i]) + end + + autodiff( + Reverse, di_loglik_bench!, Active, + Duplicated(A, dA), Duplicated(B, dB), Duplicated(C, dC), + Duplicated(u0, du0), Duplicated(noise, dnoise), Duplicated(y, dy), + Duplicated(H, dH), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return nothing +end + +# Warmup +reverse_di_bench!( + copy(di_s.A), copy(di_s.B), copy(di_s.C), + copy(di_s.u0), [copy(n) for n in di_s.noise], [copy(yi) for yi in di_s.y], + copy(di_s.H), di_s.sol_out, di_s.cache, + di_s.dA, di_s.dB, di_s.dC, di_s.du0, di_s.dnoise, di_s.dy, di_s.dH, + di_s.dsol_out, di_s.dcache +) + +DI_ENZYME["reverse"]["small_mutable"] = @benchmarkable reverse_di_bench!( + $(copy(di_s.A)), $(copy(di_s.B)), $(copy(di_s.C)), + $(copy(di_s.u0)), $([copy(n) for n in di_s.noise]), $([copy(yi) for yi in di_s.y]), + $(copy(di_s.H)), $(di_s.sol_out), $(di_s.cache), + $(di_s.dA), $(di_s.dB), $(di_s.dC), $(di_s.du0), $(di_s.dnoise), $(di_s.dy), $(di_s.dH), + $(di_s.dsol_out), $(di_s.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# Warmup large +reverse_di_bench!( + copy(di_l.A), copy(di_l.B), copy(di_l.C), + copy(di_l.u0), [copy(n) for n in di_l.noise], [copy(yi) for yi in di_l.y], + copy(di_l.H), di_l.sol_out, di_l.cache, + di_l.dA, di_l.dB, di_l.dC, di_l.du0, di_l.dnoise, di_l.dy, di_l.dH, + di_l.dsol_out, di_l.dcache +) + +DI_ENZYME["reverse"]["large_mutable"] = @benchmarkable reverse_di_bench!( + $(copy(di_l.A)), $(copy(di_l.B)), $(copy(di_l.C)), + $(copy(di_l.u0)), $([copy(n) for n in di_l.noise]), $([copy(yi) for yi in di_l.y]), + $(copy(di_l.H)), $(di_l.sol_out), $(di_l.cache), + $(di_l.dA), $(di_l.dB), $(di_l.dC), $(di_l.du0), $(di_l.dnoise), $(di_l.dy), $(di_l.dH), + $(di_l.dsol_out), $(di_l.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +DI_ENZYME diff --git a/benchmark/enzyme_linear_simulation.jl b/benchmark/enzyme_linear_simulation.jl new file mode 100644 index 0000000..0d478df --- /dev/null +++ b/benchmark/enzyme_linear_simulation.jl @@ -0,0 +1,241 @@ +# Enzyme AD benchmarks for Linear DirectIteration simulation (no observations/likelihood) +# Returns SIM_ENZYME BenchmarkGroup + +using Enzyme: make_zero, make_zero! +using DifferenceEquations: init, solve!, StateSpaceWorkspace, fill_zero!! + +const SIM_ENZYME = BenchmarkGroup() +SIM_ENZYME["raw"] = BenchmarkGroup() +SIM_ENZYME["forward"] = BenchmarkGroup() +SIM_ENZYME["reverse"] = BenchmarkGroup() + +# ============================================================================= +# Problem sizes +# ============================================================================= + +const p_sim_small = (; N = 5, M = 3, K = 2, T = 10) +const p_sim_large = (; N = 30, M = 10, K = 10, T = 100) + +# ============================================================================= +# Problem setup +# ============================================================================= + +function make_sim_benchmark(p; seed = 42) + (; N, M, K, T) = p + Random.seed!(seed) + A_raw = randn(N, N) + A = 0.5 * A_raw / maximum(abs.(eigvals(A_raw))) + B = 0.1 * randn(N, K) + C = randn(M, N) + u0 = zeros(N) + noise = [randn(K) for _ in 1:T] + + # Create problem and workspace (no observables, no observables_noise) + prob = LinearStateSpaceProblem(A, B, u0, (0, T); C, noise) + ws = init(prob, DirectIteration()) + sol_out = ws.output + cache = ws.cache + + # Shadow copies for AD (all Duplicated) + dsol_out = make_zero(sol_out) + dcache = make_zero(cache) + dA = make_zero(A) + dB = make_zero(B) + dC = make_zero(C) + du0 = make_zero(u0) + dnoise = [make_zero(noise[1]) for _ in 1:T] + + return (; + A, B, C, u0, noise, prob, sol_out, cache, + dsol_out, dcache, dA, dB, dC, du0, dnoise, + ) +end + +# ============================================================================= +# Scalar wrapper for reverse mode (returns sum of terminal state) +# ============================================================================= + +function sim_scalar_bench!(A, B, C, u0, noise, sol_out, cache)::Float64 + prob = LinearStateSpaceProblem(A, B, u0, (0, length(noise)); C, noise) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol_out, cache) + return sum(solve!(ws).u[end]) +end + +# ============================================================================= +# Forward wrapper (returns terminal state and observation) +# ============================================================================= + +function sim_forward_bench!(A, B, C, u0, noise, sol_out, cache) + prob = LinearStateSpaceProblem(A, B, u0, (0, length(noise)); C, noise) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol_out, cache) + solve!(ws) + return (sol_out.u[end], sol_out.z[end]) +end + +# ============================================================================= +# Instantiate problems +# ============================================================================= + +const sim_s = make_sim_benchmark(p_sim_small) +const sim_l = make_sim_benchmark(p_sim_large) + +# ============================================================================= +# Raw benchmarks (primal solve through public API) +# ============================================================================= + +function raw_sim!(prob, sol_out, cache) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol_out, cache) + return sum(solve!(ws).u[end]) +end + +# Warmup +raw_sim!(sim_s.prob, sim_s.sol_out, sim_s.cache) +raw_sim!(sim_l.prob, sim_l.sol_out, sim_l.cache) + +SIM_ENZYME["raw"]["small_mutable"] = @benchmarkable raw_sim!($(sim_s.prob), $(sim_s.sol_out), $(sim_s.cache)) +SIM_ENZYME["raw"]["large_mutable"] = @benchmarkable raw_sim!($(sim_l.prob), $(sim_l.sol_out), $(sim_l.cache)) + +# ============================================================================= +# Forward mode AD — perturb A[1,1], return terminal state and observation +# ============================================================================= + +function forward_sim_bench!( + A, B, C, u0, noise, sol_out, cache, + dA, dB, dC, du0, dnoise, dsol_out, dcache + ) + # Zero all shadows + make_zero!(dsol_out) + make_zero!(dcache) + dA = fill_zero!!(dA); dB = fill_zero!!(dB); dC = fill_zero!!(dC) + du0 = fill_zero!!(du0) + @inbounds for i in eachindex(dnoise) + dnoise[i] = fill_zero!!(dnoise[i]) + end + # Set perturbation direction + dA[1, 1] = 1.0 + + autodiff( + Forward, sim_forward_bench!, + Duplicated(A, dA), Duplicated(B, dB), Duplicated(C, dC), + Duplicated(u0, du0), Duplicated(noise, dnoise), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return nothing +end + +# Warmup +forward_sim_bench!( + copy(sim_s.A), copy(sim_s.B), copy(sim_s.C), + copy(sim_s.u0), [copy(n) for n in sim_s.noise], + sim_s.sol_out, sim_s.cache, + sim_s.dA, sim_s.dB, sim_s.dC, sim_s.du0, sim_s.dnoise, + sim_s.dsol_out, sim_s.dcache +) + +SIM_ENZYME["forward"]["small_mutable"] = @benchmarkable forward_sim_bench!( + $(copy(sim_s.A)), $(copy(sim_s.B)), $(copy(sim_s.C)), + $(copy(sim_s.u0)), $([copy(n) for n in sim_s.noise]), + $(sim_s.sol_out), $(sim_s.cache), + $(sim_s.dA), $(sim_s.dB), $(sim_s.dC), $(sim_s.du0), $(sim_s.dnoise), + $(sim_s.dsol_out), $(sim_s.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# Warmup large +forward_sim_bench!( + copy(sim_l.A), copy(sim_l.B), copy(sim_l.C), + copy(sim_l.u0), [copy(n) for n in sim_l.noise], + sim_l.sol_out, sim_l.cache, + sim_l.dA, sim_l.dB, sim_l.dC, sim_l.du0, sim_l.dnoise, + sim_l.dsol_out, sim_l.dcache +) + +SIM_ENZYME["forward"]["large_mutable"] = @benchmarkable forward_sim_bench!( + $(copy(sim_l.A)), $(copy(sim_l.B)), $(copy(sim_l.C)), + $(copy(sim_l.u0)), $([copy(n) for n in sim_l.noise]), + $(sim_l.sol_out), $(sim_l.cache), + $(sim_l.dA), $(sim_l.dB), $(sim_l.dC), $(sim_l.du0), $(sim_l.dnoise), + $(sim_l.dsol_out), $(sim_l.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# ============================================================================= +# Reverse mode AD — all Duplicated, scalar sum(u[end]) output +# ============================================================================= + +function reverse_sim_bench!( + A, B, C, u0, noise, sol_out, cache, + dA, dB, dC, du0, dnoise, dsol_out, dcache + ) + # Zero all shadows + make_zero!(dsol_out) + make_zero!(dcache) + dA = fill_zero!!(dA); dB = fill_zero!!(dB); dC = fill_zero!!(dC) + du0 = fill_zero!!(du0) + @inbounds for i in eachindex(dnoise) + dnoise[i] = fill_zero!!(dnoise[i]) + end + + autodiff( + Reverse, sim_scalar_bench!, Active, + Duplicated(A, dA), Duplicated(B, dB), Duplicated(C, dC), + Duplicated(u0, du0), Duplicated(noise, dnoise), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return nothing +end + +# Warmup +reverse_sim_bench!( + copy(sim_s.A), copy(sim_s.B), copy(sim_s.C), + copy(sim_s.u0), [copy(n) for n in sim_s.noise], + sim_s.sol_out, sim_s.cache, + sim_s.dA, sim_s.dB, sim_s.dC, sim_s.du0, sim_s.dnoise, + sim_s.dsol_out, sim_s.dcache +) + +SIM_ENZYME["reverse"]["small_mutable"] = @benchmarkable reverse_sim_bench!( + $(copy(sim_s.A)), $(copy(sim_s.B)), $(copy(sim_s.C)), + $(copy(sim_s.u0)), $([copy(n) for n in sim_s.noise]), + $(sim_s.sol_out), $(sim_s.cache), + $(sim_s.dA), $(sim_s.dB), $(sim_s.dC), $(sim_s.du0), $(sim_s.dnoise), + $(sim_s.dsol_out), $(sim_s.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# Warmup large +reverse_sim_bench!( + copy(sim_l.A), copy(sim_l.B), copy(sim_l.C), + copy(sim_l.u0), [copy(n) for n in sim_l.noise], + sim_l.sol_out, sim_l.cache, + sim_l.dA, sim_l.dB, sim_l.dC, sim_l.du0, sim_l.dnoise, + sim_l.dsol_out, sim_l.dcache +) + +SIM_ENZYME["reverse"]["large_mutable"] = @benchmarkable reverse_sim_bench!( + $(copy(sim_l.A)), $(copy(sim_l.B)), $(copy(sim_l.C)), + $(copy(sim_l.u0)), $([copy(n) for n in sim_l.noise]), + $(sim_l.sol_out), $(sim_l.cache), + $(sim_l.dA), $(sim_l.dB), $(sim_l.dC), $(sim_l.du0), $(sim_l.dnoise), + $(sim_l.dsol_out), $(sim_l.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# --- Edge cases: no noise, no observation equation (raw primal only) --- + +SIM_ENZYME["raw"]["no_noise"] = let + A = sim_s.A; C = sim_s.C; u0 = sim_s.u0 + prob = LinearStateSpaceProblem(A, nothing, u0, (0, p_sim_small.T); C) + ws = init(prob, DirectIteration()) + @benchmarkable bench_nn!(ws_nn) setup = (ws_nn = $ws) +end + +function bench_nn!(ws) + solve!(ws) + return nothing +end + +SIM_ENZYME["raw"]["no_obs_eq"] = let + A = sim_s.A; u0 = sim_s.u0 + prob = LinearStateSpaceProblem(A, nothing, u0, (0, p_sim_small.T)) + ws = init(prob, DirectIteration()) + @benchmarkable bench_nn!(ws_nn) setup = (ws_nn = $ws) +end + +SIM_ENZYME diff --git a/benchmark/enzyme_quadratic.jl b/benchmark/enzyme_quadratic.jl new file mode 100644 index 0000000..487ce91 --- /dev/null +++ b/benchmark/enzyme_quadratic.jl @@ -0,0 +1,421 @@ +# Enzyme AD benchmarks for QuadraticStateSpaceProblem / PrunedQuadraticStateSpaceProblem +# Two sub-groups: "unpruned" and "pruned", each with "raw", "forward", "reverse" × small/large +# Returns QUAD_ENZYME BenchmarkGroup + +using Enzyme: make_zero, make_zero! +using DifferenceEquations: init, solve!, StateSpaceWorkspace, fill_zero!! + +const QUAD_ENZYME = BenchmarkGroup() +QUAD_ENZYME["unpruned"] = BenchmarkGroup() +QUAD_ENZYME["unpruned"]["raw"] = BenchmarkGroup() +QUAD_ENZYME["unpruned"]["forward"] = BenchmarkGroup() +QUAD_ENZYME["unpruned"]["reverse"] = BenchmarkGroup() +QUAD_ENZYME["pruned"] = BenchmarkGroup() +QUAD_ENZYME["pruned"]["raw"] = BenchmarkGroup() +QUAD_ENZYME["pruned"]["forward"] = BenchmarkGroup() +QUAD_ENZYME["pruned"]["reverse"] = BenchmarkGroup() + +# ============================================================================= +# Problem sizes +# ============================================================================= + +const p_quad_small = (; N = 2, K = 1, M = 2, T = 10) +const p_quad_large = (; N = 10, K = 3, M = 6, T = 50) + +# ============================================================================= +# Problem setup +# ============================================================================= + +function make_quad_benchmark(; N, K, M, T, seed = 42, pruned = false) + Random.seed!(seed) + A_1_raw = randn(N, N) + A_1 = 0.5 * A_1_raw / maximum(abs.(eigvals(A_1_raw))) + A_0 = 0.001 * randn(N) + A_2 = 0.01 * randn(N, N, N) / N + B = 0.1 * randn(N, K) + C_0 = 0.001 * randn(M) + C_1 = randn(M, N) + C_2 = 0.01 * randn(M, N, N) / N + u0 = zeros(N) + noise = [randn(K) for _ in 1:T] + + ProbType = pruned ? PrunedQuadraticStateSpaceProblem : QuadraticStateSpaceProblem + prob = ProbType(A_0, A_1, A_2, B, u0, (0, T); C_0, C_1, C_2, noise) + ws = init(prob, DirectIteration()) + + # Shadows for AD (no dprob — prob constructed inside wrapper) + dA_0 = make_zero(A_0); dA_1 = make_zero(A_1); dA_2 = make_zero(A_2) + dB = make_zero(B); dC_0 = make_zero(C_0); dC_1 = make_zero(C_1); dC_2 = make_zero(C_2) + du0 = make_zero(u0); dnoise = [make_zero(noise[1]) for _ in 1:T] + dsol = make_zero(ws.output); dcache = make_zero(ws.cache) + + return (; + A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, prob, + sol = ws.output, cache = ws.cache, + dA_0, dA_1, dA_2, dB, dC_0, dC_1, dC_2, du0, dnoise, dsol, dcache, + ) +end + +# ============================================================================= +# Inner wrappers — construct prob inside (correct Enzyme pattern) +# ============================================================================= + +# --- Unpruned --- + +function quad_fwd!(A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, sol_out, cache) + prob = QuadraticStateSpaceProblem( + A_0, A_1, A_2, B, u0, (0, length(noise)); + C_0, C_1, C_2, noise + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol_out, cache) + solve!(ws) + return (sol_out.u[end], sol_out.z[end]) +end + +function quad_rev!(A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, sol_out, cache)::Float64 + prob = QuadraticStateSpaceProblem( + A_0, A_1, A_2, B, u0, (0, length(noise)); + C_0, C_1, C_2, noise + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol_out, cache) + return sum(solve!(ws).u[end]) +end + +# --- Pruned --- + +function pruned_quad_fwd!(A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, sol_out, cache) + prob = PrunedQuadraticStateSpaceProblem( + A_0, A_1, A_2, B, u0, (0, length(noise)); + C_0, C_1, C_2, noise + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol_out, cache) + solve!(ws) + return (sol_out.u[end], sol_out.z[end]) +end + +function pruned_quad_rev!(A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, sol_out, cache)::Float64 + prob = PrunedQuadraticStateSpaceProblem( + A_0, A_1, A_2, B, u0, (0, length(noise)); + C_0, C_1, C_2, noise + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol_out, cache) + return sum(solve!(ws).u[end]) +end + +# ============================================================================= +# Outer bench functions — zero shadows, call autodiff +# ============================================================================= + +function forward_quad!( + A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, sol_out, cache, + dA_0, dA_1, dA_2, dB, dC_0, dC_1, dC_2, du0, dnoise, dsol_out, dcache + ) + make_zero!(dsol_out); make_zero!(dcache) + dA_0 = fill_zero!!(dA_0); dA_1 = fill_zero!!(dA_1); make_zero!(dA_2) + dB = fill_zero!!(dB); dC_0 = fill_zero!!(dC_0); dC_1 = fill_zero!!(dC_1) + make_zero!(dC_2); du0 = fill_zero!!(du0) + @inbounds for i in eachindex(dnoise) + dnoise[i] = fill_zero!!(dnoise[i]) + end + dA_1[1, 1] = 1.0 + + autodiff( + Forward, quad_fwd!, + Duplicated(A_0, dA_0), Duplicated(A_1, dA_1), Duplicated(A_2, dA_2), + Duplicated(B, dB), Duplicated(C_0, dC_0), Duplicated(C_1, dC_1), + Duplicated(C_2, dC_2), Duplicated(u0, du0), Duplicated(noise, dnoise), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return nothing +end + +function reverse_quad!( + A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, sol_out, cache, + dA_0, dA_1, dA_2, dB, dC_0, dC_1, dC_2, du0, dnoise, dsol_out, dcache + ) + make_zero!(dsol_out); make_zero!(dcache) + dA_0 = fill_zero!!(dA_0); dA_1 = fill_zero!!(dA_1); make_zero!(dA_2) + dB = fill_zero!!(dB); dC_0 = fill_zero!!(dC_0); dC_1 = fill_zero!!(dC_1) + make_zero!(dC_2); du0 = fill_zero!!(du0) + @inbounds for i in eachindex(dnoise) + dnoise[i] = fill_zero!!(dnoise[i]) + end + + autodiff( + Reverse, quad_rev!, Active, + Duplicated(A_0, dA_0), Duplicated(A_1, dA_1), Duplicated(A_2, dA_2), + Duplicated(B, dB), Duplicated(C_0, dC_0), Duplicated(C_1, dC_1), + Duplicated(C_2, dC_2), Duplicated(u0, du0), Duplicated(noise, dnoise), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return nothing +end + +function forward_pruned_quad!( + A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, sol_out, cache, + dA_0, dA_1, dA_2, dB, dC_0, dC_1, dC_2, du0, dnoise, dsol_out, dcache + ) + make_zero!(dsol_out); make_zero!(dcache) + dA_0 = fill_zero!!(dA_0); dA_1 = fill_zero!!(dA_1); make_zero!(dA_2) + dB = fill_zero!!(dB); dC_0 = fill_zero!!(dC_0); dC_1 = fill_zero!!(dC_1) + make_zero!(dC_2); du0 = fill_zero!!(du0) + @inbounds for i in eachindex(dnoise) + dnoise[i] = fill_zero!!(dnoise[i]) + end + dA_1[1, 1] = 1.0 + + autodiff( + Forward, pruned_quad_fwd!, + Duplicated(A_0, dA_0), Duplicated(A_1, dA_1), Duplicated(A_2, dA_2), + Duplicated(B, dB), Duplicated(C_0, dC_0), Duplicated(C_1, dC_1), + Duplicated(C_2, dC_2), Duplicated(u0, du0), Duplicated(noise, dnoise), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return nothing +end + +function reverse_pruned_quad!( + A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, sol_out, cache, + dA_0, dA_1, dA_2, dB, dC_0, dC_1, dC_2, du0, dnoise, dsol_out, dcache + ) + make_zero!(dsol_out); make_zero!(dcache) + dA_0 = fill_zero!!(dA_0); dA_1 = fill_zero!!(dA_1); make_zero!(dA_2) + dB = fill_zero!!(dB); dC_0 = fill_zero!!(dC_0); dC_1 = fill_zero!!(dC_1) + make_zero!(dC_2); du0 = fill_zero!!(du0) + @inbounds for i in eachindex(dnoise) + dnoise[i] = fill_zero!!(dnoise[i]) + end + + autodiff( + Reverse, pruned_quad_rev!, Active, + Duplicated(A_0, dA_0), Duplicated(A_1, dA_1), Duplicated(A_2, dA_2), + Duplicated(B, dB), Duplicated(C_0, dC_0), Duplicated(C_1, dC_1), + Duplicated(C_2, dC_2), Duplicated(u0, du0), Duplicated(noise, dnoise), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return nothing +end + +# ============================================================================= +# Instantiate problems +# ============================================================================= + +const quad_us = make_quad_benchmark(; p_quad_small..., pruned = false) +const quad_ul = make_quad_benchmark(; p_quad_large..., pruned = false) +const quad_ps = make_quad_benchmark(; p_quad_small..., pruned = true) +const quad_pl = make_quad_benchmark(; p_quad_large..., pruned = true) + +# ============================================================================= +# Raw benchmarks (primal solve through public API) +# ============================================================================= + +function raw_quad!(prob, sol_out, cache) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol_out, cache) + solve!(ws) + return nothing +end + +# Warmup +raw_quad!(quad_us.prob, quad_us.sol, quad_us.cache) +raw_quad!(quad_ul.prob, quad_ul.sol, quad_ul.cache) +raw_quad!(quad_ps.prob, quad_ps.sol, quad_ps.cache) +raw_quad!(quad_pl.prob, quad_pl.sol, quad_pl.cache) + +QUAD_ENZYME["unpruned"]["raw"]["small_mutable"] = @benchmarkable raw_quad!( + $(quad_us.prob), $(quad_us.sol), $(quad_us.cache) +) +QUAD_ENZYME["unpruned"]["raw"]["large_mutable"] = @benchmarkable raw_quad!( + $(quad_ul.prob), $(quad_ul.sol), $(quad_ul.cache) +) +QUAD_ENZYME["pruned"]["raw"]["small_mutable"] = @benchmarkable raw_quad!( + $(quad_ps.prob), $(quad_ps.sol), $(quad_ps.cache) +) +QUAD_ENZYME["pruned"]["raw"]["large_mutable"] = @benchmarkable raw_quad!( + $(quad_pl.prob), $(quad_pl.sol), $(quad_pl.cache) +) + +# ============================================================================= +# Forward mode AD — unpruned +# ============================================================================= + +# Warmup small +forward_quad!( + copy(quad_us.A_0), copy(quad_us.A_1), copy(quad_us.A_2), + copy(quad_us.B), copy(quad_us.C_0), copy(quad_us.C_1), copy(quad_us.C_2), + copy(quad_us.u0), [copy(n) for n in quad_us.noise], + quad_us.sol, quad_us.cache, + quad_us.dA_0, quad_us.dA_1, quad_us.dA_2, quad_us.dB, + quad_us.dC_0, quad_us.dC_1, quad_us.dC_2, quad_us.du0, quad_us.dnoise, + quad_us.dsol, quad_us.dcache +) + +QUAD_ENZYME["unpruned"]["forward"]["small_mutable"] = @benchmarkable forward_quad!( + $(copy(quad_us.A_0)), $(copy(quad_us.A_1)), $(copy(quad_us.A_2)), + $(copy(quad_us.B)), $(copy(quad_us.C_0)), $(copy(quad_us.C_1)), $(copy(quad_us.C_2)), + $(copy(quad_us.u0)), $([copy(n) for n in quad_us.noise]), + $(quad_us.sol), $(quad_us.cache), + $(quad_us.dA_0), $(quad_us.dA_1), $(quad_us.dA_2), $(quad_us.dB), + $(quad_us.dC_0), $(quad_us.dC_1), $(quad_us.dC_2), $(quad_us.du0), $(quad_us.dnoise), + $(quad_us.dsol), $(quad_us.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# Warmup large +forward_quad!( + copy(quad_ul.A_0), copy(quad_ul.A_1), copy(quad_ul.A_2), + copy(quad_ul.B), copy(quad_ul.C_0), copy(quad_ul.C_1), copy(quad_ul.C_2), + copy(quad_ul.u0), [copy(n) for n in quad_ul.noise], + quad_ul.sol, quad_ul.cache, + quad_ul.dA_0, quad_ul.dA_1, quad_ul.dA_2, quad_ul.dB, + quad_ul.dC_0, quad_ul.dC_1, quad_ul.dC_2, quad_ul.du0, quad_ul.dnoise, + quad_ul.dsol, quad_ul.dcache +) + +QUAD_ENZYME["unpruned"]["forward"]["large_mutable"] = @benchmarkable forward_quad!( + $(copy(quad_ul.A_0)), $(copy(quad_ul.A_1)), $(copy(quad_ul.A_2)), + $(copy(quad_ul.B)), $(copy(quad_ul.C_0)), $(copy(quad_ul.C_1)), $(copy(quad_ul.C_2)), + $(copy(quad_ul.u0)), $([copy(n) for n in quad_ul.noise]), + $(quad_ul.sol), $(quad_ul.cache), + $(quad_ul.dA_0), $(quad_ul.dA_1), $(quad_ul.dA_2), $(quad_ul.dB), + $(quad_ul.dC_0), $(quad_ul.dC_1), $(quad_ul.dC_2), $(quad_ul.du0), $(quad_ul.dnoise), + $(quad_ul.dsol), $(quad_ul.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# ============================================================================= +# Reverse mode AD — unpruned +# ============================================================================= + +# Warmup small +reverse_quad!( + copy(quad_us.A_0), copy(quad_us.A_1), copy(quad_us.A_2), + copy(quad_us.B), copy(quad_us.C_0), copy(quad_us.C_1), copy(quad_us.C_2), + copy(quad_us.u0), [copy(n) for n in quad_us.noise], + quad_us.sol, quad_us.cache, + quad_us.dA_0, quad_us.dA_1, quad_us.dA_2, quad_us.dB, + quad_us.dC_0, quad_us.dC_1, quad_us.dC_2, quad_us.du0, quad_us.dnoise, + quad_us.dsol, quad_us.dcache +) + +QUAD_ENZYME["unpruned"]["reverse"]["small_mutable"] = @benchmarkable reverse_quad!( + $(copy(quad_us.A_0)), $(copy(quad_us.A_1)), $(copy(quad_us.A_2)), + $(copy(quad_us.B)), $(copy(quad_us.C_0)), $(copy(quad_us.C_1)), $(copy(quad_us.C_2)), + $(copy(quad_us.u0)), $([copy(n) for n in quad_us.noise]), + $(quad_us.sol), $(quad_us.cache), + $(quad_us.dA_0), $(quad_us.dA_1), $(quad_us.dA_2), $(quad_us.dB), + $(quad_us.dC_0), $(quad_us.dC_1), $(quad_us.dC_2), $(quad_us.du0), $(quad_us.dnoise), + $(quad_us.dsol), $(quad_us.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# Warmup large +reverse_quad!( + copy(quad_ul.A_0), copy(quad_ul.A_1), copy(quad_ul.A_2), + copy(quad_ul.B), copy(quad_ul.C_0), copy(quad_ul.C_1), copy(quad_ul.C_2), + copy(quad_ul.u0), [copy(n) for n in quad_ul.noise], + quad_ul.sol, quad_ul.cache, + quad_ul.dA_0, quad_ul.dA_1, quad_ul.dA_2, quad_ul.dB, + quad_ul.dC_0, quad_ul.dC_1, quad_ul.dC_2, quad_ul.du0, quad_ul.dnoise, + quad_ul.dsol, quad_ul.dcache +) + +QUAD_ENZYME["unpruned"]["reverse"]["large_mutable"] = @benchmarkable reverse_quad!( + $(copy(quad_ul.A_0)), $(copy(quad_ul.A_1)), $(copy(quad_ul.A_2)), + $(copy(quad_ul.B)), $(copy(quad_ul.C_0)), $(copy(quad_ul.C_1)), $(copy(quad_ul.C_2)), + $(copy(quad_ul.u0)), $([copy(n) for n in quad_ul.noise]), + $(quad_ul.sol), $(quad_ul.cache), + $(quad_ul.dA_0), $(quad_ul.dA_1), $(quad_ul.dA_2), $(quad_ul.dB), + $(quad_ul.dC_0), $(quad_ul.dC_1), $(quad_ul.dC_2), $(quad_ul.du0), $(quad_ul.dnoise), + $(quad_ul.dsol), $(quad_ul.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# ============================================================================= +# Forward mode AD — pruned +# ============================================================================= + +# Warmup small +forward_pruned_quad!( + copy(quad_ps.A_0), copy(quad_ps.A_1), copy(quad_ps.A_2), + copy(quad_ps.B), copy(quad_ps.C_0), copy(quad_ps.C_1), copy(quad_ps.C_2), + copy(quad_ps.u0), [copy(n) for n in quad_ps.noise], + quad_ps.sol, quad_ps.cache, + quad_ps.dA_0, quad_ps.dA_1, quad_ps.dA_2, quad_ps.dB, + quad_ps.dC_0, quad_ps.dC_1, quad_ps.dC_2, quad_ps.du0, quad_ps.dnoise, + quad_ps.dsol, quad_ps.dcache +) + +QUAD_ENZYME["pruned"]["forward"]["small_mutable"] = @benchmarkable forward_pruned_quad!( + $(copy(quad_ps.A_0)), $(copy(quad_ps.A_1)), $(copy(quad_ps.A_2)), + $(copy(quad_ps.B)), $(copy(quad_ps.C_0)), $(copy(quad_ps.C_1)), $(copy(quad_ps.C_2)), + $(copy(quad_ps.u0)), $([copy(n) for n in quad_ps.noise]), + $(quad_ps.sol), $(quad_ps.cache), + $(quad_ps.dA_0), $(quad_ps.dA_1), $(quad_ps.dA_2), $(quad_ps.dB), + $(quad_ps.dC_0), $(quad_ps.dC_1), $(quad_ps.dC_2), $(quad_ps.du0), $(quad_ps.dnoise), + $(quad_ps.dsol), $(quad_ps.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# Warmup large +forward_pruned_quad!( + copy(quad_pl.A_0), copy(quad_pl.A_1), copy(quad_pl.A_2), + copy(quad_pl.B), copy(quad_pl.C_0), copy(quad_pl.C_1), copy(quad_pl.C_2), + copy(quad_pl.u0), [copy(n) for n in quad_pl.noise], + quad_pl.sol, quad_pl.cache, + quad_pl.dA_0, quad_pl.dA_1, quad_pl.dA_2, quad_pl.dB, + quad_pl.dC_0, quad_pl.dC_1, quad_pl.dC_2, quad_pl.du0, quad_pl.dnoise, + quad_pl.dsol, quad_pl.dcache +) + +QUAD_ENZYME["pruned"]["forward"]["large_mutable"] = @benchmarkable forward_pruned_quad!( + $(copy(quad_pl.A_0)), $(copy(quad_pl.A_1)), $(copy(quad_pl.A_2)), + $(copy(quad_pl.B)), $(copy(quad_pl.C_0)), $(copy(quad_pl.C_1)), $(copy(quad_pl.C_2)), + $(copy(quad_pl.u0)), $([copy(n) for n in quad_pl.noise]), + $(quad_pl.sol), $(quad_pl.cache), + $(quad_pl.dA_0), $(quad_pl.dA_1), $(quad_pl.dA_2), $(quad_pl.dB), + $(quad_pl.dC_0), $(quad_pl.dC_1), $(quad_pl.dC_2), $(quad_pl.du0), $(quad_pl.dnoise), + $(quad_pl.dsol), $(quad_pl.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# ============================================================================= +# Reverse mode AD — pruned +# ============================================================================= + +# Warmup small +reverse_pruned_quad!( + copy(quad_ps.A_0), copy(quad_ps.A_1), copy(quad_ps.A_2), + copy(quad_ps.B), copy(quad_ps.C_0), copy(quad_ps.C_1), copy(quad_ps.C_2), + copy(quad_ps.u0), [copy(n) for n in quad_ps.noise], + quad_ps.sol, quad_ps.cache, + quad_ps.dA_0, quad_ps.dA_1, quad_ps.dA_2, quad_ps.dB, + quad_ps.dC_0, quad_ps.dC_1, quad_ps.dC_2, quad_ps.du0, quad_ps.dnoise, + quad_ps.dsol, quad_ps.dcache +) + +QUAD_ENZYME["pruned"]["reverse"]["small_mutable"] = @benchmarkable reverse_pruned_quad!( + $(copy(quad_ps.A_0)), $(copy(quad_ps.A_1)), $(copy(quad_ps.A_2)), + $(copy(quad_ps.B)), $(copy(quad_ps.C_0)), $(copy(quad_ps.C_1)), $(copy(quad_ps.C_2)), + $(copy(quad_ps.u0)), $([copy(n) for n in quad_ps.noise]), + $(quad_ps.sol), $(quad_ps.cache), + $(quad_ps.dA_0), $(quad_ps.dA_1), $(quad_ps.dA_2), $(quad_ps.dB), + $(quad_ps.dC_0), $(quad_ps.dC_1), $(quad_ps.dC_2), $(quad_ps.du0), $(quad_ps.dnoise), + $(quad_ps.dsol), $(quad_ps.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# Warmup large +reverse_pruned_quad!( + copy(quad_pl.A_0), copy(quad_pl.A_1), copy(quad_pl.A_2), + copy(quad_pl.B), copy(quad_pl.C_0), copy(quad_pl.C_1), copy(quad_pl.C_2), + copy(quad_pl.u0), [copy(n) for n in quad_pl.noise], + quad_pl.sol, quad_pl.cache, + quad_pl.dA_0, quad_pl.dA_1, quad_pl.dA_2, quad_pl.dB, + quad_pl.dC_0, quad_pl.dC_1, quad_pl.dC_2, quad_pl.du0, quad_pl.dnoise, + quad_pl.dsol, quad_pl.dcache +) + +QUAD_ENZYME["pruned"]["reverse"]["large_mutable"] = @benchmarkable reverse_pruned_quad!( + $(copy(quad_pl.A_0)), $(copy(quad_pl.A_1)), $(copy(quad_pl.A_2)), + $(copy(quad_pl.B)), $(copy(quad_pl.C_0)), $(copy(quad_pl.C_1)), $(copy(quad_pl.C_2)), + $(copy(quad_pl.u0)), $([copy(n) for n in quad_pl.noise]), + $(quad_pl.sol), $(quad_pl.cache), + $(quad_pl.dA_0), $(quad_pl.dA_1), $(quad_pl.dA_2), $(quad_pl.dB), + $(quad_pl.dC_0), $(quad_pl.dC_1), $(quad_pl.dC_2), $(quad_pl.du0), $(quad_pl.dnoise), + $(quad_pl.dsol), $(quad_pl.dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +QUAD_ENZYME diff --git a/benchmark/forwarddiff_conditional_likelihood.jl b/benchmark/forwarddiff_conditional_likelihood.jl new file mode 100644 index 0000000..6e14313 --- /dev/null +++ b/benchmark/forwarddiff_conditional_likelihood.jl @@ -0,0 +1,140 @@ +# ForwardDiff AD benchmarks for ConditionalLikelihood +# Returns CL_FD BenchmarkGroup + +using ForwardDiff +using DifferenceEquations: init, solve!, StateSpaceWorkspace + +const CL_FD = BenchmarkGroup() +CL_FD["gradient"] = BenchmarkGroup() + +# ============================================================================= +# Type promotion helper +# ============================================================================= + +_fd_promote_cl(::Type{T}, x::AbstractArray{T}) where {T} = x +_fd_promote_cl(::Type{T}, x::AbstractArray) where {T} = T.(x) + +# ============================================================================= +# Problem sizes (same as enzyme_conditional_likelihood.jl) +# ============================================================================= + +# CL requires fully-observed state: M = N +const p_cl_fd_small = (; N = 5, M = 5, T = 10) +const p_cl_fd_large = (; N = 30, M = 30, T = 100) + +# ============================================================================= +# Problem setup +# ============================================================================= + +function make_cl_fd_benchmark(p; seed = 42) + (; N, M, T) = p + Random.seed!(seed) + A_raw = randn(N, N) + A = 0.5 * A_raw / maximum(abs.(eigvals(A_raw))) + C = randn(M, N) + H = 0.1 * randn(M, M) + R = H * H' + + x = zeros(N) + y = Vector{Vector{Float64}}(undef, T) + for t in 1:T + x = A * x + 0.1 * randn(N) + y[t] = C * x + H * randn(M) + end + + return (; A, C, H, R, y) +end + +# ============================================================================= +# ForwardDiff wrapper — gradient of loglik w.r.t. vec(A) +# ============================================================================= + +function cl_loglik_fd_bench(A_vec, C, H, y, N) + T_el = eltype(A_vec) + A = reshape(A_vec, N, N) + R = _fd_promote_cl(T_el, H) * _fd_promote_cl(T_el, H)' + prob = LinearStateSpaceProblem( + A, nothing, + zeros(T_el, N), (0, length(y)); + C = _fd_promote_cl(T_el, C), + observables_noise = R, + observables = y, + ) + sol = solve(prob, ConditionalLikelihood()) + return sol.logpdf +end + +function fd_gradient_cl!(A_vec, C, H, y, N) + return ForwardDiff.gradient( + a -> cl_loglik_fd_bench(a, C, H, y, N), A_vec + ) +end + +# ============================================================================= +# Instantiate problems +# ============================================================================= + +const cl_fd_s = make_cl_fd_benchmark(p_cl_fd_small) +const cl_fd_l = make_cl_fd_benchmark(p_cl_fd_large) + +# ============================================================================= +# Warmup and benchmarks +# ============================================================================= + +fd_gradient_cl!( + vec(copy(cl_fd_s.A)), cl_fd_s.C, cl_fd_s.H, cl_fd_s.y, p_cl_fd_small.N +) +fd_gradient_cl!( + vec(copy(cl_fd_l.A)), cl_fd_l.C, cl_fd_l.H, cl_fd_l.y, p_cl_fd_large.N +) + +CL_FD["gradient"]["small_mutable"] = @benchmarkable fd_gradient_cl!( + $(vec(copy(cl_fd_s.A))), $(cl_fd_s.C), $(cl_fd_s.H), + $(cl_fd_s.y), $(p_cl_fd_small.N) +) + +CL_FD["gradient"]["large_mutable"] = @benchmarkable fd_gradient_cl!( + $(vec(copy(cl_fd_l.A))), $(cl_fd_l.C), $(cl_fd_l.H), + $(cl_fd_l.y), $(p_cl_fd_large.N) +) + +# ============================================================================= +# StaticArrays variant (small only) +# ============================================================================= + +# StaticArrays CL: no C matrix (identity observation, state = obs) +CL_FD["gradient"]["small_static"] = let + (; A, H, y) = cl_fd_s + N = p_cl_fd_small.N + + # For CL without C, observables must be state-dimensional + H_s = SMatrix{N, N}(0.1 * I(N)) + y_s = [SVector{N}(yi) for yi in y] + + function _cl_loglik_static( + A_vec, H_s, y_s, + ::Val{N_} + ) where {N_} + T_el = eltype(A_vec) + A_d = SMatrix{N_, N_}(reshape(A_vec, N_, N_)) + H_d = SMatrix{N_, N_}(T_el.(H_s)) + R_d = H_d * H_d' + u0_d = SVector{N_}(zeros(T_el, N_)) + prob = LinearStateSpaceProblem( + A_d, nothing, u0_d, (0, length(y_s)); + observables_noise = R_d, + observables = y_s, + ) + sol = solve(prob, ConditionalLikelihood()) + return sol.logpdf + end + + _cl_fd_static_grad(a) = _cl_loglik_static(a, H_s, y_s, Val(N)) + + A_vec = collect(vec(Matrix(A))) + ForwardDiff.gradient(_cl_fd_static_grad, A_vec) + + @benchmarkable ForwardDiff.gradient($_cl_fd_static_grad, $(copy(A_vec))) +end + +CL_FD diff --git a/benchmark/forwarddiff_kalman.jl b/benchmark/forwarddiff_kalman.jl new file mode 100644 index 0000000..4f782d8 --- /dev/null +++ b/benchmark/forwarddiff_kalman.jl @@ -0,0 +1,149 @@ +# ForwardDiff AD benchmarks for Kalman filter +# Returns KALMAN_FD BenchmarkGroup + +using ForwardDiff +using DifferenceEquations: init, solve!, StateSpaceWorkspace + +const KALMAN_FD = BenchmarkGroup() +KALMAN_FD["gradient"] = BenchmarkGroup() + +# ============================================================================= +# Type promotion helper +# ============================================================================= + +_fd_promote(::Type{T}, x::AbstractArray{T}) where {T} = x +_fd_promote(::Type{T}, x::AbstractArray) where {T} = T.(x) + +# ============================================================================= +# Problem sizes (same as enzyme_kalman.jl) +# ============================================================================= + +const p_kf_fd_small = (; N = 5, M = 2, K = 2, L = 2, T = 10) +const p_kf_fd_large = (; N = 30, M = 10, K = 10, L = 10, T = 100) + +# ============================================================================= +# Problem setup +# ============================================================================= + +function make_kalman_fd_benchmark(p; seed = 42) + (; N, M, K, L, T) = p + Random.seed!(seed) + A_raw = randn(N, N) + A = 0.5 * A_raw / maximum(abs.(eigvals(A_raw))) + B = 0.1 * randn(N, K) + C = randn(M, N) + H = 0.1 * randn(M, L) + R = H * H' + mu_0 = zeros(N) + Sigma_0 = Matrix{Float64}(I, N, N) + + x0 = randn(N) + noise = [randn(K) for _ in 1:T] + sim = solve(LinearStateSpaceProblem(A, B, x0, (0, T); C, noise)) + y = [sim.z[t + 1] + H * randn(L) for t in 1:T] + + return (; A, B, C, R, mu_0, Sigma_0, y) +end + +# ============================================================================= +# ForwardDiff wrapper — gradient of loglik w.r.t. vec(A) +# ============================================================================= + +function kalman_loglik_fd_bench(A_vec, B, C, mu_0, Sigma_0, R, y, N) + T_el = eltype(A_vec) + A = reshape(A_vec, N, N) + prob = LinearStateSpaceProblem( + A, _fd_promote(T_el, B), + zeros(T_el, N), (0, length(y)); + C = _fd_promote(T_el, C), + u0_prior_mean = _fd_promote(T_el, mu_0), + u0_prior_var = _fd_promote(T_el, Sigma_0), + observables_noise = _fd_promote(T_el, R), + observables = y + ) + sol = solve(prob, KalmanFilter()) + return sol.logpdf +end + +function fd_gradient_kalman!(A_vec, B, C, mu_0, Sigma_0, R, y, N) + return ForwardDiff.gradient( + a -> kalman_loglik_fd_bench(a, B, C, mu_0, Sigma_0, R, y, N), A_vec + ) +end + +# ============================================================================= +# Instantiate problems +# ============================================================================= + +const kf_fd_s = make_kalman_fd_benchmark(p_kf_fd_small) +const kf_fd_l = make_kalman_fd_benchmark(p_kf_fd_large) + +# ============================================================================= +# Warmup and benchmarks +# ============================================================================= + +# Warmup +fd_gradient_kalman!( + vec(copy(kf_fd_s.A)), kf_fd_s.B, kf_fd_s.C, + kf_fd_s.mu_0, kf_fd_s.Sigma_0, kf_fd_s.R, kf_fd_s.y, p_kf_fd_small.N +) +fd_gradient_kalman!( + vec(copy(kf_fd_l.A)), kf_fd_l.B, kf_fd_l.C, + kf_fd_l.mu_0, kf_fd_l.Sigma_0, kf_fd_l.R, kf_fd_l.y, p_kf_fd_large.N +) + +KALMAN_FD["gradient"]["small_mutable"] = @benchmarkable fd_gradient_kalman!( + $(vec(copy(kf_fd_s.A))), $(kf_fd_s.B), $(kf_fd_s.C), + $(kf_fd_s.mu_0), $(kf_fd_s.Sigma_0), $(kf_fd_s.R), $(kf_fd_s.y), + $(p_kf_fd_small.N) +) + +KALMAN_FD["gradient"]["large_mutable"] = @benchmarkable fd_gradient_kalman!( + $(vec(copy(kf_fd_l.A))), $(kf_fd_l.B), $(kf_fd_l.C), + $(kf_fd_l.mu_0), $(kf_fd_l.Sigma_0), $(kf_fd_l.R), $(kf_fd_l.y), + $(p_kf_fd_large.N) +) + +# ============================================================================= +# StaticArrays variant (small only — static types impractical for N=30) +# ============================================================================= + +KALMAN_FD["gradient"]["small_static"] = let + (; A, B, C, R, mu_0, Sigma_0, y) = kf_fd_s + N = p_kf_fd_small.N; M = p_kf_fd_small.M; K = p_kf_fd_small.K + + A_s = SMatrix{N, N}(A); B_s = SMatrix{N, K}(B); C_s = SMatrix{M, N}(C) + R_s = SMatrix{M, M}(R); mu_s = SVector{N}(mu_0); Sig_s = SMatrix{N, N}(Sigma_0) + y_s = [SVector{M}(yi) for yi in y] + + function _kf_loglik_static( + A_vec, B_s, C_s, mu_s, Sig_s, R_s, y_s, + ::Val{N_}, ::Val{M_}, ::Val{K_} + ) where {N_, M_, K_} + T_el = eltype(A_vec) + A_d = SMatrix{N_, N_}(reshape(A_vec, N_, N_)) + prob = LinearStateSpaceProblem( + A_d, SMatrix{N_, K_}(T_el.(B_s)), + SVector{N_}(zeros(T_el, N_)), (0, length(y_s)); + C = SMatrix{M_, N_}(T_el.(C_s)), + u0_prior_mean = SVector{N_}(T_el.(mu_s)), + u0_prior_var = SMatrix{N_, N_}(T_el.(Sig_s)), + observables_noise = SMatrix{M_, M_}(T_el.(R_s)), + observables = y_s + ) + sol = solve(prob, KalmanFilter()) + return sol.logpdf + end + + A_vec = collect(vec(Matrix(A))) + f = a -> _kf_loglik_static( + a, B_s, C_s, mu_s, Sig_s, R_s, y_s, + Val(N), Val(M), Val(K) + ) + # Warmup + ForwardDiff.gradient(f, A_vec) + + @benchmarkable ForwardDiff.gradient($f, $(copy(A_vec))) +end + +KALMAN_FD diff --git a/benchmark/forwarddiff_linear_likelihood.jl b/benchmark/forwarddiff_linear_likelihood.jl new file mode 100644 index 0000000..06f3f3a --- /dev/null +++ b/benchmark/forwarddiff_linear_likelihood.jl @@ -0,0 +1,145 @@ +# ForwardDiff AD benchmarks for DirectIteration (joint likelihood) +# Returns DI_FD BenchmarkGroup + +using ForwardDiff +using DifferenceEquations: init, solve!, StateSpaceWorkspace + +const DI_FD = BenchmarkGroup() +DI_FD["gradient"] = BenchmarkGroup() + +# ============================================================================= +# Type promotion helper +# ============================================================================= + +_fd_promote_di(::Type{T}, x::AbstractArray{T}) where {T} = x +_fd_promote_di(::Type{T}, x::AbstractArray) where {T} = T.(x) + +# ============================================================================= +# Problem sizes (same as enzyme_linear_likelihood.jl) +# ============================================================================= + +const p_di_fd_small = (; N = 5, M = 2, K = 2, L = 2, T = 10) +const p_di_fd_large = (; N = 30, M = 10, K = 10, L = 10, T = 100) + +# ============================================================================= +# Problem setup +# ============================================================================= + +function make_di_fd_benchmark(p; seed = 42) + (; N, M, K, L, T) = p + Random.seed!(seed) + A_raw = randn(N, N) + A = 0.5 * A_raw / maximum(abs.(eigvals(A_raw))) + B = 0.1 * randn(N, K) + C = randn(M, N) + H = 0.1 * randn(M, L) + R = H * H' + u0 = zeros(N) + noise = [randn(K) for _ in 1:T] + + sim = solve(LinearStateSpaceProblem(A, B, u0, (0, T); C, noise)) + y = [sim.z[t + 1] + H * randn(L) for t in 1:T] + + return (; A, B, C, H, R, u0, noise, y) +end + +# ============================================================================= +# ForwardDiff wrapper — gradient of loglik w.r.t. vec(A) +# ============================================================================= + +function di_loglik_fd_bench(A_vec, B, C, u0, noise, y, H, N) + T_el = eltype(A_vec) + A = reshape(A_vec, N, N) + R = _fd_promote_di(T_el, H) * _fd_promote_di(T_el, H)' + prob = LinearStateSpaceProblem( + A, _fd_promote_di(T_el, B), + _fd_promote_di(T_el, u0), (0, length(y)); + C = _fd_promote_di(T_el, C), + observables_noise = R, + observables = y, noise = noise + ) + sol = solve(prob, DirectIteration()) + return sol.logpdf +end + +function fd_gradient_di!(A_vec, B, C, u0, noise, y, H, N) + return ForwardDiff.gradient( + a -> di_loglik_fd_bench(a, B, C, u0, noise, y, H, N), A_vec + ) +end + +# ============================================================================= +# Instantiate problems +# ============================================================================= + +const di_fd_s = make_di_fd_benchmark(p_di_fd_small) +const di_fd_l = make_di_fd_benchmark(p_di_fd_large) + +# ============================================================================= +# Warmup and benchmarks +# ============================================================================= + +fd_gradient_di!( + vec(copy(di_fd_s.A)), di_fd_s.B, di_fd_s.C, + di_fd_s.u0, di_fd_s.noise, di_fd_s.y, di_fd_s.H, p_di_fd_small.N +) +fd_gradient_di!( + vec(copy(di_fd_l.A)), di_fd_l.B, di_fd_l.C, + di_fd_l.u0, di_fd_l.noise, di_fd_l.y, di_fd_l.H, p_di_fd_large.N +) + +DI_FD["gradient"]["small_mutable"] = @benchmarkable fd_gradient_di!( + $(vec(copy(di_fd_s.A))), $(di_fd_s.B), $(di_fd_s.C), + $(di_fd_s.u0), $(di_fd_s.noise), $(di_fd_s.y), $(di_fd_s.H), + $(p_di_fd_small.N) +) + +DI_FD["gradient"]["large_mutable"] = @benchmarkable fd_gradient_di!( + $(vec(copy(di_fd_l.A))), $(di_fd_l.B), $(di_fd_l.C), + $(di_fd_l.u0), $(di_fd_l.noise), $(di_fd_l.y), $(di_fd_l.H), + $(p_di_fd_large.N) +) + +# ============================================================================= +# StaticArrays variant (small only) +# ============================================================================= + +DI_FD["gradient"]["small_static"] = let + (; A, B, C, H, u0, noise, y) = di_fd_s + N = p_di_fd_small.N; M = p_di_fd_small.M; K = p_di_fd_small.K; L = p_di_fd_small.L + + B_s = SMatrix{N, K}(B); C_s = SMatrix{M, N}(C); H_s = SMatrix{M, L}(H) + noise_s = [SVector{K}(n) for n in noise] + y_s = [SVector{M}(yi) for yi in y] + + function _di_loglik_static( + A_vec, B_s, C_s, H_s, noise_s, y_s, + ::Val{N_}, ::Val{M_}, ::Val{K_}, ::Val{L_} + ) where {N_, M_, K_, L_} + T_el = eltype(A_vec) + A_d = SMatrix{N_, N_}(reshape(A_vec, N_, N_)) + B_d = SMatrix{N_, K_}(T_el.(B_s)) + C_d = SMatrix{M_, N_}(T_el.(C_s)) + H_d = SMatrix{M_, L_}(T_el.(H_s)) + R_d = H_d * H_d' + u0_d = SVector{N_}(zeros(T_el, N_)) + prob = LinearStateSpaceProblem( + A_d, B_d, u0_d, (0, length(y_s)); + C = C_d, observables_noise = R_d, + observables = y_s, noise = noise_s + ) + sol = solve(prob, DirectIteration()) + return sol.logpdf + end + + A_vec = collect(vec(Matrix(A))) + f = a -> _di_loglik_static( + a, B_s, C_s, H_s, noise_s, y_s, + Val(N), Val(M), Val(K), Val(L) + ) + ForwardDiff.gradient(f, A_vec) + + @benchmarkable ForwardDiff.gradient($f, $(copy(A_vec))) +end + +DI_FD diff --git a/benchmark/forwarddiff_linear_simulation.jl b/benchmark/forwarddiff_linear_simulation.jl new file mode 100644 index 0000000..f00dd50 --- /dev/null +++ b/benchmark/forwarddiff_linear_simulation.jl @@ -0,0 +1,128 @@ +# ForwardDiff AD benchmarks for Linear DirectIteration simulation (no observations/likelihood) +# Returns SIM_FD BenchmarkGroup + +using ForwardDiff +using DifferenceEquations: init, solve!, StateSpaceWorkspace + +const SIM_FD = BenchmarkGroup() +SIM_FD["gradient"] = BenchmarkGroup() + +# ============================================================================= +# Type promotion helper +# ============================================================================= + +_fd_promote_sim(::Type{T}, x::AbstractArray{T}) where {T} = x +_fd_promote_sim(::Type{T}, x::AbstractArray) where {T} = T.(x) + +# ============================================================================= +# Problem sizes (same as enzyme_linear_simulation.jl) +# ============================================================================= + +const p_sim_fd_small = (; N = 5, M = 3, K = 2, T = 10) +const p_sim_fd_large = (; N = 30, M = 10, K = 10, T = 100) + +# ============================================================================= +# Problem setup +# ============================================================================= + +function make_sim_fd_benchmark(p; seed = 42) + (; N, M, K, T) = p + Random.seed!(seed) + A_raw = randn(N, N) + A = 0.5 * A_raw / maximum(abs.(eigvals(A_raw))) + B = 0.1 * randn(N, K) + C = randn(M, N) + u0 = zeros(N) + noise = [randn(K) for _ in 1:T] + + return (; A, B, C, u0, noise) +end + +# ============================================================================= +# ForwardDiff wrapper — gradient of sum(u[end]) w.r.t. vec(A) +# ============================================================================= + +function sim_scalar_fd_bench(A_vec, B, C, u0, noise, N) + T_el = eltype(A_vec) + A = reshape(A_vec, N, N) + prob = LinearStateSpaceProblem( + A, _fd_promote_sim(T_el, B), + _fd_promote_sim(T_el, u0), (0, length(noise)); + C = _fd_promote_sim(T_el, C), noise = noise + ) + sol = solve(prob, DirectIteration()) + return sum(sol.u[end]) +end + +function fd_gradient_sim!(A_vec, B, C, u0, noise, N) + return ForwardDiff.gradient( + a -> sim_scalar_fd_bench(a, B, C, u0, noise, N), A_vec + ) +end + +# ============================================================================= +# Instantiate problems +# ============================================================================= + +const sim_fd_s = make_sim_fd_benchmark(p_sim_fd_small) +const sim_fd_l = make_sim_fd_benchmark(p_sim_fd_large) + +# ============================================================================= +# Warmup and benchmarks +# ============================================================================= + +fd_gradient_sim!( + vec(copy(sim_fd_s.A)), sim_fd_s.B, sim_fd_s.C, + sim_fd_s.u0, sim_fd_s.noise, p_sim_fd_small.N +) +fd_gradient_sim!( + vec(copy(sim_fd_l.A)), sim_fd_l.B, sim_fd_l.C, + sim_fd_l.u0, sim_fd_l.noise, p_sim_fd_large.N +) + +SIM_FD["gradient"]["small_mutable"] = @benchmarkable fd_gradient_sim!( + $(vec(copy(sim_fd_s.A))), $(sim_fd_s.B), $(sim_fd_s.C), + $(sim_fd_s.u0), $(sim_fd_s.noise), $(p_sim_fd_small.N) +) + +SIM_FD["gradient"]["large_mutable"] = @benchmarkable fd_gradient_sim!( + $(vec(copy(sim_fd_l.A))), $(sim_fd_l.B), $(sim_fd_l.C), + $(sim_fd_l.u0), $(sim_fd_l.noise), $(p_sim_fd_large.N) +) + +# ============================================================================= +# StaticArrays variant (small only) +# ============================================================================= + +SIM_FD["gradient"]["small_static"] = let + (; A, B, C, u0, noise) = sim_fd_s + N = p_sim_fd_small.N; M = p_sim_fd_small.M; K = p_sim_fd_small.K + + B_s = SMatrix{N, K}(B); C_s = SMatrix{M, N}(C) + noise_s = [SVector{K}(n) for n in noise] + + function _sim_scalar_static( + A_vec, B_s, C_s, noise_s, + ::Val{N_}, ::Val{M_}, ::Val{K_} + ) where {N_, M_, K_} + T_el = eltype(A_vec) + A_d = SMatrix{N_, N_}(reshape(A_vec, N_, N_)) + B_d = SMatrix{N_, K_}(T_el.(B_s)) + C_d = SMatrix{M_, N_}(T_el.(C_s)) + u0_d = SVector{N_}(zeros(T_el, N_)) + prob = LinearStateSpaceProblem( + A_d, B_d, u0_d, (0, length(noise_s)); + C = C_d, noise = noise_s + ) + sol = solve(prob, DirectIteration()) + return sum(sol.u[end]) + end + + A_vec = collect(vec(Matrix(A))) + f = a -> _sim_scalar_static(a, B_s, C_s, noise_s, Val(N), Val(M), Val(K)) + ForwardDiff.gradient(f, A_vec) + + @benchmarkable ForwardDiff.gradient($f, $(copy(A_vec))) +end + +SIM_FD diff --git a/benchmark/gradient_comparison.jl b/benchmark/gradient_comparison.jl new file mode 100644 index 0000000..21fe88d --- /dev/null +++ b/benchmark/gradient_comparison.jl @@ -0,0 +1,508 @@ +# Apples-to-apples gradient benchmark: ForwardDiff vs Enzyme BatchDuplicated vs Enzyme Reverse +# All methods compute the SAME quantity: full gradient of loglik w.r.t. vec(A) (N² components). +# +# Returns GRAD_CMP BenchmarkGroup + +using ForwardDiff +using Enzyme: make_zero, make_zero!, BatchDuplicated +using DifferenceEquations: init, solve!, StateSpaceWorkspace, fill_zero!! + +const GRAD_CMP = BenchmarkGroup() +GRAD_CMP["kalman"] = BenchmarkGroup() +GRAD_CMP["di_likelihood"] = BenchmarkGroup() + +# ============================================================================= +# Type promotion helper (ForwardDiff path) +# ============================================================================= + +_gc_promote_bench(::Type{T}, x::AbstractArray{T}) where {T} = x +_gc_promote_bench(::Type{T}, x::AbstractArray) where {T} = T.(x) + +# ============================================================================= +# Problem sizes +# ============================================================================= + +const p_gc_small = (; N = 5, M = 2, K = 2, L = 2, T = 10) +const p_gc_large = (; N = 30, M = 10, K = 10, L = 10, T = 100) +const BATCH_SIZE = 10 # chunk size for both ForwardDiff and Enzyme BatchDuplicated + +# ============================================================================= +# Kalman setup +# ============================================================================= + +function make_gc_kalman(p; seed = 42) + (; N, M, K, L, T) = p + Random.seed!(seed) + A_raw = randn(N, N) + A = 0.5 * A_raw / maximum(abs.(eigvals(A_raw))) + B = 0.1 * randn(N, K) + C = randn(M, N) + H = 0.1 * randn(M, L) + R = H * H' + mu_0 = zeros(N) + Sigma_0 = Matrix{Float64}(I, N, N) + + x0 = randn(N) + noise = [randn(K) for _ in 1:T] + sim = solve(LinearStateSpaceProblem(A, B, x0, (0, T); C, noise)) + y = [sim.z[t + 1] + H * randn(L) for t in 1:T] + + # Enzyme workspace + prob = LinearStateSpaceProblem( + A, B, zeros(N), (0, T); C, + u0_prior_mean = mu_0, u0_prior_var = Sigma_0, + observables_noise = R, observables = y + ) + ws = init(prob, KalmanFilter()) + + # Reverse shadows (single copy) + rv_dA = make_zero(A); rv_dB = make_zero(B); rv_dC = make_zero(C) + rv_dmu0 = make_zero(mu_0); rv_dSig0 = make_zero(Sigma_0); rv_dR = make_zero(R) + rv_dy = [make_zero(y[1]) for _ in 1:T] + rv_dsol = make_zero(ws.output); rv_dcache = make_zero(ws.cache) + + return (; + A, B, C, R, mu_0, Sigma_0, y, + sol_out = ws.output, cache = ws.cache, + rv_dA, rv_dB, rv_dC, rv_dmu0, rv_dSig0, rv_dR, rv_dy, rv_dsol, rv_dcache, + ) +end + +# ============================================================================= +# Kalman wrapper functions +# ============================================================================= + +# Enzyme inner function (shared by forward & reverse) +function _kf_loglik_gc!(A, B, C, mu_0, Sigma_0, R, y, sol_out, cache) + prob = LinearStateSpaceProblem( + A, B, zeros(eltype(A), size(A, 1)), (0, length(y)); C, + u0_prior_mean = mu_0, u0_prior_var = Sigma_0, + observables_noise = R, observables = y + ) + ws = StateSpaceWorkspace(prob, KalmanFilter(), sol_out, cache) + return solve!(ws).logpdf +end + +# ForwardDiff wrapper +function _kf_loglik_fd_gc(A_vec, B, C, mu_0, Sigma_0, R, y, N) + T_el = eltype(A_vec) + A = reshape(A_vec, N, N) + prob = LinearStateSpaceProblem( + A, _gc_promote_bench(T_el, B), + zeros(T_el, N), (0, length(y)); + C = _gc_promote_bench(T_el, C), + u0_prior_mean = _gc_promote_bench(T_el, mu_0), + u0_prior_var = _gc_promote_bench(T_el, Sigma_0), + observables_noise = _gc_promote_bench(T_el, R), + observables = y + ) + sol = solve(prob, KalmanFilter()) + return sol.logpdf +end + +function bench_forwarddiff_kf!(A_vec, B, C, mu_0, Sigma_0, R, y, N) + return ForwardDiff.gradient( + a -> _kf_loglik_fd_gc(a, B, C, mu_0, Sigma_0, R, y, N), A_vec + ) +end + +# Enzyme BatchDuplicated forward — full gradient +function bench_enzyme_batched_fwd_kf!( + grad_out, A, B, C, mu_0, Sigma_0, R, y, + sol_out, cache, + dAs, dBs, dCs, dmu0s, dSig0s, dRs, dys, dsols, dcaches + ) + chunk_size = length(dAs) + N_params = length(vec(A)) + for chunk_start in 1:chunk_size:N_params + chunk_end = min(chunk_start + chunk_size - 1, N_params) + actual = chunk_end - chunk_start + 1 + + for k in 1:chunk_size + fill_zero!!(dAs[k]); fill_zero!!(dBs[k]); fill_zero!!(dCs[k]) + fill_zero!!(dmu0s[k]); fill_zero!!(dSig0s[k]); fill_zero!!(dRs[k]) + for t in eachindex(dys[k]) + dys[k][t] = fill_zero!!(dys[k][t]) + end + make_zero!(dsols[k]); make_zero!(dcaches[k]) + end + for k in 1:actual + dAs[k][chunk_start + k - 1] = 1.0 + end + + result = autodiff( + Forward, _kf_loglik_gc!, + BatchDuplicated(A, dAs), + BatchDuplicated(B, dBs), + BatchDuplicated(C, dCs), + BatchDuplicated(mu_0, dmu0s), + BatchDuplicated(Sigma_0, dSig0s), + BatchDuplicated(R, dRs), + BatchDuplicated(y, dys), + BatchDuplicated(sol_out, dsols), + BatchDuplicated(cache, dcaches) + ) + + derivs = values(result[1]) + for k in 1:actual + grad_out[chunk_start + k - 1] = derivs[k] + end + end + return grad_out +end + +# Enzyme Reverse — full gradient, extract dA +function bench_enzyme_reverse_kf!( + A, B, C, mu_0, Sigma_0, R, y, + sol_out, cache, dA, dB, dC, dmu_0, dSigma_0, dR, dy, dsol_out, dcache + ) + make_zero!(dsol_out); make_zero!(dcache) + fill_zero!!(dA); fill_zero!!(dB); fill_zero!!(dC) + fill_zero!!(dmu_0); fill_zero!!(dSigma_0); fill_zero!!(dR) + @inbounds for i in eachindex(dy) + dy[i] = fill_zero!!(dy[i]) + end + + autodiff( + Reverse, _kf_loglik_gc!, Active, + Duplicated(A, dA), Duplicated(B, dB), Duplicated(C, dC), + Duplicated(mu_0, dmu_0), Duplicated(Sigma_0, dSigma_0), + Duplicated(R, dR), Duplicated(y, dy), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return vec(dA) +end + +# ============================================================================= +# Kalman benchmarks +# ============================================================================= + +const gc_kf_s = make_gc_kalman(p_gc_small) +const gc_kf_l = make_gc_kalman(p_gc_large) + +# Warmup +bench_forwarddiff_kf!( + vec(copy(gc_kf_s.A)), gc_kf_s.B, gc_kf_s.C, + gc_kf_s.mu_0, gc_kf_s.Sigma_0, gc_kf_s.R, gc_kf_s.y, p_gc_small.N +) +# BatchDuplicated forward is always slower than ForwardDiff for this codebase: +# the shadow-copy overhead for all arguments (sol, cache, etc.) dominates. +# Kept for reference but not benchmarked. +# bench_enzyme_batched_fwd_kf!(zeros(p_gc_small.N^2), +# gc_kf_s.A, gc_kf_s.B, gc_kf_s.C, gc_kf_s.mu_0, gc_kf_s.Sigma_0, gc_kf_s.R, gc_kf_s.y, +# gc_kf_s.sol_out, gc_kf_s.cache, +# gc_kf_s.bd_dAs, gc_kf_s.bd_dBs, gc_kf_s.bd_dCs, gc_kf_s.bd_dmu0s, gc_kf_s.bd_dSig0s, +# gc_kf_s.bd_dRs, gc_kf_s.bd_dys, gc_kf_s.bd_dsols, gc_kf_s.bd_dcaches) +bench_enzyme_reverse_kf!( + gc_kf_s.A, gc_kf_s.B, gc_kf_s.C, + gc_kf_s.mu_0, gc_kf_s.Sigma_0, gc_kf_s.R, gc_kf_s.y, + gc_kf_s.sol_out, gc_kf_s.cache, + gc_kf_s.rv_dA, gc_kf_s.rv_dB, gc_kf_s.rv_dC, gc_kf_s.rv_dmu0, gc_kf_s.rv_dSig0, + gc_kf_s.rv_dR, gc_kf_s.rv_dy, gc_kf_s.rv_dsol, gc_kf_s.rv_dcache +) + +bench_forwarddiff_kf!( + vec(copy(gc_kf_l.A)), gc_kf_l.B, gc_kf_l.C, + gc_kf_l.mu_0, gc_kf_l.Sigma_0, gc_kf_l.R, gc_kf_l.y, p_gc_large.N +) +# bench_enzyme_batched_fwd_kf!(zeros(p_gc_large.N^2), +# gc_kf_l.A, gc_kf_l.B, gc_kf_l.C, gc_kf_l.mu_0, gc_kf_l.Sigma_0, gc_kf_l.R, gc_kf_l.y, +# gc_kf_l.sol_out, gc_kf_l.cache, +# gc_kf_l.bd_dAs, gc_kf_l.bd_dBs, gc_kf_l.bd_dCs, gc_kf_l.bd_dmu0s, gc_kf_l.bd_dSig0s, +# gc_kf_l.bd_dRs, gc_kf_l.bd_dys, gc_kf_l.bd_dsols, gc_kf_l.bd_dcaches) +bench_enzyme_reverse_kf!( + gc_kf_l.A, gc_kf_l.B, gc_kf_l.C, + gc_kf_l.mu_0, gc_kf_l.Sigma_0, gc_kf_l.R, gc_kf_l.y, + gc_kf_l.sol_out, gc_kf_l.cache, + gc_kf_l.rv_dA, gc_kf_l.rv_dB, gc_kf_l.rv_dC, gc_kf_l.rv_dmu0, gc_kf_l.rv_dSig0, + gc_kf_l.rv_dR, gc_kf_l.rv_dy, gc_kf_l.rv_dsol, gc_kf_l.rv_dcache +) + +# --- Kalman ForwardDiff --- +GRAD_CMP["kalman"]["forwarddiff_small"] = @benchmarkable bench_forwarddiff_kf!( + $(vec(copy(gc_kf_s.A))), $(gc_kf_s.B), $(gc_kf_s.C), + $(gc_kf_s.mu_0), $(gc_kf_s.Sigma_0), $(gc_kf_s.R), $(gc_kf_s.y), $(p_gc_small.N) +) + +GRAD_CMP["kalman"]["forwarddiff_large"] = @benchmarkable bench_forwarddiff_kf!( + $(vec(copy(gc_kf_l.A))), $(gc_kf_l.B), $(gc_kf_l.C), + $(gc_kf_l.mu_0), $(gc_kf_l.Sigma_0), $(gc_kf_l.R), $(gc_kf_l.y), $(p_gc_large.N) +) + +# --- Kalman Enzyme BatchDuplicated Forward (commented out — always slower than ForwardDiff) --- +# GRAD_CMP["kalman"]["enzyme_batched_fwd_small"] = @benchmarkable bench_enzyme_batched_fwd_kf!( +# $(zeros(p_gc_small.N^2)), +# $(gc_kf_s.A), $(gc_kf_s.B), $(gc_kf_s.C), $(gc_kf_s.mu_0), $(gc_kf_s.Sigma_0), +# $(gc_kf_s.R), $(gc_kf_s.y), $(gc_kf_s.sol_out), $(gc_kf_s.cache), +# $(gc_kf_s.bd_dAs), $(gc_kf_s.bd_dBs), $(gc_kf_s.bd_dCs), $(gc_kf_s.bd_dmu0s), +# $(gc_kf_s.bd_dSig0s), $(gc_kf_s.bd_dRs), $(gc_kf_s.bd_dys), +# $(gc_kf_s.bd_dsols), $(gc_kf_s.bd_dcaches)) +# +# GRAD_CMP["kalman"]["enzyme_batched_fwd_large"] = @benchmarkable bench_enzyme_batched_fwd_kf!( +# $(zeros(p_gc_large.N^2)), +# $(gc_kf_l.A), $(gc_kf_l.B), $(gc_kf_l.C), $(gc_kf_l.mu_0), $(gc_kf_l.Sigma_0), +# $(gc_kf_l.R), $(gc_kf_l.y), $(gc_kf_l.sol_out), $(gc_kf_l.cache), +# $(gc_kf_l.bd_dAs), $(gc_kf_l.bd_dBs), $(gc_kf_l.bd_dCs), $(gc_kf_l.bd_dmu0s), +# $(gc_kf_l.bd_dSig0s), $(gc_kf_l.bd_dRs), $(gc_kf_l.bd_dys), +# $(gc_kf_l.bd_dsols), $(gc_kf_l.bd_dcaches)) + +# --- Kalman Enzyme Reverse --- +GRAD_CMP["kalman"]["enzyme_reverse_small"] = @benchmarkable bench_enzyme_reverse_kf!( + $(gc_kf_s.A), $(gc_kf_s.B), $(gc_kf_s.C), + $(gc_kf_s.mu_0), $(gc_kf_s.Sigma_0), $(gc_kf_s.R), $(gc_kf_s.y), + $(gc_kf_s.sol_out), $(gc_kf_s.cache), + $(gc_kf_s.rv_dA), $(gc_kf_s.rv_dB), $(gc_kf_s.rv_dC), $(gc_kf_s.rv_dmu0), + $(gc_kf_s.rv_dSig0), $(gc_kf_s.rv_dR), $(gc_kf_s.rv_dy), + $(gc_kf_s.rv_dsol), $(gc_kf_s.rv_dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +GRAD_CMP["kalman"]["enzyme_reverse_large"] = @benchmarkable bench_enzyme_reverse_kf!( + $(gc_kf_l.A), $(gc_kf_l.B), $(gc_kf_l.C), + $(gc_kf_l.mu_0), $(gc_kf_l.Sigma_0), $(gc_kf_l.R), $(gc_kf_l.y), + $(gc_kf_l.sol_out), $(gc_kf_l.cache), + $(gc_kf_l.rv_dA), $(gc_kf_l.rv_dB), $(gc_kf_l.rv_dC), $(gc_kf_l.rv_dmu0), + $(gc_kf_l.rv_dSig0), $(gc_kf_l.rv_dR), $(gc_kf_l.rv_dy), + $(gc_kf_l.rv_dsol), $(gc_kf_l.rv_dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# ============================================================================= +# DirectIteration likelihood setup +# ============================================================================= + +function make_gc_di(p; seed = 42) + (; N, M, K, L, T) = p + Random.seed!(seed) + A_raw = randn(N, N) + A = 0.5 * A_raw / maximum(abs.(eigvals(A_raw))) + B = 0.1 * randn(N, K) + C = randn(M, N) + H = 0.1 * randn(M, L) + R = H * H' + u0 = zeros(N) + noise = [randn(K) for _ in 1:T] + + sim = solve(LinearStateSpaceProblem(A, B, u0, (0, T); C, noise)) + y = [sim.z[t + 1] + H * randn(L) for t in 1:T] + + prob = LinearStateSpaceProblem( + A, B, u0, (0, T); C, + observables_noise = R, observables = y, noise + ) + ws = init(prob, DirectIteration()) + + rv_dA = make_zero(A); rv_dB = make_zero(B); rv_dC = make_zero(C) + rv_du0 = make_zero(u0); rv_dH = make_zero(H) + rv_dnoise = [make_zero(noise[1]) for _ in 1:T] + rv_dy = [make_zero(y[1]) for _ in 1:T] + rv_dsol = make_zero(ws.output); rv_dcache = make_zero(ws.cache) + + return (; + A, B, C, H, R, u0, noise, y, + sol_out = ws.output, cache = ws.cache, + rv_dA, rv_dB, rv_dC, rv_du0, rv_dH, rv_dnoise, rv_dy, rv_dsol, rv_dcache, + ) +end + +# ============================================================================= +# DI wrapper functions +# ============================================================================= + +function _di_loglik_gc!(A, B, C, u0, noise, y, H, sol_out, cache) + R = H * H' + prob = LinearStateSpaceProblem( + A, B, u0, (0, length(y)); + C, observables_noise = R, observables = y, noise + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol_out, cache) + return solve!(ws).logpdf +end + +function _di_loglik_fd_gc(A_vec, B, C, u0, noise, y, H, N) + T_el = eltype(A_vec) + A = reshape(A_vec, N, N) + H_d = _gc_promote_bench(T_el, H) + R = H_d * H_d' + prob = LinearStateSpaceProblem( + A, _gc_promote_bench(T_el, B), + _gc_promote_bench(T_el, u0), (0, length(y)); + C = _gc_promote_bench(T_el, C), + observables_noise = R, + observables = y, noise = noise + ) + sol = solve(prob, DirectIteration()) + return sol.logpdf +end + +function bench_forwarddiff_di!(A_vec, B, C, u0, noise, y, H, N) + return ForwardDiff.gradient( + a -> _di_loglik_fd_gc(a, B, C, u0, noise, y, H, N), A_vec + ) +end + +function bench_enzyme_batched_fwd_di!( + grad_out, A, B, C, u0, noise, y, H, + sol_out, cache, + dAs, dBs, dCs, du0s, dnoises, dys, dHs, dsols, dcaches + ) + chunk_size = length(dAs) + N_params = length(vec(A)) + for chunk_start in 1:chunk_size:N_params + chunk_end = min(chunk_start + chunk_size - 1, N_params) + actual = chunk_end - chunk_start + 1 + + for k in 1:chunk_size + fill_zero!!(dAs[k]); fill_zero!!(dBs[k]); fill_zero!!(dCs[k]) + fill_zero!!(du0s[k]); fill_zero!!(dHs[k]) + for t in eachindex(dnoises[k]) + dnoises[k][t] = fill_zero!!(dnoises[k][t]) + end + for t in eachindex(dys[k]) + dys[k][t] = fill_zero!!(dys[k][t]) + end + make_zero!(dsols[k]); make_zero!(dcaches[k]) + end + for k in 1:actual + dAs[k][chunk_start + k - 1] = 1.0 + end + + result = autodiff( + Forward, _di_loglik_gc!, + BatchDuplicated(A, dAs), + BatchDuplicated(B, dBs), + BatchDuplicated(C, dCs), + BatchDuplicated(u0, du0s), + BatchDuplicated(noise, dnoises), + BatchDuplicated(y, dys), + BatchDuplicated(H, dHs), + BatchDuplicated(sol_out, dsols), + BatchDuplicated(cache, dcaches) + ) + + derivs = values(result[1]) + for k in 1:actual + grad_out[chunk_start + k - 1] = derivs[k] + end + end + return grad_out +end + +function bench_enzyme_reverse_di!( + A, B, C, u0, noise, y, H, + sol_out, cache, dA, dB, dC, du0, dnoise, dy, dH, dsol_out, dcache + ) + make_zero!(dsol_out); make_zero!(dcache) + fill_zero!!(dA); fill_zero!!(dB); fill_zero!!(dC) + fill_zero!!(du0); fill_zero!!(dH) + @inbounds for i in eachindex(dnoise) + dnoise[i] = fill_zero!!(dnoise[i]) + end + @inbounds for i in eachindex(dy) + dy[i] = fill_zero!!(dy[i]) + end + + autodiff( + Reverse, _di_loglik_gc!, Active, + Duplicated(A, dA), Duplicated(B, dB), Duplicated(C, dC), + Duplicated(u0, du0), Duplicated(noise, dnoise), + Duplicated(y, dy), Duplicated(H, dH), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return vec(dA) +end + +# ============================================================================= +# DI benchmarks +# ============================================================================= + +const gc_di_s = make_gc_di(p_gc_small) +const gc_di_l = make_gc_di(p_gc_large) + +# Warmup +bench_forwarddiff_di!( + vec(copy(gc_di_s.A)), gc_di_s.B, gc_di_s.C, + gc_di_s.u0, gc_di_s.noise, gc_di_s.y, gc_di_s.H, p_gc_small.N +) +# bench_enzyme_batched_fwd_di!(zeros(p_gc_small.N^2), +# gc_di_s.A, gc_di_s.B, gc_di_s.C, gc_di_s.u0, gc_di_s.noise, gc_di_s.y, gc_di_s.H, +# gc_di_s.sol_out, gc_di_s.cache, +# gc_di_s.bd_dAs, gc_di_s.bd_dBs, gc_di_s.bd_dCs, gc_di_s.bd_du0s, +# gc_di_s.bd_dnoises, gc_di_s.bd_dys, gc_di_s.bd_dHs, +# gc_di_s.bd_dsols, gc_di_s.bd_dcaches) +bench_enzyme_reverse_di!( + gc_di_s.A, gc_di_s.B, gc_di_s.C, + gc_di_s.u0, gc_di_s.noise, gc_di_s.y, gc_di_s.H, + gc_di_s.sol_out, gc_di_s.cache, + gc_di_s.rv_dA, gc_di_s.rv_dB, gc_di_s.rv_dC, gc_di_s.rv_du0, + gc_di_s.rv_dnoise, gc_di_s.rv_dy, gc_di_s.rv_dH, + gc_di_s.rv_dsol, gc_di_s.rv_dcache +) + +bench_forwarddiff_di!( + vec(copy(gc_di_l.A)), gc_di_l.B, gc_di_l.C, + gc_di_l.u0, gc_di_l.noise, gc_di_l.y, gc_di_l.H, p_gc_large.N +) +# bench_enzyme_batched_fwd_di!(zeros(p_gc_large.N^2), +# gc_di_l.A, gc_di_l.B, gc_di_l.C, gc_di_l.u0, gc_di_l.noise, gc_di_l.y, gc_di_l.H, +# gc_di_l.sol_out, gc_di_l.cache, +# gc_di_l.bd_dAs, gc_di_l.bd_dBs, gc_di_l.bd_dCs, gc_di_l.bd_du0s, +# gc_di_l.bd_dnoises, gc_di_l.bd_dys, gc_di_l.bd_dHs, +# gc_di_l.bd_dsols, gc_di_l.bd_dcaches) +bench_enzyme_reverse_di!( + gc_di_l.A, gc_di_l.B, gc_di_l.C, + gc_di_l.u0, gc_di_l.noise, gc_di_l.y, gc_di_l.H, + gc_di_l.sol_out, gc_di_l.cache, + gc_di_l.rv_dA, gc_di_l.rv_dB, gc_di_l.rv_dC, gc_di_l.rv_du0, + gc_di_l.rv_dnoise, gc_di_l.rv_dy, gc_di_l.rv_dH, + gc_di_l.rv_dsol, gc_di_l.rv_dcache +) + +# --- DI ForwardDiff --- +GRAD_CMP["di_likelihood"]["forwarddiff_small"] = @benchmarkable bench_forwarddiff_di!( + $(vec(copy(gc_di_s.A))), $(gc_di_s.B), $(gc_di_s.C), + $(gc_di_s.u0), $(gc_di_s.noise), $(gc_di_s.y), $(gc_di_s.H), $(p_gc_small.N) +) + +GRAD_CMP["di_likelihood"]["forwarddiff_large"] = @benchmarkable bench_forwarddiff_di!( + $(vec(copy(gc_di_l.A))), $(gc_di_l.B), $(gc_di_l.C), + $(gc_di_l.u0), $(gc_di_l.noise), $(gc_di_l.y), $(gc_di_l.H), $(p_gc_large.N) +) + +# --- DI Enzyme BatchDuplicated Forward (commented out — always slower than ForwardDiff) --- +# GRAD_CMP["di_likelihood"]["enzyme_batched_fwd_small"] = @benchmarkable bench_enzyme_batched_fwd_di!( +# $(zeros(p_gc_small.N^2)), +# $(gc_di_s.A), $(gc_di_s.B), $(gc_di_s.C), $(gc_di_s.u0), +# $(gc_di_s.noise), $(gc_di_s.y), $(gc_di_s.H), +# $(gc_di_s.sol_out), $(gc_di_s.cache), +# $(gc_di_s.bd_dAs), $(gc_di_s.bd_dBs), $(gc_di_s.bd_dCs), $(gc_di_s.bd_du0s), +# $(gc_di_s.bd_dnoises), $(gc_di_s.bd_dys), $(gc_di_s.bd_dHs), +# $(gc_di_s.bd_dsols), $(gc_di_s.bd_dcaches)) +# +# GRAD_CMP["di_likelihood"]["enzyme_batched_fwd_large"] = @benchmarkable bench_enzyme_batched_fwd_di!( +# $(zeros(p_gc_large.N^2)), +# $(gc_di_l.A), $(gc_di_l.B), $(gc_di_l.C), $(gc_di_l.u0), +# $(gc_di_l.noise), $(gc_di_l.y), $(gc_di_l.H), +# $(gc_di_l.sol_out), $(gc_di_l.cache), +# $(gc_di_l.bd_dAs), $(gc_di_l.bd_dBs), $(gc_di_l.bd_dCs), $(gc_di_l.bd_du0s), +# $(gc_di_l.bd_dnoises), $(gc_di_l.bd_dys), $(gc_di_l.bd_dHs), +# $(gc_di_l.bd_dsols), $(gc_di_l.bd_dcaches)) + +# --- DI Enzyme Reverse --- +GRAD_CMP["di_likelihood"]["enzyme_reverse_small"] = @benchmarkable bench_enzyme_reverse_di!( + $(gc_di_s.A), $(gc_di_s.B), $(gc_di_s.C), + $(gc_di_s.u0), $(gc_di_s.noise), $(gc_di_s.y), $(gc_di_s.H), + $(gc_di_s.sol_out), $(gc_di_s.cache), + $(gc_di_s.rv_dA), $(gc_di_s.rv_dB), $(gc_di_s.rv_dC), $(gc_di_s.rv_du0), + $(gc_di_s.rv_dnoise), $(gc_di_s.rv_dy), $(gc_di_s.rv_dH), + $(gc_di_s.rv_dsol), $(gc_di_s.rv_dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +GRAD_CMP["di_likelihood"]["enzyme_reverse_large"] = @benchmarkable bench_enzyme_reverse_di!( + $(gc_di_l.A), $(gc_di_l.B), $(gc_di_l.C), + $(gc_di_l.u0), $(gc_di_l.noise), $(gc_di_l.y), $(gc_di_l.H), + $(gc_di_l.sol_out), $(gc_di_l.cache), + $(gc_di_l.rv_dA), $(gc_di_l.rv_dB), $(gc_di_l.rv_dC), $(gc_di_l.rv_du0), + $(gc_di_l.rv_dnoise), $(gc_di_l.rv_dy), $(gc_di_l.rv_dH), + $(gc_di_l.rv_dsol), $(gc_di_l.rv_dcache) +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +GRAD_CMP diff --git a/benchmark/linear.jl b/benchmark/linear.jl deleted file mode 100644 index 9885070..0000000 --- a/benchmark/linear.jl +++ /dev/null @@ -1,330 +0,0 @@ -#Benchmarking of RBC and FVGQ variants -using DifferenceEquations, BenchmarkTools -using DelimitedFiles, Distributions, Zygote, LinearAlgebra - -# for benchmarking construction itself -function make_problem_1(A, B, C, u0, noise, observables, D; kwargs...) - prob = LinearStateSpaceProblem( - A, B, u0, (0, size(observables, 2)); C, - observables_noise = D, - noise, observables, kwargs... - ) - return prob.A[1, 1] + prob.B[1, 1] -end -function make_problem_kalman(A, B, C, u0_prior_var, observables, D; kwargs...) - prob = LinearStateSpaceProblem( - A, B, zeros(size(A, 1)), (0, size(observables, 2)); C, - u0_prior_var, u0_prior_mean = zeros(size(A, 1)), - observables_noise = D, noise = nothing, observables, - kwargs... - ) - return prob.A[1, 1] + prob.B[1, 1] -end - -# for benchmarking likelihoods -function joint_likelihood_1(A, B, C, u0, noise, observables, D; kwargs...) - prob = LinearStateSpaceProblem( - A, B, u0, (0, size(observables, 2)); C, - observables_noise = D, - noise, observables, kwargs... - ) - return solve(prob).logpdf -end -function kalman_likelihood(A, B, C, u0_prior_var, observables, D; kwargs...) - prob = LinearStateSpaceProblem( - A, B, zeros(size(A, 1)), (0, size(observables, 2)); C, - u0_prior_var, u0_prior_mean = zeros(size(A, 1)), - observables_noise = D, noise = nothing, observables, - kwargs... - ) - return solve(prob).logpdf -end - -function simulate_model_no_noise_1(A, B, C, u0, observables, D; kwargs...) - prob = LinearStateSpaceProblem( - A, B, u0, (0, size(observables, 2)); C, - observables_noise = D, - observables, kwargs... - ) - sol = solve(prob) - return sol.retcode -end - -function simulate_model_no_observations_1(A, B, C, u0, T; kwargs...) - prob = LinearStateSpaceProblem(A, B, u0, (0, T); C, kwargs...) - sol = solve(prob) - return sol.retcode -end - -# Matrices from RBC -const A_rbc = [ - 0.9568351489231076 6.209371005755285; - 3.0153731819288737e-18 0.20000000000000007 -] -const B_rbc = reshape([0.0; -0.01], 2, 1) # make sure B is a matrix -const C_rbc = [0.09579643002426148 0.6746869652592109; 1.0 0.0] -const D_rbc = abs2.([0.1, 0.1]) -const u0_rbc = zeros(2) -const u0_prior_var_rbc = diagm(ones(length(u0_rbc))) - -const observables_rbc = readdlm( - joinpath( - pkgdir(DifferenceEquations), - "test/data/RBC_observables.csv" - ), ',' -)' |> collect -const noise_rbc = readdlm( - joinpath(pkgdir(DifferenceEquations), "test/data/RBC_noise.csv"), - ',' -)' |> - collect -const T_rbc = size(observables_rbc, 2) -# Matrices from FVGQ -# Load FVGQ data for checks -const A_FVGQ = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_A.csv"), ',') -const B_FVGQ = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_B.csv"), ',') -const C_FVGQ = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_C.csv"), ',') -const D_FVGQ = abs2.(ones(6) * 1.0e-3) - -const observables_FVGQ = readdlm( - joinpath( - pkgdir(DifferenceEquations), - "test/data/FVGQ20_observables.csv" - ), ',' -)' |> - collect - -const noise_FVGQ = readdlm( - joinpath( - pkgdir(DifferenceEquations), - "test/data/FVGQ20_noise.csv" - ), - ',' -)' |> collect -const u0_FVGQ = zeros(size(A_FVGQ, 1)) -const u0_prior_var_FVGQ = diagm(ones(length(u0_FVGQ))) -const T_FVGQ = size(observables_FVGQ, 2) -# executing gradients once to avoid compilation time in benchmarking - -make_problem_1(A_rbc, B_rbc, C_rbc, u0_rbc, noise_rbc, observables_rbc, D_rbc) -gradient( - (args...) -> make_problem_1(args..., observables_rbc, D_rbc), A_rbc, B_rbc, C_rbc, - u0_rbc, - noise_rbc -) - -make_problem_kalman(A_rbc, B_rbc, C_rbc, u0_prior_var_rbc, observables_rbc, D_rbc) -gradient( - (args...) -> make_problem_kalman(args..., observables_rbc, D_rbc), A_rbc, B_rbc, - C_rbc, - u0_prior_var_rbc -) - -kalman_likelihood(A_rbc, B_rbc, C_rbc, u0_prior_var_rbc, observables_rbc, D_rbc) -gradient( - (args...) -> kalman_likelihood(args..., observables_rbc, D_rbc), A_rbc, B_rbc, - C_rbc, - u0_prior_var_rbc -) -joint_likelihood_1(A_rbc, B_rbc, C_rbc, u0_rbc, noise_rbc, observables_rbc, D_rbc) -gradient( - (args...) -> joint_likelihood_1(args..., observables_rbc, D_rbc), A_rbc, B_rbc, - C_rbc, - u0_rbc, noise_rbc -) - -kalman_likelihood(A_FVGQ, B_FVGQ, C_FVGQ, u0_prior_var_FVGQ, observables_FVGQ, D_FVGQ) -gradient( - (args...) -> kalman_likelihood(args..., observables_FVGQ, D_FVGQ), A_FVGQ, B_FVGQ, - C_FVGQ, - u0_prior_var_FVGQ -) -joint_likelihood_1(A_FVGQ, B_FVGQ, C_FVGQ, u0_FVGQ, noise_FVGQ, observables_FVGQ, D_FVGQ) -gradient( - (args...) -> joint_likelihood_1(args..., observables_FVGQ, D_FVGQ), A_FVGQ, B_FVGQ, - C_FVGQ, - u0_FVGQ, noise_FVGQ -) - -####### Benchmarks - -const LINEAR = BenchmarkGroup() - -const LINEAR["rbc"] = BenchmarkGroup() - -const LINEAR["rbc"]["make_problem_1"] = @benchmarkable make_problem_1( - $A_rbc, $B_rbc, - $C_rbc, - $u0_rbc, $noise_rbc, - $observables_rbc, - $D_rbc -) -const LINEAR["rbc"]["make_problem_1_gradient"] = @benchmarkable gradient( - (args...) -> make_problem_1( - args..., - $observables_rbc, - $D_rbc - ), - $A_rbc, $B_rbc, - $C_rbc, - $u0_rbc, - $noise_rbc -) - -const LINEAR["rbc"]["simulate_model_no_noise_1"] = @benchmarkable simulate_model_no_noise_1( - $A_rbc, - $B_rbc, - $C_rbc, - $u0_rbc, - $observables_rbc, - $D_rbc -) -const LINEAR["rbc"]["simulate_model_no_observations_1"] = @benchmarkable simulate_model_no_observations_1( - $A_rbc, - $B_rbc, - $C_rbc, - $u0_rbc, - T_rbc -) -const LINEAR["rbc"]["joint_1"] = @benchmarkable joint_likelihood_1( - $A_rbc, $B_rbc, $C_rbc, - $u0_rbc, - $noise_rbc, - $observables_rbc, - $D_rbc -) -const LINEAR["rbc"]["joint_1_gradient"] = @benchmarkable gradient( - (args...) -> joint_likelihood_1( - args..., - $observables_rbc, - $D_rbc - ), - $A_rbc, $B_rbc, $C_rbc, - $u0_rbc, - $noise_rbc -) -const LINEAR["rbc"]["make_problem_kalman"] = @benchmarkable make_problem_kalman( - $A_rbc, - $B_rbc, - $C_rbc, - $u0_prior_var_rbc, - $observables_rbc, - $D_rbc -) -const LINEAR["rbc"]["make_problem_kalman_gradient"] = @benchmarkable gradient( - (args...) -> make_problem_kalman( - args..., - $observables_rbc, - $D_rbc - ), - $A_rbc, - $B_rbc, - $C_rbc, - $u0_prior_var_rbc -) -const LINEAR["rbc"]["kalman"] = @benchmarkable kalman_likelihood( - $A_rbc, $B_rbc, $C_rbc, - $u0_prior_var_rbc, - $observables_rbc, - $D_rbc -) -const LINEAR["rbc"]["kalman_gradient"] = @benchmarkable gradient( - (args...) -> kalman_likelihood( - args..., - $observables_rbc, - $D_rbc - ), - $A_rbc, $B_rbc, $C_rbc, - $u0_prior_var_rbc -) - -# FVGQ -const LINEAR["FVGQ"] = BenchmarkGroup() -const LINEAR["FVGQ"]["make_problem_1"] = @benchmarkable make_problem_1( - $A_FVGQ, $B_FVGQ, - $C_FVGQ, - $u0_FVGQ, - $noise_FVGQ, - $observables_FVGQ, - $D_FVGQ -) -const LINEAR["FVGQ"]["make_problem_1_gradient"] = @benchmarkable gradient( - (args...) -> make_problem_1( - args..., - $observables_FVGQ, - $D_FVGQ - ), - $A_FVGQ, $B_FVGQ, - $C_FVGQ, - $u0_FVGQ, - $noise_FVGQ -) -const LINEAR["FVGQ"]["simulate_model_no_noise_1"] = @benchmarkable simulate_model_no_noise_1( - $A_FVGQ, - $B_FVGQ, - $C_FVGQ, - $u0_FVGQ, - $observables_FVGQ, - $D_FVGQ -) -const LINEAR["FVGQ"]["simulate_model_no_observations_1"] = @benchmarkable simulate_model_no_observations_1( - $A_FVGQ, - $B_FVGQ, - $C_FVGQ, - $u0_FVGQ, - $T_FVGQ -) -const LINEAR["FVGQ"]["joint_1"] = @benchmarkable joint_likelihood_1( - $A_FVGQ, $B_FVGQ, - $C_FVGQ, - $u0_FVGQ, $noise_FVGQ, - $observables_FVGQ, - $D_FVGQ -) -const LINEAR["FVGQ"]["joint_1_gradient"] = @benchmarkable gradient( - (args...) -> joint_likelihood_1( - args..., - $observables_FVGQ, - $D_FVGQ - ), - $A_FVGQ, $B_FVGQ, - $C_FVGQ, - $u0_FVGQ, $noise_FVGQ -) -const LINEAR["FVGQ"]["make_problem_kalman"] = @benchmarkable make_problem_kalman( - $A_FVGQ, - $B_FVGQ, - $C_FVGQ, - $u0_prior_var_FVGQ, - $observables_FVGQ, - $D_FVGQ -) -const LINEAR["FVGQ"]["make_problem_kalman_gradient"] = @benchmarkable gradient( - (args...) -> make_problem_kalman( - args..., - $observables_FVGQ, - $D_FVGQ - ), - $A_FVGQ, - $B_FVGQ, - $C_FVGQ, - $u0_prior_var_FVGQ -) -const LINEAR["FVGQ"]["kalman"] = @benchmarkable kalman_likelihood( - $A_FVGQ, $B_FVGQ, $C_FVGQ, - $u0_prior_var_FVGQ, - $observables_FVGQ, - $D_FVGQ -) -const LINEAR["FVGQ"]["kalman_gradient"] = @benchmarkable gradient( - (args...) -> kalman_likelihood( - args..., - $observables_FVGQ, - $D_FVGQ - ), - $A_FVGQ, $B_FVGQ, $C_FVGQ, - $u0_prior_var_FVGQ -) - -# return for the test suite -LINEAR diff --git a/benchmark/quadratic.jl b/benchmark/quadratic.jl deleted file mode 100644 index e443e62..0000000 --- a/benchmark/quadratic.jl +++ /dev/null @@ -1,332 +0,0 @@ -#Benchmarking of RBC and FVGQ variants -using DifferenceEquations, BenchmarkTools, LinearAlgebra -using DelimitedFiles, Distributions, Zygote -# General likelihood calculation -function make_problem_2( - A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, observables, D; - kwargs... - ) - prob = QuadraticStateSpaceProblem( - A_0, A_1, A_2, B, u0, (0, size(observables, 2)); C_0, - C_1, - C_2, observables_noise = D, - noise, - observables, kwargs... - ) - return prob.A_1[1, 1] + prob.B[1, 1] -end - -function joint_likelihood_2( - A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, observables, D; - kwargs... - ) - prob = QuadraticStateSpaceProblem( - A_0, A_1, A_2, B, u0, (0, size(observables, 2)); C_0, - C_1, - C_2, observables_noise = D, - noise, - observables, kwargs... - ) - return solve(prob).logpdf -end - -function simulate_model_no_noise_2( - A_0, A_1, A_2, B, C_0, C_1, C_2, u0, observables, D; - kwargs... - ) - prob = QuadraticStateSpaceProblem( - A_0, A_1, A_2, B, u0, (0, size(observables, 2)); C_0, - C_1, - C_2, observables_noise = D, - observables, kwargs... - ) - sol = solve(prob) - return sol.retcode -end - -function simulate_model_no_observations_2(A_0, A_1, A_2, B, C_0, C_1, C_2, u0, T; kwargs...) - prob = QuadraticStateSpaceProblem( - A_0, A_1, A_2, B, u0, (0, T); C_0, C_1, C_2, - kwargs... - ) - sol = solve(prob) - return sol.retcode -end - -const QUADRATIC = BenchmarkGroup() - -# Matrices from RBC -const A_0_rbc = [-7.824904812740593e-5, 0.0] -const A_1_rbc = [ - 0.9568351489231076 6.209371005755285; - 3.0153731819288737e-18 0.20000000000000007 -] -const A_2_rbc = cat( - [-0.00019761505863889124 0.03375055315837927; 0.0 0.0], - [0.03375055315837913 3.128758481817603; 0.0 0.0]; dims = 3 -) -const B_2_rbc = reshape([0.0; -0.01], 2, 1) # make sure B is a matrix -const C_0_rbc = [7.824904812740593e-5, 0.0] -const C_1_rbc = [0.09579643002426148 0.6746869652592109; 1.0 0.0] -const C_2_rbc = cat( - [-0.00018554166974717046 0.0025652363153049716; 0.0 0.0], - [0.002565236315304951 0.3132705036896446; 0.0 0.0]; dims = 3 -) -const D_2_rbc = [0.1, 0.1] -const u0_2_rbc = zeros(2) - -const observables_2_rbc = readdlm( - joinpath( - pkgdir(DifferenceEquations), - "test/data/RBC_observables.csv" - ), ',' -)' |> - collect -const noise_2_rbc = readdlm( - joinpath( - pkgdir(DifferenceEquations), - "test/data/RBC_noise.csv" - ), - ',' -)' |> collect -const T_2_rbc = size(observables_2_rbc, 2) -# Matrices from FVGQ -# Load FVGQ data for checks -A_0_raw = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_A_0.csv"), ',') -const A_0_FVGQ = vec(A_0_raw) -const A_1_FVGQ = readdlm( - joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_A_1.csv"), - ',' -) -A_2_raw = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_A_2.csv"), ',') -const A_2_FVGQ = reshape(A_2_raw, length(A_0_FVGQ), length(A_0_FVGQ), length(A_0_FVGQ)) -const B_2_FVGQ = readdlm( - joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_B.csv"), - ',' -) -C_0_raw = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_C_0.csv"), ',') -const C_0_FVGQ = vec(C_0_raw) -const C_1_FVGQ = readdlm( - joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_C_1.csv"), - ',' -) -C_2_raw = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_C_2.csv"), ',') -const C_2_FVGQ = reshape(C_2_raw, length(C_0_FVGQ), length(A_0_FVGQ), length(A_0_FVGQ)) -# D_raw = readdlm(joinpath(pkgdir(DifferenceEquations), "FVGQ_D.csv"); header = false))) -D_2_FVGQ = ones(6) * 1.0e-3 -u0_2_FVGQ = zeros(size(A_1_FVGQ, 1)) -const observables_2_FVGQ = readdlm( - joinpath( - pkgdir(DifferenceEquations), - "test/data/FVGQ20_observables.csv" - ), ',' -)' |> - collect - -const noise_2_FVGQ = readdlm( - joinpath( - pkgdir(DifferenceEquations), - "test/data/FVGQ20_noise.csv" - ), - ',' -)' |> collect -const T_2_FVGQ = size(observables_2_FVGQ, 2) - -# RBC sized specific tests -# Verifying code prior to benchmark -# executing gradients once to avoid compilation time in benchmarking -make_problem_2( - A_0_rbc, A_1_rbc, A_2_rbc, B_2_rbc, C_0_rbc, C_1_rbc, C_2_rbc, u0_2_rbc, - noise_2_rbc, - observables_2_rbc, D_2_rbc -) -gradient( - (args...) -> make_problem_2(args..., observables_2_rbc, D_2_rbc), A_0_rbc, A_1_rbc, - A_2_rbc, B_2_rbc, C_0_rbc, C_1_rbc, C_2_rbc, u0_2_rbc, noise_2_rbc -) -make_problem_2( - A_0_FVGQ, A_1_FVGQ, A_2_FVGQ, B_2_FVGQ, C_0_FVGQ, C_1_FVGQ, C_2_FVGQ, - u0_2_FVGQ, - noise_2_FVGQ, observables_2_FVGQ, D_2_FVGQ -) -gradient( - (args...) -> make_problem_2(args..., observables_2_FVGQ, D_2_FVGQ), A_0_FVGQ, - A_1_FVGQ, - A_2_FVGQ, B_2_FVGQ, C_0_FVGQ, C_1_FVGQ, C_2_FVGQ, u0_2_FVGQ, noise_2_FVGQ -) - -joint_likelihood_2( - A_0_rbc, A_1_rbc, A_2_rbc, B_2_rbc, C_0_rbc, C_1_rbc, C_2_rbc, u0_2_rbc, - noise_2_rbc, observables_2_rbc, D_2_rbc -) -gradient( - (args...) -> joint_likelihood_2(args..., observables_2_rbc, D_2_rbc), A_0_rbc, - A_1_rbc, - A_2_rbc, B_2_rbc, C_0_rbc, C_1_rbc, C_2_rbc, u0_2_rbc, noise_2_rbc -) -joint_likelihood_2( - A_0_FVGQ, A_1_FVGQ, A_2_FVGQ, B_2_FVGQ, C_0_FVGQ, C_1_FVGQ, C_2_FVGQ, - u0_2_FVGQ, - noise_2_FVGQ, observables_2_FVGQ, D_2_FVGQ -) -gradient( - (args...) -> joint_likelihood_2(args..., observables_2_FVGQ, D_2_FVGQ), A_0_FVGQ, - A_1_FVGQ, - A_2_FVGQ, B_2_FVGQ, C_0_FVGQ, C_1_FVGQ, C_2_FVGQ, u0_2_FVGQ, noise_2_FVGQ -) - -const QUADRATIC["rbc"] = BenchmarkGroup() -const QUADRATIC["rbc"]["make_problem_2"] = @benchmarkable make_problem_2( - $A_0_rbc, $A_1_rbc, - $A_2_rbc, $B_2_rbc, - $C_0_rbc, $C_1_rbc, - $C_2_rbc, - $u0_2_rbc, - $noise_2_rbc, - $observables_2_rbc, - $D_2_rbc -) -const QUADRATIC["rbc"]["make_problem_2_gradient"] = @benchmarkable gradient( - (args...) -> make_problem_2( - args..., - $observables_2_rbc, - $D_2_rbc - ), - $A_0_rbc, - $A_1_rbc, - $A_2_rbc, - $B_2_rbc, - $C_0_rbc, - $C_1_rbc, - $C_2_rbc, - $u0_2_rbc, - $noise_2_rbc -) - -const QUADRATIC["rbc"]["simulate_model_no_noise_2"] = @benchmarkable simulate_model_no_noise_2( - $A_0_rbc, - $A_1_rbc, - $A_2_rbc, - $B_2_rbc, - $C_0_rbc, - $C_1_rbc, - $C_2_rbc, - $u0_2_rbc, - $observables_2_rbc, - $D_2_rbc -) -const QUADRATIC["rbc"]["simulate_model_no_observations_2"] = @benchmarkable simulate_model_no_observations_2( - $A_0_rbc, - $A_1_rbc, - $A_2_rbc, - $B_2_rbc, - $C_0_rbc, - $C_1_rbc, - $C_2_rbc, - $u0_2_rbc, - $T_2_rbc -) -const QUADRATIC["rbc"]["joint_2"] = @benchmarkable joint_likelihood_2( - $A_0_rbc, $A_1_rbc, - $A_2_rbc, - $B_2_rbc, $C_0_rbc, - $C_1_rbc, - $C_2_rbc, $u0_2_rbc, - $noise_2_rbc, - $observables_2_rbc, - $D_2_rbc -) -const QUADRATIC["rbc"]["joint_2_gradient"] = @benchmarkable gradient( - (args...) -> joint_likelihood_2( - args..., - $observables_2_rbc, - $D_2_rbc - ), - $A_0_rbc, $A_1_rbc, - $A_2_rbc, - $B_2_rbc, $C_0_rbc, - $C_1_rbc, - $C_2_rbc, $u0_2_rbc, - $noise_2_rbc -) - -# FVGQ sized specific test -const QUADRATIC["FVGQ"] = BenchmarkGroup() -const QUADRATIC["FVGQ"]["make_problem_2"] = @benchmarkable make_problem_2( - $A_0_FVGQ, - $A_1_FVGQ, - $A_2_FVGQ, - $B_2_FVGQ, - $C_0_FVGQ, - $C_1_FVGQ, - $C_2_FVGQ, - $u0_2_FVGQ, - $noise_2_FVGQ, - $observables_2_FVGQ, - $D_2_FVGQ -) -const QUADRATIC["FVGQ"]["make_problem_2_gradient"] = @benchmarkable gradient( - (args...) -> make_problem_2( - args..., - $observables_2_FVGQ, - $D_2_FVGQ - ), - $A_0_FVGQ, - $A_1_FVGQ, - $A_2_FVGQ, - $B_2_FVGQ, - $C_0_FVGQ, - $C_1_FVGQ, - $C_2_FVGQ, - $u0_2_FVGQ, - $noise_2_FVGQ -) - -const QUADRATIC["FVGQ"]["simulate_model_no_noise_2"] = @benchmarkable simulate_model_no_noise_2( - $A_0_FVGQ, - $A_1_FVGQ, - $A_2_FVGQ, - $B_2_FVGQ, - $C_0_FVGQ, - $C_1_FVGQ, - $C_2_FVGQ, - $u0_2_FVGQ, - $observables_2_FVGQ, - $D_2_FVGQ -) -const QUADRATIC["FVGQ"]["simulate_model_no_observations_2"] = @benchmarkable simulate_model_no_observations_2( - $A_0_FVGQ, - $A_1_FVGQ, - $A_2_FVGQ, - $B_2_FVGQ, - $C_0_FVGQ, - $C_1_FVGQ, - $C_2_FVGQ, - $u0_2_FVGQ, - $T_2_FVGQ -) - -const QUADRATIC["FVGQ"]["joint_2"] = @benchmarkable joint_likelihood_2( - $A_0_FVGQ, $A_1_FVGQ, - $A_2_FVGQ, $B_2_FVGQ, - $C_0_FVGQ, $C_1_FVGQ, - $C_2_FVGQ, - $u0_2_FVGQ, - $noise_2_FVGQ, - $observables_2_FVGQ, - $D_2_FVGQ -) -const QUADRATIC["FVGQ"]["joint_2_gradient"] = @benchmarkable gradient( - (args...) -> joint_likelihood_2( - args..., - $observables_2_FVGQ, - $D_2_FVGQ - ), - $A_0_FVGQ, $A_1_FVGQ, - $A_2_FVGQ, $B_2_FVGQ, - $C_0_FVGQ, $C_1_FVGQ, - $C_2_FVGQ, $u0_2_FVGQ, - $noise_2_FVGQ -) -# return for the test suite -QUADRATIC diff --git a/benchmark/static_arrays.jl b/benchmark/static_arrays.jl new file mode 100644 index 0000000..f19dea6 --- /dev/null +++ b/benchmark/static_arrays.jl @@ -0,0 +1,764 @@ +# StaticArrays primal performance via workspace pattern (init/solve!) +# Vector{SVector} workspace works because the solver uses reassignment: +# u[t] = _transition!!(u[t], ...) — bang-bang returns new SVector, outer = replaces element +# +# Returns SA_BENCH BenchmarkGroup + +using StaticArrays +using Enzyme: make_zero, make_zero!, remake_zero! +using DifferenceEquations: init, solve!, mul!!, muladd!!, fill_zero!!, StateSpaceWorkspace + +const SA_BENCH = BenchmarkGroup() +SA_BENCH["linear"] = BenchmarkGroup() +SA_BENCH["generic"] = BenchmarkGroup() +SA_BENCH["kalman"] = BenchmarkGroup() + +function bench_solve!(ws) + solve!(ws) + return nothing +end + +# --- Linear DirectIteration N=2, T=20 --- + +const A_sa_2 = @SMatrix [0.9 0.1; 0.0 0.8] +const B_sa_2 = @SMatrix [0.0; 0.1;;] +const C_sa_2 = @SMatrix [1.0 0.0; 0.0 1.0] +const u0_sa_2 = @SVector [0.5, 0.3] +const noise_sa_2 = [SVector{1}(randn()) for _ in 1:20] + +const ws_ls2 = init( + LinearStateSpaceProblem( + A_sa_2, B_sa_2, u0_sa_2, (0, 20); + C = C_sa_2, noise = noise_sa_2 + ), DirectIteration() +) +SA_BENCH["linear"]["static_2x2"] = @benchmarkable bench_solve!($ws_ls2) + +const ws_lm2 = init( + LinearStateSpaceProblem( + Matrix(A_sa_2), Matrix(B_sa_2), Vector(u0_sa_2), (0, 20); + C = Matrix(C_sa_2), noise = [Vector(n) for n in noise_sa_2] + ), DirectIteration() +) +SA_BENCH["linear"]["mutable_2x2"] = @benchmarkable bench_solve!($ws_lm2) + +# --- Linear DirectIteration N=5, T=50 --- + +Random.seed!(123) +const A_sa_5_raw = randn(5, 5) +const A_sa_5 = SMatrix{5, 5}(0.5 * A_sa_5_raw / maximum(abs.(eigvals(A_sa_5_raw)))) +const B_sa_5 = SMatrix{5, 2}(0.1 * randn(5, 2)) +const C_sa_5 = SMatrix{3, 5}(randn(3, 5)) +const u0_sa_5 = SVector{5}(zeros(5)) +const noise_sa_5 = [SVector{2}(randn(2)) for _ in 1:50] + +const ws_ls5 = init( + LinearStateSpaceProblem( + A_sa_5, B_sa_5, u0_sa_5, (0, 50); + C = C_sa_5, noise = noise_sa_5 + ), DirectIteration() +) +SA_BENCH["linear"]["static_5x5"] = @benchmarkable bench_solve!($ws_ls5) + +const ws_lm5 = init( + LinearStateSpaceProblem( + Matrix(A_sa_5), Matrix(B_sa_5), Vector(u0_sa_5), (0, 50); + C = Matrix(C_sa_5), noise = [Vector(n) for n in noise_sa_5] + ), DirectIteration() +) +SA_BENCH["linear"]["mutable_5x5"] = @benchmarkable bench_solve!($ws_lm5) + +# --- Generic !! callbacks --- + +@inline function f_lss_sa!!(x_p, x, w, p, t) + x_p = mul!!(x_p, p.A, x) + return muladd!!(x_p, p.B, w) +end + +@inline function g_lss_sa!!(y, x, p, t) + return mul!!(y, p.C, x) +end + +# --- Generic N=2, T=20 --- + +const p_gen_s2 = (; A = A_sa_2, B = B_sa_2, C = C_sa_2) +const ws_gs2 = init( + StateSpaceProblem( + f_lss_sa!!, g_lss_sa!!, u0_sa_2, (0, 20), p_gen_s2; + n_shocks = 1, n_obs = 2, noise = noise_sa_2 + ), DirectIteration() +) +SA_BENCH["generic"]["static_2x2"] = @benchmarkable bench_solve!($ws_gs2) + +const p_gen_m2 = (; A = Matrix(A_sa_2), B = Matrix(B_sa_2), C = Matrix(C_sa_2)) +const ws_gm2 = init( + StateSpaceProblem( + f_lss_sa!!, g_lss_sa!!, Vector(u0_sa_2), (0, 20), p_gen_m2; + n_shocks = 1, n_obs = 2, noise = [Vector(n) for n in noise_sa_2] + ), DirectIteration() +) +SA_BENCH["generic"]["mutable_2x2"] = @benchmarkable bench_solve!($ws_gm2) + +# --- Generic N=5, T=50 --- + +const p_gen_s5 = (; A = A_sa_5, B = B_sa_5, C = C_sa_5) +const ws_gs5 = init( + StateSpaceProblem( + f_lss_sa!!, g_lss_sa!!, u0_sa_5, (0, 50), p_gen_s5; + n_shocks = 2, n_obs = 3, noise = noise_sa_5 + ), DirectIteration() +) +SA_BENCH["generic"]["static_5x5"] = @benchmarkable bench_solve!($ws_gs5) + +const p_gen_m5 = (; A = Matrix(A_sa_5), B = Matrix(B_sa_5), C = Matrix(C_sa_5)) +const ws_gm5 = init( + StateSpaceProblem( + f_lss_sa!!, g_lss_sa!!, Vector(u0_sa_5), (0, 50), p_gen_m5; + n_shocks = 2, n_obs = 3, noise = [Vector(n) for n in noise_sa_5] + ), DirectIteration() +) +SA_BENCH["generic"]["mutable_5x5"] = @benchmarkable bench_solve!($ws_gm5) + +# --- Kalman filter N=3, M=2, T=10 --- + +Random.seed!(789) +const A_kf_3_raw = randn(3, 3) +const A_kf_3 = SMatrix{3, 3}(0.5 * A_kf_3_raw / maximum(abs.(eigvals(A_kf_3_raw)))) +const B_kf_3 = SMatrix{3, 2}(0.1 * randn(3, 2)) +const C_kf_3 = SMatrix{2, 3}(randn(2, 3)) +const R_kf_3 = SMatrix{2, 2}(0.01 * I(2)) +const mu0_kf_3 = SVector{3}(zeros(3)) +const Sig0_kf_3 = SMatrix{3, 3}(1.0 * I(3)) + +# Generate observations for Kalman +const noise_kf_3 = [SVector{2}(randn(2)) for _ in 1:10] +const sim_kf_3 = solve( + LinearStateSpaceProblem( + A_kf_3, B_kf_3, mu0_kf_3, (0, 10); + C = C_kf_3, noise = noise_kf_3 + ) +) +const y_kf_3 = [sim_kf_3.z[t + 1] + SVector{2}(0.1 * randn(2)) for t in 1:10] + +const ws_ks3 = init( + LinearStateSpaceProblem( + A_kf_3, B_kf_3, mu0_kf_3, (0, 10); + C = C_kf_3, u0_prior_mean = mu0_kf_3, u0_prior_var = Sig0_kf_3, + observables_noise = R_kf_3, observables = y_kf_3 + ), KalmanFilter() +) +SA_BENCH["kalman"]["static_3x3"] = @benchmarkable bench_solve!($ws_ks3) + +const ws_km3 = init( + LinearStateSpaceProblem( + Matrix(A_kf_3), Matrix(B_kf_3), Vector(mu0_kf_3), (0, 10); + C = Matrix(C_kf_3), u0_prior_mean = Vector(mu0_kf_3), u0_prior_var = Matrix(Sig0_kf_3), + observables_noise = Matrix(R_kf_3), observables = [Vector(y) for y in y_kf_3] + ), KalmanFilter() +) +SA_BENCH["kalman"]["mutable_3x3"] = @benchmarkable bench_solve!($ws_km3) + +# --- Kalman filter N=5, M=3, T=20 --- + +Random.seed!(101) +const A_kf_5_raw = randn(5, 5) +const A_kf_5 = SMatrix{5, 5}(0.5 * A_kf_5_raw / maximum(abs.(eigvals(A_kf_5_raw)))) +const B_kf_5 = SMatrix{5, 2}(0.1 * randn(5, 2)) +const C_kf_5 = SMatrix{3, 5}(randn(3, 5)) +const R_kf_5 = SMatrix{3, 3}(0.01 * I(3)) +const mu0_kf_5 = SVector{5}(zeros(5)) +const Sig0_kf_5 = SMatrix{5, 5}(1.0 * I(5)) + +const noise_kf_5 = [SVector{2}(randn(2)) for _ in 1:20] +const sim_kf_5 = solve( + LinearStateSpaceProblem( + A_kf_5, B_kf_5, mu0_kf_5, (0, 20); + C = C_kf_5, noise = noise_kf_5 + ) +) +const y_kf_5 = [sim_kf_5.z[t + 1] + SVector{3}(0.1 * randn(3)) for t in 1:20] + +const ws_ks5 = init( + LinearStateSpaceProblem( + A_kf_5, B_kf_5, mu0_kf_5, (0, 20); + C = C_kf_5, u0_prior_mean = mu0_kf_5, u0_prior_var = Sig0_kf_5, + observables_noise = R_kf_5, observables = y_kf_5 + ), KalmanFilter() +) +SA_BENCH["kalman"]["static_5x5"] = @benchmarkable bench_solve!($ws_ks5) + +const ws_km5 = init( + LinearStateSpaceProblem( + Matrix(A_kf_5), Matrix(B_kf_5), Vector(mu0_kf_5), (0, 20); + C = Matrix(C_kf_5), u0_prior_mean = Vector(mu0_kf_5), u0_prior_var = Matrix(Sig0_kf_5), + observables_noise = Matrix(R_kf_5), observables = [Vector(y) for y in y_kf_5] + ), KalmanFilter() +) +SA_BENCH["kalman"]["mutable_5x5"] = @benchmarkable bench_solve!($ws_km5) + +# --- Quadratic PrunedQuadraticStateSpaceProblem (pruned, using new types) --- + +SA_BENCH["quadratic"] = BenchmarkGroup() + +# N=2, K=1, M=2, T=10 +Random.seed!(42) +const A_2_q = 0.01 * randn(2, 2, 2) +const C_2_q = 0.01 * randn(2, 2, 2) +const noise_q = [randn() for _ in 1:10] + +const As1 = @SMatrix [0.3 0.1; -0.1 0.3] +const As0 = @SVector [0.001, -0.001] +const Bs = @SMatrix [0.1; 0.0;;] +const Cs0 = @SVector [0.001, -0.001] +const Cs1 = @SMatrix [1.0 0.0; 0.0 1.0] +const u0s = @SVector zeros(2) +const noise_s = [SVector{1}(n) for n in noise_q] + +# Static 2x2 (pruned quadratic) +const prob_qs = PrunedQuadraticStateSpaceProblem( + As0, As1, A_2_q, Bs, u0s, (0, 10); + C_0 = Cs0, C_1 = Cs1, C_2 = C_2_q, noise = noise_s +) +const ws_qs = init(prob_qs, DirectIteration()) +SA_BENCH["quadratic"]["static_2x2"] = @benchmarkable bench_solve!($ws_qs) + +# Mutable 2x2 (pruned quadratic) +const prob_qm = PrunedQuadraticStateSpaceProblem( + Vector(As0), Matrix(As1), copy(A_2_q), Matrix(Bs), + Vector(u0s), (0, 10); + C_0 = Vector(Cs0), C_1 = Matrix(Cs1), C_2 = copy(C_2_q), + noise = [Vector(n) for n in noise_s] +) +const ws_qm = init(prob_qm, DirectIteration()) +SA_BENCH["quadratic"]["mutable_2x2"] = @benchmarkable bench_solve!($ws_qm) + +# ============================================================================= +# AD benchmarks for Kalman filter (static and mutable) +# ============================================================================= + +SA_BENCH["kalman"]["forward"] = BenchmarkGroup() +SA_BENCH["kalman"]["reverse"] = BenchmarkGroup() + +function kalman_fwd_sa!(A, B, C, mu_0, Sigma_0, R, y, sol_out, cache) + prob = LinearStateSpaceProblem( + A, B, zero(mu_0), (0, length(y)); C, + u0_prior_mean = mu_0, u0_prior_var = Sigma_0, + observables_noise = R, observables = y + ) + ws = StateSpaceWorkspace(prob, KalmanFilter(), sol_out, cache) + solve!(ws) + return (sol_out.u[end], sol_out.P[end]) +end + +function kalman_rev_sa!(A, B, C, mu_0, Sigma_0, R, y, sol_out, cache) + prob = LinearStateSpaceProblem( + A, B, zero(mu_0), (0, length(y)); C, + u0_prior_mean = mu_0, u0_prior_var = Sigma_0, + observables_noise = R, observables = y + ) + ws = StateSpaceWorkspace(prob, KalmanFilter(), sol_out, cache) + return solve!(ws).logpdf +end + +function forward_kalman_sa!( + A, B, C, mu_0, Sigma_0, R, y, sol_out, cache, + dA, dB, dC, dmu_0, dSigma_0, dR, dy, dsol_out, dcache + ) + remake_zero!(dsol_out); remake_zero!(dcache) + dA = fill_zero!!(dA); dB = fill_zero!!(dB); dC = fill_zero!!(dC) + dmu_0 = fill_zero!!(dmu_0); dSigma_0 = fill_zero!!(dSigma_0); dR = fill_zero!!(dR) + @inbounds for i in eachindex(dy) + dy[i] = fill_zero!!(dy[i]) + end + if ismutable(dA) + dA[1, 1] = 1.0 + else + dA = setindex(dA, 1.0, 1, 1) + end + autodiff( + Forward, kalman_fwd_sa!, + Duplicated(A, dA), Duplicated(B, dB), Duplicated(C, dC), + Duplicated(mu_0, dmu_0), Duplicated(Sigma_0, dSigma_0), + Duplicated(R, dR), Duplicated(y, dy), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return nothing +end + +function reverse_kalman_sa!( + A, B, C, mu_0, Sigma_0, R, y, sol_out, cache, + dA, dB, dC, dmu_0, dSigma_0, dR, dy, dsol_out, dcache + ) + remake_zero!(dsol_out); remake_zero!(dcache) + dA = fill_zero!!(dA); dB = fill_zero!!(dB); dC = fill_zero!!(dC) + dmu_0 = fill_zero!!(dmu_0); dSigma_0 = fill_zero!!(dSigma_0); dR = fill_zero!!(dR) + @inbounds for i in eachindex(dy) + dy[i] = fill_zero!!(dy[i]) + end + autodiff( + Reverse, kalman_rev_sa!, Active, + Duplicated(A, dA), Duplicated(B, dB), Duplicated(C, dC), + Duplicated(mu_0, dmu_0), Duplicated(Sigma_0, dSigma_0), + Duplicated(R, dR), Duplicated(y, dy), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return nothing +end + +# --- Kalman 3x3 static AD shadows --- + +const dA_kf3s = make_zero(A_kf_3); const dB_kf3s = make_zero(B_kf_3) +const dC_kf3s = make_zero(C_kf_3); const dmu0_kf3s = make_zero(mu0_kf_3) +const dSig0_kf3s = make_zero(Sig0_kf_3); const dR_kf3s = make_zero(R_kf_3) +const dy_kf3s = [make_zero(y_kf_3[1]) for _ in 1:10] +const dsol_kf3s = make_zero(ws_ks3.output); const dcache_kf3s = make_zero(ws_ks3.cache) + +# --- Kalman 3x3 mutable AD shadows --- + +const A_kf3m = Matrix(A_kf_3); const B_kf3m = Matrix(B_kf_3) +const C_kf3m = Matrix(C_kf_3); const mu0_kf3m = Vector(mu0_kf_3) +const Sig0_kf3m = Matrix(Sig0_kf_3); const R_kf3m = Matrix(R_kf_3) +const y_kf3m = [Vector(y) for y in y_kf_3] +const dA_kf3m = make_zero(A_kf3m); const dB_kf3m = make_zero(B_kf3m) +const dC_kf3m = make_zero(C_kf3m); const dmu0_kf3m = make_zero(mu0_kf3m) +const dSig0_kf3m = make_zero(Sig0_kf3m); const dR_kf3m = make_zero(R_kf3m) +const dy_kf3m = [make_zero(y_kf3m[1]) for _ in 1:10] +const dsol_kf3m = make_zero(ws_km3.output); const dcache_kf3m = make_zero(ws_km3.cache) + +# --- Kalman 5x5 static AD shadows --- + +const dA_kf5s = make_zero(A_kf_5); const dB_kf5s = make_zero(B_kf_5) +const dC_kf5s = make_zero(C_kf_5); const dmu0_kf5s = make_zero(mu0_kf_5) +const dSig0_kf5s = make_zero(Sig0_kf_5); const dR_kf5s = make_zero(R_kf_5) +const dy_kf5s = [make_zero(y_kf_5[1]) for _ in 1:20] +const dsol_kf5s = make_zero(ws_ks5.output); const dcache_kf5s = make_zero(ws_ks5.cache) + +# --- Kalman 5x5 mutable AD shadows --- + +const A_kf5m = Matrix(A_kf_5); const B_kf5m = Matrix(B_kf_5) +const C_kf5m = Matrix(C_kf_5); const mu0_kf5m = Vector(mu0_kf_5) +const Sig0_kf5m = Matrix(Sig0_kf_5); const R_kf5m = Matrix(R_kf_5) +const y_kf5m = [Vector(y) for y in y_kf_5] +const dA_kf5m = make_zero(A_kf5m); const dB_kf5m = make_zero(B_kf5m) +const dC_kf5m = make_zero(C_kf5m); const dmu0_kf5m = make_zero(mu0_kf5m) +const dSig0_kf5m = make_zero(Sig0_kf5m); const dR_kf5m = make_zero(R_kf5m) +const dy_kf5m = [make_zero(y_kf5m[1]) for _ in 1:20] +const dsol_kf5m = make_zero(ws_km5.output); const dcache_kf5m = make_zero(ws_km5.cache) + +# --- Kalman AD warmups --- + +forward_kalman_sa!( + A_kf_3, B_kf_3, C_kf_3, mu0_kf_3, Sig0_kf_3, R_kf_3, y_kf_3, + ws_ks3.output, ws_ks3.cache, + dA_kf3s, dB_kf3s, dC_kf3s, dmu0_kf3s, dSig0_kf3s, dR_kf3s, dy_kf3s, + dsol_kf3s, dcache_kf3s +) + +forward_kalman_sa!( + A_kf3m, B_kf3m, C_kf3m, mu0_kf3m, Sig0_kf3m, R_kf3m, y_kf3m, + ws_km3.output, ws_km3.cache, + dA_kf3m, dB_kf3m, dC_kf3m, dmu0_kf3m, dSig0_kf3m, dR_kf3m, dy_kf3m, + dsol_kf3m, dcache_kf3m +) + +reverse_kalman_sa!( + A_kf_3, B_kf_3, C_kf_3, mu0_kf_3, Sig0_kf_3, R_kf_3, y_kf_3, + ws_ks3.output, ws_ks3.cache, + dA_kf3s, dB_kf3s, dC_kf3s, dmu0_kf3s, dSig0_kf3s, dR_kf3s, dy_kf3s, + dsol_kf3s, dcache_kf3s +) + +reverse_kalman_sa!( + A_kf3m, B_kf3m, C_kf3m, mu0_kf3m, Sig0_kf3m, R_kf3m, y_kf3m, + ws_km3.output, ws_km3.cache, + dA_kf3m, dB_kf3m, dC_kf3m, dmu0_kf3m, dSig0_kf3m, dR_kf3m, dy_kf3m, + dsol_kf3m, dcache_kf3m +) + +forward_kalman_sa!( + A_kf_5, B_kf_5, C_kf_5, mu0_kf_5, Sig0_kf_5, R_kf_5, y_kf_5, + ws_ks5.output, ws_ks5.cache, + dA_kf5s, dB_kf5s, dC_kf5s, dmu0_kf5s, dSig0_kf5s, dR_kf5s, dy_kf5s, + dsol_kf5s, dcache_kf5s +) + +forward_kalman_sa!( + A_kf5m, B_kf5m, C_kf5m, mu0_kf5m, Sig0_kf5m, R_kf5m, y_kf5m, + ws_km5.output, ws_km5.cache, + dA_kf5m, dB_kf5m, dC_kf5m, dmu0_kf5m, dSig0_kf5m, dR_kf5m, dy_kf5m, + dsol_kf5m, dcache_kf5m +) + +reverse_kalman_sa!( + A_kf_5, B_kf_5, C_kf_5, mu0_kf_5, Sig0_kf_5, R_kf_5, y_kf_5, + ws_ks5.output, ws_ks5.cache, + dA_kf5s, dB_kf5s, dC_kf5s, dmu0_kf5s, dSig0_kf5s, dR_kf5s, dy_kf5s, + dsol_kf5s, dcache_kf5s +) + +reverse_kalman_sa!( + A_kf5m, B_kf5m, C_kf5m, mu0_kf5m, Sig0_kf5m, R_kf5m, y_kf5m, + ws_km5.output, ws_km5.cache, + dA_kf5m, dB_kf5m, dC_kf5m, dmu0_kf5m, dSig0_kf5m, dR_kf5m, dy_kf5m, + dsol_kf5m, dcache_kf5m +) + +# --- Kalman AD benchmarkables --- + +SA_BENCH["kalman"]["forward"]["static_3x3"] = @benchmarkable forward_kalman_sa!( + $A_kf_3, $B_kf_3, $C_kf_3, $mu0_kf_3, $Sig0_kf_3, $R_kf_3, $y_kf_3, + $(ws_ks3.output), $(ws_ks3.cache), + $dA_kf3s, $dB_kf3s, $dC_kf3s, $dmu0_kf3s, $dSig0_kf3s, $dR_kf3s, $dy_kf3s, + $dsol_kf3s, $dcache_kf3s +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +SA_BENCH["kalman"]["forward"]["mutable_3x3"] = @benchmarkable forward_kalman_sa!( + $A_kf3m, $B_kf3m, $C_kf3m, $mu0_kf3m, $Sig0_kf3m, $R_kf3m, $y_kf3m, + $(ws_km3.output), $(ws_km3.cache), + $dA_kf3m, $dB_kf3m, $dC_kf3m, $dmu0_kf3m, $dSig0_kf3m, $dR_kf3m, $dy_kf3m, + $dsol_kf3m, $dcache_kf3m +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +SA_BENCH["kalman"]["reverse"]["static_3x3"] = @benchmarkable reverse_kalman_sa!( + $A_kf_3, $B_kf_3, $C_kf_3, $mu0_kf_3, $Sig0_kf_3, $R_kf_3, $y_kf_3, + $(ws_ks3.output), $(ws_ks3.cache), + $dA_kf3s, $dB_kf3s, $dC_kf3s, $dmu0_kf3s, $dSig0_kf3s, $dR_kf3s, $dy_kf3s, + $dsol_kf3s, $dcache_kf3s +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +SA_BENCH["kalman"]["reverse"]["mutable_3x3"] = @benchmarkable reverse_kalman_sa!( + $A_kf3m, $B_kf3m, $C_kf3m, $mu0_kf3m, $Sig0_kf3m, $R_kf3m, $y_kf3m, + $(ws_km3.output), $(ws_km3.cache), + $dA_kf3m, $dB_kf3m, $dC_kf3m, $dmu0_kf3m, $dSig0_kf3m, $dR_kf3m, $dy_kf3m, + $dsol_kf3m, $dcache_kf3m +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +SA_BENCH["kalman"]["forward"]["static_5x5"] = @benchmarkable forward_kalman_sa!( + $A_kf_5, $B_kf_5, $C_kf_5, $mu0_kf_5, $Sig0_kf_5, $R_kf_5, $y_kf_5, + $(ws_ks5.output), $(ws_ks5.cache), + $dA_kf5s, $dB_kf5s, $dC_kf5s, $dmu0_kf5s, $dSig0_kf5s, $dR_kf5s, $dy_kf5s, + $dsol_kf5s, $dcache_kf5s +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +SA_BENCH["kalman"]["forward"]["mutable_5x5"] = @benchmarkable forward_kalman_sa!( + $A_kf5m, $B_kf5m, $C_kf5m, $mu0_kf5m, $Sig0_kf5m, $R_kf5m, $y_kf5m, + $(ws_km5.output), $(ws_km5.cache), + $dA_kf5m, $dB_kf5m, $dC_kf5m, $dmu0_kf5m, $dSig0_kf5m, $dR_kf5m, $dy_kf5m, + $dsol_kf5m, $dcache_kf5m +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +SA_BENCH["kalman"]["reverse"]["static_5x5"] = @benchmarkable reverse_kalman_sa!( + $A_kf_5, $B_kf_5, $C_kf_5, $mu0_kf_5, $Sig0_kf_5, $R_kf_5, $y_kf_5, + $(ws_ks5.output), $(ws_ks5.cache), + $dA_kf5s, $dB_kf5s, $dC_kf5s, $dmu0_kf5s, $dSig0_kf5s, $dR_kf5s, $dy_kf5s, + $dsol_kf5s, $dcache_kf5s +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +SA_BENCH["kalman"]["reverse"]["mutable_5x5"] = @benchmarkable reverse_kalman_sa!( + $A_kf5m, $B_kf5m, $C_kf5m, $mu0_kf5m, $Sig0_kf5m, $R_kf5m, $y_kf5m, + $(ws_km5.output), $(ws_km5.cache), + $dA_kf5m, $dB_kf5m, $dC_kf5m, $dmu0_kf5m, $dSig0_kf5m, $dR_kf5m, $dy_kf5m, + $dsol_kf5m, $dcache_kf5m +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# ============================================================================= +# AD benchmarks for Linear 2x2 (static and mutable) +# ============================================================================= + +SA_BENCH["linear"]["forward"] = BenchmarkGroup() +SA_BENCH["linear"]["reverse"] = BenchmarkGroup() + +function sim_fwd_sa!(A, B, C, u0, noise, sol_out, cache) + prob = LinearStateSpaceProblem(A, B, u0, (0, length(noise)); C, noise) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol_out, cache) + solve!(ws) + return (sol_out.u[end], sol_out.z[end]) +end + +function sim_rev_sa!(A, B, C, u0, noise, sol_out, cache)::Float64 + prob = LinearStateSpaceProblem(A, B, u0, (0, length(noise)); C, noise) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol_out, cache) + return sum(solve!(ws).u[end]) +end + +function forward_sa!( + A, B, C, u0, noise, sol_out, cache, + dA, dB, dC, du0, dnoise, dsol_out, dcache + ) + make_zero!(dsol_out); make_zero!(dcache) + dA = fill_zero!!(dA); dB = fill_zero!!(dB); dC = fill_zero!!(dC); du0 = fill_zero!!(du0) + @inbounds for i in eachindex(dnoise) + dnoise[i] = fill_zero!!(dnoise[i]) + end + if ismutable(dA) + dA[1, 1] = 1.0 + else + dA = setindex(dA, 1.0, 1, 1) + end + autodiff( + Forward, sim_fwd_sa!, + Duplicated(A, dA), Duplicated(B, dB), Duplicated(C, dC), + Duplicated(u0, du0), Duplicated(noise, dnoise), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return nothing +end + +function reverse_sa!( + A, B, C, u0, noise, sol_out, cache, + dA, dB, dC, du0, dnoise, dsol_out, dcache + ) + make_zero!(dsol_out); make_zero!(dcache) + dA = fill_zero!!(dA); dB = fill_zero!!(dB); dC = fill_zero!!(dC); du0 = fill_zero!!(du0) + @inbounds for i in eachindex(dnoise) + dnoise[i] = fill_zero!!(dnoise[i]) + end + autodiff( + Reverse, sim_rev_sa!, Active, + Duplicated(A, dA), Duplicated(B, dB), Duplicated(C, dC), + Duplicated(u0, du0), Duplicated(noise, dnoise), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return nothing +end + +# --- Static 2x2 AD shadows --- + +const dA_s2 = make_zero(A_sa_2) +const dB_s2 = make_zero(B_sa_2) +const dC_s2 = make_zero(C_sa_2) +const du0_s2 = make_zero(u0_sa_2) +const dnoise_s2 = [make_zero(noise_sa_2[1]) for _ in 1:20] +const dsol_s2 = make_zero(ws_ls2.output) +const dcache_s2 = make_zero(ws_ls2.cache) + +# --- Mutable 2x2 AD shadows --- + +const A_m2 = Matrix(A_sa_2) +const B_m2 = Matrix(B_sa_2) +const C_m2 = Matrix(C_sa_2) +const u0_m2 = Vector(u0_sa_2) +const noise_m2 = [Vector(n) for n in noise_sa_2] +const dA_m2 = make_zero(A_m2) +const dB_m2 = make_zero(B_m2) +const dC_m2 = make_zero(C_m2) +const du0_m2 = make_zero(u0_m2) +const dnoise_m2 = [make_zero(noise_m2[1]) for _ in 1:20] +const dsol_m2 = make_zero(ws_lm2.output) +const dcache_m2 = make_zero(ws_lm2.cache) + +# --- Warmups --- + +forward_sa!( + A_sa_2, B_sa_2, C_sa_2, u0_sa_2, noise_sa_2, + ws_ls2.output, ws_ls2.cache, + dA_s2, dB_s2, dC_s2, du0_s2, dnoise_s2, dsol_s2, dcache_s2 +) + +forward_sa!( + A_m2, B_m2, C_m2, u0_m2, noise_m2, + ws_lm2.output, ws_lm2.cache, + dA_m2, dB_m2, dC_m2, du0_m2, dnoise_m2, dsol_m2, dcache_m2 +) + +reverse_sa!( + A_sa_2, B_sa_2, C_sa_2, u0_sa_2, noise_sa_2, + ws_ls2.output, ws_ls2.cache, + dA_s2, dB_s2, dC_s2, du0_s2, dnoise_s2, dsol_s2, dcache_s2 +) + +reverse_sa!( + A_m2, B_m2, C_m2, u0_m2, noise_m2, + ws_lm2.output, ws_lm2.cache, + dA_m2, dB_m2, dC_m2, du0_m2, dnoise_m2, dsol_m2, dcache_m2 +) + +# --- Benchmarkables --- + +SA_BENCH["linear"]["forward"]["static_2x2"] = @benchmarkable forward_sa!( + $A_sa_2, $B_sa_2, $C_sa_2, $u0_sa_2, $noise_sa_2, + $(ws_ls2.output), $(ws_ls2.cache), + $dA_s2, $dB_s2, $dC_s2, $du0_s2, $dnoise_s2, $dsol_s2, $dcache_s2 +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +SA_BENCH["linear"]["forward"]["mutable_2x2"] = @benchmarkable forward_sa!( + $A_m2, $B_m2, $C_m2, $u0_m2, $noise_m2, + $(ws_lm2.output), $(ws_lm2.cache), + $dA_m2, $dB_m2, $dC_m2, $du0_m2, $dnoise_m2, $dsol_m2, $dcache_m2 +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +SA_BENCH["linear"]["reverse"]["static_2x2"] = @benchmarkable reverse_sa!( + $A_sa_2, $B_sa_2, $C_sa_2, $u0_sa_2, $noise_sa_2, + $(ws_ls2.output), $(ws_ls2.cache), + $dA_s2, $dB_s2, $dC_s2, $du0_s2, $dnoise_s2, $dsol_s2, $dcache_s2 +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +SA_BENCH["linear"]["reverse"]["mutable_2x2"] = @benchmarkable reverse_sa!( + $A_m2, $B_m2, $C_m2, $u0_m2, $noise_m2, + $(ws_lm2.output), $(ws_lm2.cache), + $dA_m2, $dB_m2, $dC_m2, $du0_m2, $dnoise_m2, $dsol_m2, $dcache_m2 +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +# ============================================================================= +# AD benchmarks for Quadratic 2x2 (static and mutable) — PrunedQuadraticStateSpaceProblem +# ============================================================================= + +SA_BENCH["quadratic"]["forward"] = BenchmarkGroup() +SA_BENCH["quadratic"]["reverse"] = BenchmarkGroup() + +# --- Inner wrappers: construct prob inside (correct Enzyme pattern) --- + +function quad_fwd_sa!(A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, sol_out, cache) + prob = PrunedQuadraticStateSpaceProblem( + A_0, A_1, A_2, B, u0, (0, length(noise)); + C_0, C_1, C_2, noise + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol_out, cache) + solve!(ws) + return (sol_out.u[end], sol_out.z[end]) +end + +function quad_rev_sa!(A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, sol_out, cache)::Float64 + prob = PrunedQuadraticStateSpaceProblem( + A_0, A_1, A_2, B, u0, (0, length(noise)); + C_0, C_1, C_2, noise + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol_out, cache) + return sum(solve!(ws).u[end]) +end + +# --- Outer bench functions: zero shadows, call autodiff --- + +function forward_quad_sa!( + A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, sol_out, cache, + dA_0, dA_1, dA_2, dB, dC_0, dC_1, dC_2, du0, dnoise, dsol_out, dcache + ) + make_zero!(dsol_out); make_zero!(dcache) + dA_0 = fill_zero!!(dA_0); dA_1 = fill_zero!!(dA_1); make_zero!(dA_2) + dB = fill_zero!!(dB); dC_0 = fill_zero!!(dC_0); dC_1 = fill_zero!!(dC_1) + make_zero!(dC_2); du0 = fill_zero!!(du0) + @inbounds for i in eachindex(dnoise) + dnoise[i] = fill_zero!!(dnoise[i]) + end + if ismutable(dA_1) + dA_1[1, 1] = 1.0 + else + dA_1 = setindex(dA_1, 1.0, 1, 1) + end + + autodiff( + Forward, quad_fwd_sa!, + Duplicated(A_0, dA_0), Duplicated(A_1, dA_1), Duplicated(A_2, dA_2), + Duplicated(B, dB), Duplicated(C_0, dC_0), Duplicated(C_1, dC_1), + Duplicated(C_2, dC_2), Duplicated(u0, du0), Duplicated(noise, dnoise), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return nothing +end + +function reverse_quad_sa!( + A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, sol_out, cache, + dA_0, dA_1, dA_2, dB, dC_0, dC_1, dC_2, du0, dnoise, dsol_out, dcache + ) + make_zero!(dsol_out); make_zero!(dcache) + dA_0 = fill_zero!!(dA_0); dA_1 = fill_zero!!(dA_1); make_zero!(dA_2) + dB = fill_zero!!(dB); dC_0 = fill_zero!!(dC_0); dC_1 = fill_zero!!(dC_1) + make_zero!(dC_2); du0 = fill_zero!!(du0) + @inbounds for i in eachindex(dnoise) + dnoise[i] = fill_zero!!(dnoise[i]) + end + + autodiff( + Reverse, quad_rev_sa!, Active, + Duplicated(A_0, dA_0), Duplicated(A_1, dA_1), Duplicated(A_2, dA_2), + Duplicated(B, dB), Duplicated(C_0, dC_0), Duplicated(C_1, dC_1), + Duplicated(C_2, dC_2), Duplicated(u0, du0), Duplicated(noise, dnoise), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return nothing +end + +# --- Static quadratic AD shadows --- + +const dAs0 = make_zero(As0); const dAs1 = make_zero(As1) +const dA_2_qs = make_zero(A_2_q); const dBs = make_zero(Bs) +const dCs0 = make_zero(Cs0); const dCs1 = make_zero(Cs1) +const dC_2_qs = make_zero(C_2_q); const du0s = make_zero(u0s) +const dnoise_qs = [make_zero(noise_s[1]) for _ in 1:10] +const dsol_qs = make_zero(ws_qs.output); const dcache_qs = make_zero(ws_qs.cache) + +# --- Mutable quadratic AD shadows --- + +const A_0_qm_ad = Vector(As0); const A_1_qm_ad = Matrix(As1) +const A_2_qm_ad = copy(A_2_q); const B_qm_ad = Matrix(Bs) +const C_0_qm_ad = Vector(Cs0); const C_1_qm_ad = Matrix(Cs1) +const C_2_qm_ad = copy(C_2_q); const u0_qm_ad = Vector(u0s) +const noise_qm_ad = [Vector(n) for n in noise_s] +const dA_0_qm = make_zero(A_0_qm_ad); const dA_1_qm = make_zero(A_1_qm_ad) +const dA_2_qm = make_zero(A_2_qm_ad); const dB_qm_ad = make_zero(B_qm_ad) +const dC_0_qm = make_zero(C_0_qm_ad); const dC_1_qm = make_zero(C_1_qm_ad) +const dC_2_qm = make_zero(C_2_qm_ad); const du0_qm_ad = make_zero(u0_qm_ad) +const dnoise_qm_ad = [make_zero(noise_qm_ad[1]) for _ in 1:10] +const dsol_qm = make_zero(ws_qm.output); const dcache_qm = make_zero(ws_qm.cache) + +# --- Quadratic warmups --- + +forward_quad_sa!( + As0, As1, A_2_q, Bs, Cs0, Cs1, C_2_q, + u0s, noise_s, ws_qs.output, ws_qs.cache, + dAs0, dAs1, dA_2_qs, dBs, dCs0, dCs1, dC_2_qs, + du0s, dnoise_qs, dsol_qs, dcache_qs +) + +forward_quad_sa!( + A_0_qm_ad, A_1_qm_ad, A_2_qm_ad, B_qm_ad, C_0_qm_ad, C_1_qm_ad, C_2_qm_ad, + u0_qm_ad, noise_qm_ad, ws_qm.output, ws_qm.cache, + dA_0_qm, dA_1_qm, dA_2_qm, dB_qm_ad, dC_0_qm, dC_1_qm, dC_2_qm, + du0_qm_ad, dnoise_qm_ad, dsol_qm, dcache_qm +) + +reverse_quad_sa!( + As0, As1, A_2_q, Bs, Cs0, Cs1, C_2_q, + u0s, noise_s, ws_qs.output, ws_qs.cache, + dAs0, dAs1, dA_2_qs, dBs, dCs0, dCs1, dC_2_qs, + du0s, dnoise_qs, dsol_qs, dcache_qs +) + +reverse_quad_sa!( + A_0_qm_ad, A_1_qm_ad, A_2_qm_ad, B_qm_ad, C_0_qm_ad, C_1_qm_ad, C_2_qm_ad, + u0_qm_ad, noise_qm_ad, ws_qm.output, ws_qm.cache, + dA_0_qm, dA_1_qm, dA_2_qm, dB_qm_ad, dC_0_qm, dC_1_qm, dC_2_qm, + du0_qm_ad, dnoise_qm_ad, dsol_qm, dcache_qm +) + +# --- Quadratic benchmarkables --- + +SA_BENCH["quadratic"]["forward"]["static_2x2"] = @benchmarkable forward_quad_sa!( + $As0, $As1, $A_2_q, $Bs, $Cs0, $Cs1, $C_2_q, + $u0s, $noise_s, $(ws_qs.output), $(ws_qs.cache), + $dAs0, $dAs1, $dA_2_qs, $dBs, $dCs0, $dCs1, $dC_2_qs, + $du0s, $dnoise_qs, $dsol_qs, $dcache_qs +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +SA_BENCH["quadratic"]["forward"]["mutable_2x2"] = @benchmarkable forward_quad_sa!( + $A_0_qm_ad, $A_1_qm_ad, $A_2_qm_ad, $B_qm_ad, $C_0_qm_ad, $C_1_qm_ad, $C_2_qm_ad, + $u0_qm_ad, $noise_qm_ad, $(ws_qm.output), $(ws_qm.cache), + $dA_0_qm, $dA_1_qm, $dA_2_qm, $dB_qm_ad, $dC_0_qm, $dC_1_qm, $dC_2_qm, + $du0_qm_ad, $dnoise_qm_ad, $dsol_qm, $dcache_qm +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +SA_BENCH["quadratic"]["reverse"]["static_2x2"] = @benchmarkable reverse_quad_sa!( + $As0, $As1, $A_2_q, $Bs, $Cs0, $Cs1, $C_2_q, + $u0s, $noise_s, $(ws_qs.output), $(ws_qs.cache), + $dAs0, $dAs1, $dA_2_qs, $dBs, $dCs0, $dCs1, $dC_2_qs, + $du0s, $dnoise_qs, $dsol_qs, $dcache_qs +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +SA_BENCH["quadratic"]["reverse"]["mutable_2x2"] = @benchmarkable reverse_quad_sa!( + $A_0_qm_ad, $A_1_qm_ad, $A_2_qm_ad, $B_qm_ad, $C_0_qm_ad, $C_1_qm_ad, $C_2_qm_ad, + $u0_qm_ad, $noise_qm_ad, $(ws_qm.output), $(ws_qm.cache), + $dA_0_qm, $dA_1_qm, $dA_2_qm, $dB_qm_ad, $dC_0_qm, $dC_1_qm, $dC_2_qm, + $du0_qm_ad, $dnoise_qm_ad, $dsol_qm, $dcache_qm +) teardown = (GC.enable(true); GC.gc(); GC.enable(false)) + +SA_BENCH diff --git a/development.md b/development.md deleted file mode 100644 index 78bccc1..0000000 --- a/development.md +++ /dev/null @@ -1,164 +0,0 @@ -# Development and Benchmarking - -## Setup - -One time setup: - - 1. Setup your environment for [VS Code](https://julia.quantecon.org/software_engineering/tools_editors.html), [github](https://julia.quantecon.org/software_engineering/version_control.html) and [unit testing](https://julia.quantecon.org/software_engineering/testing.html). - 2. First start up a Julia repl in vscode this project - 3. Activate the global environment with `] activate` instead of the project environment - 4. Add in global packages for debugging and benchmarking - -``` -] add BenchmarkTools Infiltrator TestEnv PkgBenchmark -``` - - 5. Activate the benchmarking project - -``` -] activate benchmark -``` - - 6. Connect it the current version of the DifferenceEquations package, - -``` -] dev . -``` - - 7. Instantiate all benchmarking dependencies, - -``` -] instantiate -``` - -## Editing and Debugging Code - -If you open this folder in VS Code, the `Project.toml` at the root is activated rather than the one in the unit tests. - - - The `] test` should work without any chances, - - But to step through individual unit tests which may have test-only dependencies, you can use the `TestEnv` package. To do this, whenever starting the REPL do - -```julia -using TestEnv; -TestEnv.activate(); -``` - -At that point, you should be able to edit as if the `test/Project.toml` package was activated. For example, `include("test/runtests.jl")` should be roughly equivalent to `]test`. - -A useful trick for debugging is with `Infiltrator.jl`. Put in a `@exfiltrate` in the code, (e.g. inside of a DSSM function) and it pushes all local variables into a global associated with the module. - -# Benchmarking - -This assumes you are running as a package in VS Code. If not, then you will need to activate project files more carefulluy. - -Or start julia in the `DifferenceEquations/benchmark` folder with the `--project` CLI argument. - -### Running the Full Benchmarks - -Always start with the benchmarks activated, i.e. `] activate benchmark` -A few utilities - -```julia -using DifferenceEquations, PkgBenchmark -function save_benchmark(results_file = "baseline") - data = benchmarkpkg(DifferenceEquations; - resultfile = joinpath(pkgdir(DifferenceEquations), "benchmark/$results_file.json")) - export_markdown( - joinpath(pkgdir(DifferenceEquations), "benchmark/trial_$results_file.md"), data) -end -function generate_judgement(new_results, old_results = "baseline", judge_file = "judge") - return export_markdown( - joinpath(pkgdir(DifferenceEquations), "benchmark/$judge_file.md"), - judge( - PkgBenchmark.readresults(joinpath(pkgdir(DifferenceEquations), - "benchmark/$new_results.json")), - PkgBenchmark.readresults(joinpath(pkgdir(DifferenceEquations), - "benchmark/$old_results.json")))) -end -``` - -In your terminal - -```julia -save_benchmark("test") # default is "baseline" - -# Or manually: -# data = benchmarkpkg(DifferenceEquations; resultfile = joinpath(pkgdir(DifferenceEquations),"benchmark/baseline.json")) -# export_markdown(joinpath(pkgdir(DifferenceEquations),"benchmark/trial.md"), data) # can export as markdown -``` - -To compare against different parameters or after modifications, load the existing baseline and use the `judge` function to compare - -```julia -generate_judgement("test") # defaults to generate_judgement("test", "baseline", "judge") -# Or manually -# data = PkgBenchmark.readresults(joinpath(pkgdir(DifferenceEquations),"benchmark/baseline.json")) -# data_2 = benchmarkpkg(DifferenceEquations, BenchmarkConfig( -# env = Dict("JULIA_NUM_THREADS" => 4, "OPENBLAS_NUM_THREADS" => 1), -# juliacmd = `julia -O3`)) -# export_markdown(joinpath(pkgdir(DifferenceEquations),"benchmark/judge.md"), judge(data_2, data)) -``` - -### Running Portions of the Benchmarks During Development - -Rather than the whole PkgBenchmark, you can run the individual benchmarks by either first loading them all up - -```julia -using DifferenceEquations -include(joinpath(pkgdir(DifferenceEquations), "benchmark/benchmarks.jl")) -``` - -And then running individual ones - -To use: - - - To run part of the benchmarks, you can refer to the global `SUITE`. For example, - -```julia -run(SUITE["linear"]["rbc"]["joint_1"], verbose = true) -``` - - - Or to get specific statistics such as the median (and using postfix) - -```julia -SUITE["linear"]["rbc"]["joint_1"] |> run |> median -``` - -To compare between changes, save the results and judge the difference (e.g. median). - -For example, with a subset of the suite. Run it and then save the results - -```julia -output_path_old = joinpath(pkgdir(DifferenceEquations), "benchmark/rbc_first_order.json") -BenchmarkTools.save(output_path_old, run(SUITE["linear"]["rbc"]["joint_1"], verbose = true)) -``` - -Now you can reload that stored benchmarking later and compare, - -```julia -# Make code change and rerun... -results_new = run(SUITE["linear"]["rbc"]["joint_1"], verbose = true) - -#Load to compare to the old one -output_path_old = joinpath(pkgdir(DifferenceEquations), "benchmark/rbc_first_order.json") -results_old = BenchmarkTools.load(output_path_old)[1] - -judge_results = judge(median(results_new), median(results_old)) # compare the median/etc. -``` - -# Generating Documentation - -Activate the docs directory and then ensure it is using your local version - -``` -] activate docs -dev .. -``` - -After that step, only `] activate docs` is required. To generate documentation locally - -```julia -include("docs/make.jl") -``` - -To visualize the generated documents during development on vscode, consider running the `> Live Preview: Start Server` and navigating to the `docs/build` directory. diff --git a/docs/Project.toml b/docs/Project.toml index 22f82f6..cc90d08 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,26 +1,26 @@ [deps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DifferenceEquations = "e0ca9c66-1f9e-11ec-127a-1304ce62169c" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] -ChainRulesCore = "1" DataFrames = "1" DiffEqBase = "6.145, 7" DifferenceEquations = "1.1" Distributions = "0.25" Documenter = "1" -Optimization = "3, 5" -OptimizationOptimJL = "0.1, 0.4" +DocumenterInterLinks = "1" +Enzyme = "0.13" Plots = "1" -Zygote = "0.6, 0.7" +StaticArrays = "1" diff --git a/docs/make.jl b/docs/make.jl index 9001905..4f60992 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,16 +1,25 @@ using Documenter, DifferenceEquations +using DocumenterInterLinks include("pages.jl") +links = InterLinks( + "SciMLBase" => "https://docs.sciml.ai/SciMLBase/stable/", +) + makedocs( sitename = "DifferenceEquations.jl", authors = "Various Authors", - clean = true, doctest = false, linkcheck = false, - warnonly = [:example_block], + clean = true, + doctest = false, + linkcheck = true, + checkdocs = :exports, + warnonly = [:missing_docs, :linkcheck], modules = [DifferenceEquations], + plugins = [links], format = Documenter.HTML( assets = ["assets/favicon.ico"], - canonical = "https://DifferenceEquations.sciml.ai/stable/" + canonical = "https://docs.sciml.ai/DifferenceEquations/stable/" ), pages = pages ) diff --git a/docs/pages.jl b/docs/pages.jl index 30b7bb6..347ba41 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -1,8 +1,24 @@ pages = [ - "DifferenceEquations.jl: Discrete-Time State Space Solution Methods" => "index.md", - "Examples" => [ - "examples/linear_state_space_examples.md", - "examples/quadratic_state_space_examples.md", - "examples/general_state_space_examples.md", + "Home" => "index.md", + "Getting Started" => "getting_started.md", + "Tutorials" => [ + "Linear Simulation" => "tutorials/linear_simulation.md", + "Likelihood & Kalman Filter" => "tutorials/linear_likelihood.md", + "Conditional Likelihood" => "tutorials/conditional_likelihood.md", + "Quadratic Models" => "tutorials/quadratic.md", + "Generic Callbacks" => "tutorials/generic_callbacks.md", + ], + "Basics" => [ + "Problem Types" => "basics/problem_types.md", + "Solvers" => "basics/solvers.md", + "Solutions" => "basics/solutions.md", + "Workspace API" => "basics/workspace.md", + "FAQ" => "basics/faq.md", + ], + "Advanced" => [ + "Enzyme AD" => "advanced/enzyme_ad.md", + "ForwardDiff AD" => "advanced/forwarddiff_ad.md", + "StaticArrays" => "advanced/static_arrays.md", + "Internals" => "advanced/internals.md", ], ] diff --git a/docs/src/advanced/enzyme_ad.md b/docs/src/advanced/enzyme_ad.md new file mode 100644 index 0000000..09b49cc --- /dev/null +++ b/docs/src/advanced/enzyme_ad.md @@ -0,0 +1,168 @@ +# Enzyme AD + +DifferenceEquations.jl is fully differentiable with [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) in both reverse and forward mode. All examples below use the workspace-based `init`/`solve!` pattern with [`StateSpaceWorkspace`](@ref), which gives Enzyme the pre-allocated buffers it needs. + +## The Core Pattern + +Every Enzyme example in this package follows the same recipe: + +1. **Flat-argument wrapper function.** Construct the `LinearStateSpaceProblem` *inside* the function from plain matrix/vector arguments. This keeps the Enzyme call site simple and avoids closing over mutable state. + +2. **Pre-allocate with `init`.** Call `init(prob, alg)` once to obtain a workspace whose `.output` (solution) and `.cache` fields are correctly sized buffers. Then pass those buffers into the wrapper via `StateSpaceWorkspace(prob, alg, sol, cache)` followed by `solve!(ws).logpdf`. + +3. **All arguments `Duplicated`.** Because every argument flows into the *same* `LinearStateSpaceProblem` struct, Enzyme treats the whole struct as active. If even one field is `Const` while others are `Duplicated`, Enzyme may silently produce wrong gradients. The safe rule: **mark every argument `Duplicated`**. + +4. **Zero-initialized shadows for `sol`/`cache`.** Shadow copies for the solution and cache buffers must be created with `Enzyme.make_zero(deepcopy(...))`. A plain `deepcopy` copies the primal values into the shadow, which can produce `NaN` gradients. `make_zero` recursively zeroes every numeric field while preserving the nested structure. + +## Differentiating Joint Likelihood + +The joint likelihood conditions on a fixed noise sequence and accumulates the observation log-likelihood along the trajectory via [`DirectIteration`](@ref). + +```@example enzyme +using DifferenceEquations, LinearAlgebra, Enzyme, Random + +N, K, M = 2, 1, 2 +A = [0.8 0.1; -0.1 0.7] +B = [0.1; 0.0;;] +C = [1.0 0.0; 0.0 1.0] +D = Diagonal([0.01, 0.01]) # diagonal covariance; use Symmetric(H * H') for non-diagonal +u0 = zeros(N) + +Random.seed!(42) +noise = [randn(K) for _ in 1:5] +sim = solve(LinearStateSpaceProblem(A, B, u0, (0, 5); C, noise)) +obs = [sim.z[t + 1] + 0.1 * randn(M) for t in 1:5] + +# Likelihood function: all matrix args as separate parameters +function di_loglik(A, B, C, u0, noise, obs, R, sol, cache)::Float64 + prob = LinearStateSpaceProblem(A, B, u0, (0, length(obs)); + C, observables_noise = R, observables = obs, noise) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol, cache) + return solve!(ws).logpdf +end + +# Pre-allocate buffers +prob0 = LinearStateSpaceProblem(A, B, u0, (0, length(obs)); + C, observables_noise = D, observables = obs, noise) +ws0 = init(prob0, DirectIteration()) + +# Compute gradient wrt A +dA = zero(A) +autodiff(Reverse, di_loglik, + Duplicated(copy(A), dA), + Duplicated(copy(B), zero(B)), + Duplicated(copy(C), zero(C)), + Duplicated(copy(u0), zero(u0)), + Duplicated(deepcopy(noise), [zeros(K) for _ in noise]), + Duplicated(deepcopy(obs), [zeros(M) for _ in obs]), + Duplicated(copy(D), zero(D)), + Duplicated(deepcopy(ws0.output), Enzyme.make_zero(deepcopy(ws0.output))), + Duplicated(deepcopy(ws0.cache), Enzyme.make_zero(deepcopy(ws0.cache)))) +dA # gradient of logpdf with respect to A +``` + +## Differentiating the Kalman Filter + +The [`KalmanFilter`](@ref) computes the marginal log-likelihood by integrating out the latent noise analytically. The same all-`Duplicated` pattern applies. + +```@example enzyme +# Kalman filter likelihood +function kf_loglik(A, B, C, mu0, Sigma0, R, obs, sol, cache)::Float64 + prob = LinearStateSpaceProblem(A, B, zeros(eltype(A), size(A,1)), (0, length(obs)); + C, u0_prior_mean = mu0, u0_prior_var = Sigma0, + observables_noise = R, observables = obs) + ws = StateSpaceWorkspace(prob, KalmanFilter(), sol, cache) + return solve!(ws).logpdf +end + +mu0 = zeros(N) +Sigma0 = Matrix(1.0 * I(N)) +prob_kf = LinearStateSpaceProblem(A, B, zeros(N), (0, length(obs)); + C, u0_prior_mean = mu0, u0_prior_var = Sigma0, + observables_noise = D, observables = obs) +ws_kf = init(prob_kf, KalmanFilter()) + +dA_kf = zero(A) +autodiff(Reverse, kf_loglik, + Duplicated(copy(A), dA_kf), + Duplicated(copy(B), zero(B)), + Duplicated(copy(C), zero(C)), + Duplicated(copy(mu0), zero(mu0)), + Duplicated(copy(Sigma0), zero(Sigma0)), + Duplicated(copy(D), zero(D)), + Duplicated(deepcopy(obs), [zeros(M) for _ in obs]), + Duplicated(deepcopy(ws_kf.output), Enzyme.make_zero(deepcopy(ws_kf.output))), + Duplicated(deepcopy(ws_kf.cache), Enzyme.make_zero(deepcopy(ws_kf.cache)))) +dA_kf # gradient of Kalman logpdf with respect to A +``` + +## Integration with Optimization.jl + +The differentiable Kalman likelihood composes naturally with [Optimization.jl](https://github.com/SciML/Optimization.jl) for maximum-likelihood estimation. Because the all-`Duplicated` requirement cannot be expressed through `AutoEnzyme()`, we supply an explicit `grad` function that calls `Enzyme.autodiff` directly. + +```@example enzyme +using Optimization, OptimizationOptimJL + +# Simulate data from a known model +Random.seed!(42) +T_opt = 200 +B_opt = [0.0; 0.001;;] +C_opt = [0.09 0.67; 1.00 0.00] +D_opt = Diagonal([0.01, 0.01]) +prob_data = LinearStateSpaceProblem([0.95 6.2; 0.0 0.2], B_opt, zeros(2), (0, T_opt); + C = C_opt, observables_noise = D_opt) +sol_data = solve(prob_data) +obs_data = sol_data.z[2:end] + +# Pre-allocate Kalman workspace +mu0_opt = zeros(2) +Sigma0_opt = Matrix(1e-2 * I(2)) +prob_base = LinearStateSpaceProblem([0.95 6.2; 0.0 0.2], B_opt, zeros(2), + (0, length(obs_data)); C = C_opt, observables = obs_data, + observables_noise = D_opt, u0_prior_mean = mu0_opt, u0_prior_var = Sigma0_opt) +ws_opt = init(prob_base, KalmanFilter()) + +# Objective and gradient using the flat-argument pattern +function neg_loglik(beta, p) + A = [beta[1] 6.2; 0.0 0.2] + return -kf_loglik(A, p.B, p.C, p.mu0, p.Sigma0, p.D, p.obs, + deepcopy(p.sol), deepcopy(p.cache)) +end + +function neg_loglik_grad!(g, beta, p) + A = [beta[1] 6.2; 0.0 0.2] + dA = zero(A) + autodiff(Reverse, kf_loglik, + Duplicated(A, dA), + Duplicated(copy(p.B), zero(p.B)), + Duplicated(copy(p.C), zero(p.C)), + Duplicated(copy(p.mu0), zero(p.mu0)), + Duplicated(copy(p.Sigma0), zero(p.Sigma0)), + Duplicated(copy(p.D), zero(p.D)), + Duplicated(deepcopy(p.obs), [zeros(2) for _ in p.obs]), + Duplicated(deepcopy(p.sol), Enzyme.make_zero(deepcopy(p.sol))), + Duplicated(deepcopy(p.cache), Enzyme.make_zero(deepcopy(p.cache)))) + g[1] = -dA[1, 1] +end + +params = (; B = B_opt, C = C_opt, D = D_opt, obs = obs_data, + mu0 = mu0_opt, Sigma0 = Sigma0_opt, sol = ws_opt.output, cache = ws_opt.cache) + +optf = OptimizationFunction(neg_loglik; grad = neg_loglik_grad!) +optprob = OptimizationProblem(optf, [0.90], params) +optsol = solve(optprob, LBFGS()) +optsol.u # estimated beta (true value: 0.95) +``` + +## Quadratic and Generic Models + +The same all-`Duplicated` pattern works for [`QuadraticStateSpaceProblem`](@ref), [`PrunedQuadraticStateSpaceProblem`](@ref), and [`StateSpaceProblem`](@ref). Replace the constructor and add the extra arguments (`A_0`, `A_1`, `A_2`, `C_0`, `C_1`, `C_2` for quadratic; callback functions for generic) as separate `Duplicated` parameters. See the [Quadratic Models](@ref) tutorial for an Enzyme example with quadratic problems. + +## Important Notes + +- All arguments to the likelihood function that flow into the problem struct must be `Duplicated`, not `Const`. This is because Enzyme tracks activity at the struct level. +- Shadow copies for `sol` and `cache` buffers must be zero-initialized using `Enzyme.make_zero(deepcopy(...))`. Using plain `deepcopy` produces `NaN` gradients. +- The `Optimization.jl` integration requires an explicit `grad` function because `AutoEnzyme()` cannot directly handle the all-Duplicated requirement. The gradient function calls `Enzyme.autodiff` manually. +- Avoid calling `GC.gc()` inside functions differentiated by Enzyme -- this can cause segfaults when combined with `BenchmarkTools`. +- See the [Workspace API](@ref) page for details on `init`, `solve!`, and `StateSpaceWorkspace`. +- For small models (N ≤ 5), [ForwardDiff AD](@ref) offers a simpler alternative with comparable performance and no `Duplicated` bookkeeping. diff --git a/docs/src/advanced/forwarddiff_ad.md b/docs/src/advanced/forwarddiff_ad.md new file mode 100644 index 0000000..0880289 --- /dev/null +++ b/docs/src/advanced/forwarddiff_ad.md @@ -0,0 +1,205 @@ +# ForwardDiff AD + +DifferenceEquations.jl works with [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) out of the box for computing gradients and Jacobians. ForwardDiff requires no shadow arrays, no activity annotations, and no workspace pre-allocation -- just wrap your function in `ForwardDiff.gradient`. + +!!! tip "When to use ForwardDiff vs Enzyme" + + | Scenario | Recommendation | + |----------|---------------| + | Small models (N ≤ 5 states), few parameters | **ForwardDiff** -- same speed as Enzyme reverse, zero setup cost | + | Large models (N ≥ 10 states) | **Enzyme reverse** -- scales as O(1) backward passes vs O(N²) forward passes | + | Many parameters (e.g., noise perturbation over T periods) | **Enzyme reverse** -- ForwardDiff cost scales with parameter count | + | Quick prototyping | **ForwardDiff** -- simpler API, no `Duplicated` bookkeeping | + | Production estimation loops | **Enzyme reverse** -- lower memory, pre-allocated workspace | + + For `DirectIteration` problems where you differentiate with respect to the noise sequence, the effective parameter dimension is `K × T` (shocks × periods), not just `N²`. Even for small state dimensions, long horizons make Enzyme reverse the better choice. + +## The Core Pattern + +ForwardDiff propagates dual numbers through the computation. The key requirement is that all arrays must have a consistent element type (either `Float64` or `Dual{...}`). The pattern is: + +1. **Write a scalar function of a vector.** ForwardDiff.gradient takes `f: ℝⁿ → ℝ`. +2. **Promote all arrays inside the function.** When `ForwardDiff.gradient` calls your function with a `Vector{Dual{...}}`, convert all other matrices to the same `Dual` element type so that caches are allocated correctly. +3. **Use the public `solve()` API.** Unlike Enzyme, ForwardDiff creates fresh caches each call (with the correct `Dual` element type), so the simple `solve(prob, alg)` path works directly. + +```julia +_promote(::Type{T}, x::AbstractArray{T}) where {T} = x +_promote(::Type{T}, x::AbstractArray) where {T} = T.(x) +``` + +## Differentiating Joint Likelihood + +```@example forwarddiff +using DifferenceEquations, LinearAlgebra, ForwardDiff, Random + +N, K, M = 2, 1, 2 +A = [0.8 0.1; -0.1 0.7] +B = [0.1; 0.0;;] +C = [1.0 0.0; 0.0 1.0] +H = [0.1 0.0; 0.0 0.1] +u0 = zeros(N) + +Random.seed!(42) +noise = [randn(K) for _ in 1:5] +sim = solve(LinearStateSpaceProblem(A, B, u0, (0, 5); C, noise)) +obs = [sim.z[t + 1] + 0.1 * randn(M) for t in 1:5] + +# Type-promotion helper +_promote(::Type{T}, x::AbstractArray{T}) where {T} = x +_promote(::Type{T}, x::AbstractArray) where {T} = T.(x) + +# Gradient of joint loglik w.r.t. vec(A) +function di_loglik(A_vec, B, C, u0, noise, obs, H) + T_el = eltype(A_vec) + A = reshape(A_vec, 2, 2) + H_d = _promote(T_el, H) + R = H_d * H_d' + prob = LinearStateSpaceProblem( + A, _promote(T_el, B), _promote(T_el, u0), (0, length(obs)); + C = _promote(T_el, C), observables_noise = R, + observables = obs, noise = noise) + sol = solve(prob, DirectIteration()) + return sol.logpdf +end + +grad_A = ForwardDiff.gradient( + a -> di_loglik(a, B, C, u0, noise, obs, H), vec(copy(A))) +``` + +## Differentiating the Kalman Filter + +The [`KalmanFilter`](@ref) marginal log-likelihood works the same way. + +```@example forwarddiff +# General Kalman loglik that promotes all inputs consistently +function kf_loglik(A, B, C, mu0, Sigma0, R, obs) + T_el = promote_type(eltype(A), eltype(B), eltype(C), + eltype(mu0), eltype(Sigma0), eltype(R)) + N_st = size(A, 1) + prob = LinearStateSpaceProblem( + _promote(T_el, A), _promote(T_el, B), + zeros(T_el, N_st), (0, length(obs)); + C = _promote(T_el, C), + u0_prior_mean = _promote(T_el, mu0), + u0_prior_var = _promote(T_el, Sigma0), + observables_noise = _promote(T_el, R), + observables = obs) + sol = solve(prob, KalmanFilter()) + return sol.logpdf +end + +mu0 = zeros(N) +Sigma0 = Matrix(1.0 * I(N)) +R = [0.01 0.0; 0.0 0.01] + +grad_kf = ForwardDiff.gradient( + a -> kf_loglik(reshape(a, N, N), B, C, mu0, Sigma0, R, obs), vec(copy(A))) +``` + +## Differentiating with Respect to Multiple Parameters + +Because `kf_loglik` promotes all inputs via `promote_type`, you can differentiate with respect to any parameter. + +```@example forwarddiff +# Gradient w.r.t. observation matrix C +grad_C = ForwardDiff.gradient( + c_vec -> kf_loglik(A, B, reshape(c_vec, M, N), mu0, Sigma0, R, obs), + vec(copy(C))) +``` + +```@example forwarddiff +# Gradient w.r.t. prior mean +grad_mu0 = ForwardDiff.gradient( + m -> kf_loglik(A, B, C, m, Sigma0, R, obs), copy(mu0)) +``` + +## Integration with Optimization.jl + +ForwardDiff integrates with [Optimization.jl](https://github.com/SciML/Optimization.jl) via `AutoForwardDiff()`, which is simpler than the Enzyme path (no manual `Duplicated` bookkeeping). + +```@example forwarddiff +using Optimization, OptimizationOptimJL + +# Simulate data from a known model +Random.seed!(42) +T_opt = 200 +B_opt = [0.0; 0.001;;] +C_opt = [0.09 0.67; 1.00 0.00] +R_opt = [0.01 0.0; 0.0 0.01] +prob_data = LinearStateSpaceProblem([0.95 6.2; 0.0 0.2], B_opt, zeros(2), (0, T_opt); + C = C_opt, observables_noise = R_opt) +sol_data = solve(prob_data) +obs_data = sol_data.z[2:end] + +# Objective: negative Kalman loglik as function of β = [A[1,1]] +mu0_opt = zeros(2) +Sigma0_opt = Matrix(1e-2 * I(2)) + +function neg_loglik(beta, p) + A = [beta[1] 6.2; 0.0 0.2] + return -kf_loglik(A, p.B, p.C, p.mu0, p.Sigma0, p.R, p.obs) +end + +params = (; B = B_opt, C = C_opt, R = R_opt, obs = obs_data, + mu0 = mu0_opt, Sigma0 = Sigma0_opt) + +optf = OptimizationFunction(neg_loglik, AutoForwardDiff()) +optprob = OptimizationProblem(optf, [0.90], params) +optsol = solve(optprob, LBFGS()) +optsol.u # estimated β (true value: 0.95) +``` + +## StaticArrays + +ForwardDiff also works with `SVector`/`SMatrix` inputs. Construct static arrays from the dual-typed input vector inside the function. + +```@example forwarddiff +using StaticArrays + +function kf_loglik_static(A_vec, B, C, mu0, Sigma0, R, obs, + ::Val{N_}, ::Val{M_}, ::Val{K_}) where {N_, M_, K_} + T_el = eltype(A_vec) + A = SMatrix{N_, N_}(reshape(A_vec, N_, N_)) + prob = LinearStateSpaceProblem( + A, SMatrix{N_, K_}(T_el.(B)), + SVector{N_}(zeros(T_el, N_)), (0, length(obs)); + C = SMatrix{M_, N_}(T_el.(C)), + u0_prior_mean = SVector{N_}(T_el.(mu0)), + u0_prior_var = SMatrix{N_, N_}(T_el.(Sigma0)), + observables_noise = SMatrix{M_, M_}(T_el.(R)), + observables = obs) + sol = solve(prob, KalmanFilter()) + return sol.logpdf +end + +obs_s = [SVector{M}(o) for o in obs] +grad_static = ForwardDiff.gradient( + a -> kf_loglik_static(a, SMatrix{N,K}(B), SMatrix{M,N}(C), + SVector{N}(mu0), SMatrix{N,N}(Sigma0), SMatrix{M,M}(R), + obs_s, Val(N), Val(M), Val(K)), + collect(vec(Matrix(A)))) +``` + +!!! tip "Use `save_everystep=false` with StaticArrays" + + Combining ForwardDiff + StaticArrays with `save_everystep=false` gives the best + performance for small models. Without it, the allocation of T dual-number + SVector buffers dominates. With `save_everystep=false`, only 2 scratch slots are + used, yielding up to **7x speedup** for the Kalman filter and **3.4x** for + ConditionalLikelihood at N=5. + + ```julia + sol = solve(prob, KalmanFilter(); save_everystep=false) + ``` + +## Quadratic and Generic Models + +ForwardDiff works with all problem types: [`QuadraticStateSpaceProblem`](@ref), [`PrunedQuadraticStateSpaceProblem`](@ref), and [`StateSpaceProblem`](@ref). The same pattern applies — promote all arrays to the `Dual` element type inside the gradient function and call `solve(prob, DirectIteration())`. + +## Important Notes + +- **Type promotion is required.** All arrays flowing into the problem must have the same element type. Use `promote_type` across all inputs (as in `kf_loglik` above) or the `_promote` helper to convert `Float64` arrays to the `Dual` type. +- **Fresh allocation each call.** ForwardDiff creates new caches with `Dual` element types via `solve()`. This is unavoidable (unlike Enzyme, which reuses `Float64` caches with separate shadow arrays). Use `save_everystep=false` to minimize these allocations from O(T) to O(1) when you only need `logpdf`. +- **Chunk size.** `ForwardDiff.gradient` defaults to a chunk size of ~10, processing 10 partial derivatives per forward pass. For parameter count > 10, it runs multiple passes. This makes ForwardDiff cost scale linearly with the number of parameters being differentiated. +- **Observations stay `Float64`.** The `observables` (data) are not differentiated and can remain `Vector{Vector{Float64}}`. The solver's internal buffers are allocated with the `Dual` element type, so when `Float64` observations are copied in, the dual partials are zero — which is correct since observations are data, not parameters being differentiated. +- **DirectIteration noise sensitivity.** When differentiating `DirectIteration` w.r.t. the noise sequence, the parameter dimension is `K × T` (shocks × periods). Even for small state-space models, long time series make ForwardDiff expensive and Enzyme reverse the better choice. diff --git a/docs/src/advanced/internals.md b/docs/src/advanced/internals.md new file mode 100644 index 0000000..bae3bca --- /dev/null +++ b/docs/src/advanced/internals.md @@ -0,0 +1,57 @@ +# Internals + +This page documents the internal architecture of DifferenceEquations.jl. It is intended for developers who want to understand the package internals or extend the package with new problem types or algorithms. + +## Architecture + +The solving pipeline follows these stages: + +1. **Problem construction**: The user creates a problem (e.g., `LinearStateSpaceProblem`) that encodes the model dynamics, parameters, and data. +2. **Algorithm dispatch**: `solve(prob)` or `solve(prob, alg)` selects the algorithm. If no algorithm is provided, the default is chosen based on the problem type and its fields. +3. **Workspace allocation**: `init(prob, alg)` allocates the solution output via `alloc_sol` and scratch workspace via `alloc_cache`, then wraps them in a `StateSpaceWorkspace`. +4. **Solve**: `solve!(ws)` runs the algorithm, which fully overwrites all solution and cache arrays during the time loop. +5. **Solution**: A `StateSpaceSolution` is returned containing the state trajectory, observations, noise, log-likelihood, and other results. + +## Bang-Bang Operators + +DifferenceEquations.jl uses a "bang-bang" (`!!`) convention for internal operators. These functions behave differently depending on whether their arguments are mutable or immutable: + +- **Mutable arrays** (`Vector`, `Matrix`): The operator mutates the destination in place and returns it. +- **Immutable arrays** (`SVector`, `SMatrix`): The operator creates and returns a new value, since mutation is not possible. + +This dual behavior allows the same algorithm code to work with both standard arrays and StaticArrays without any branching or specialization at the call site. + +The main bang-bang operators are: + +| Operator | Description | +|----------|-------------| +| `mul!!(C, A, B)` | Matrix multiply `A * B`, storing in `C` | +| `copyto!!(dest, src)` | Copy contents of `src` into `dest` | +| `assign!!(dest, i, val)` | Assign `val` to position `i` in `dest` | +| `cholesky!!(F, A)` | Compute the Cholesky factorization of `A` | +| `ldiv!!(Y, F, B)` | Solve `F \ B`, storing in `Y` | +| `transpose!!(dest, src)` | Transpose `src` into `dest` | + +## Cache System + +Each combination of problem type and algorithm defines two allocation functions: + +- **`alloc_sol(prob, alg, T)`**: Allocates the output structure that will hold the solution (state trajectory, observations, noise, covariances, etc.). Returns a named tuple or struct of pre-allocated arrays. +- **`alloc_cache(prob, alg, T)`**: Allocates scratch workspace needed during the solve (temporary vectors, matrices for intermediate computations, etc.). Returns a named tuple or struct of pre-allocated buffers. + +The solver loop fully overwrites all solution and cache arrays on each call, so no explicit zeroing step is needed between calls. For Enzyme AD, shadow copies should be zero-initialized via `Enzyme.make_zero(deepcopy(...))` at creation time (see [Enzyme AD](@ref)). + +## Adding a New Problem Type + +To add a new problem type to DifferenceEquations.jl, you need to implement the following methods: + +| Method | Signature | Description | +|--------|-----------|-------------| +| `_noise_matrix(prob)` | `prob → Matrix` | Return the noise input matrix (e.g., `B` for linear models) | +| `_init_model_state!!(prob, cache)` | `prob, cache → cache` | Initialize any model-specific cache state before the time loop | +| `_transition!!(x_next, x, w, prob, cache, t)` | `x_next, x, w, prob, cache, t → x_next` | Compute the next state given current state `x` and noise `w` at time `t` | +| `_observation!!(y, x, prob, cache, t)` | `y, x, prob, cache, t → y` | Compute the observation given state `x` at time `t` | +| `alloc_sol(prob, alg, T)` | `prob, alg, Int → NamedTuple` | Allocate the solution output arrays for `T` time steps | +| `alloc_cache(prob, alg, T)` | `prob, alg, Int → NamedTuple` | Allocate scratch workspace for `T` time steps | + +All transition and observation methods should follow the bang-bang convention: mutate the first argument if it is mutable, otherwise return a new value. This ensures compatibility with both standard arrays and StaticArrays. diff --git a/docs/src/advanced/static_arrays.md b/docs/src/advanced/static_arrays.md new file mode 100644 index 0000000..dfb8dca --- /dev/null +++ b/docs/src/advanced/static_arrays.md @@ -0,0 +1,66 @@ +# StaticArrays + +For small state-space models (typically 2--5 states), using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) can significantly improve performance by eliminating heap allocations and enabling compiler optimizations such as loop unrolling. + +## Example + +```@example static +using DifferenceEquations, StaticArrays, LinearAlgebra +A = @SMatrix [0.95 6.2; 0.0 0.2] +B = @SMatrix [0.0; 0.01;;] +C = @SMatrix [0.09 0.67; 1.00 0.00] +u0 = @SVector zeros(2) +prob = LinearStateSpaceProblem(A, B, u0, (0, 10); C) +sol = solve(prob) +sol.u[end] +``` + +## When to Use + +StaticArrays are most beneficial when: + +- **State dimensions are small**: The performance advantage is greatest for matrices up to roughly 10x10. Beyond that, the compile-time overhead and code size can outweigh the benefits. +- **Sizes are known at compile time**: StaticArrays encode their dimensions as type parameters, so the sizes must be fixed constants rather than runtime values. +- **You need stack allocation**: StaticArrays are stored on the stack rather than the heap, eliminating GC pressure entirely for small models. + +For larger models or models where dimensions vary at runtime, use standard `Array` types instead. + +## Bang-Bang Operators + +The package internally uses "bang-bang" operators (e.g., `mul!!`, `copyto!!`, `assign!!`) that handle both mutable and immutable arrays transparently. When you pass `SMatrix` and `SVector` types, these operators return new immutable values rather than mutating in place. When you pass standard `Matrix` and `Vector` types, they mutate in place and return the result. This means you do not need to change any solver code to switch between static and dynamic arrays -- simply change the array types in your problem definition. + +See [Internals](@ref) for the full list of bang-bang operators and their behavior. + +## Supported Problem Types + +StaticArrays work with all problem types: + +- [`LinearStateSpaceProblem`](@ref) with `DirectIteration` — simulation and log-likelihood +- [`LinearStateSpaceProblem`](@ref) with `KalmanFilter` — filtering, smoothing, and log-likelihood +- [`QuadraticStateSpaceProblem`](@ref) and [`PrunedQuadraticStateSpaceProblem`](@ref) — second-order perturbation models +- [`StateSpaceProblem`](@ref) — generic callbacks using bang-bang operators + +## Kalman Filter Example + +```@example static +using DifferenceEquations, StaticArrays, LinearAlgebra, Random +Random.seed!(42) +A = SMatrix{2,2}(0.8*I(2)) +B = SMatrix{2,1}([0.1; 0.05]) +C = SMatrix{2,2}(1.0*I(2)) +R = SMatrix{2,2}(0.01*I(2)) +mu0 = @SVector zeros(2) +Sig0 = SMatrix{2,2}(1.0*I(2)) +y = [SVector{2}(randn(2)) for _ in 1:10] +prob = LinearStateSpaceProblem(A, B, mu0, (0, 10); C, + u0_prior_mean=mu0, u0_prior_var=Sig0, + observables_noise=R, observables=y) +sol = solve(prob, KalmanFilter()) +sol.logpdf +``` + +## AD Performance Note + +Enzyme reverse-mode AD benefits significantly from StaticArrays at small dimensions (N ≤ 5), with 5--7x speedups over mutable arrays for both the primal and AD passes. + +ForwardDiff with StaticArrays does not improve AD performance for this package. The overhead of constructing `SMatrix{N,N,Dual{...}}` temporaries outweighs the benefit. See [ForwardDiff AD](@ref) for details. diff --git a/docs/src/basics/faq.md b/docs/src/basics/faq.md new file mode 100644 index 0000000..3b6a404 --- /dev/null +++ b/docs/src/basics/faq.md @@ -0,0 +1,33 @@ +# FAQ + +## When should I use the Kalman filter vs. joint likelihood? + +- **Kalman filter**: Use for linear Gaussian models when you want the marginal likelihood, integrating out the latent noise sequence. This is the standard approach for maximum likelihood estimation (MLE) of parameters. +- **Joint likelihood**: Use when conditioning on a specific noise realization. This is useful for Bayesian methods where the noise sequence is sampled as part of inference (e.g., particle MCMC, HMC on latent variables). + +## Why does Enzyme require all arguments to be Duplicated? + +Enzyme tracks activity at the struct level. When constructing a `LinearStateSpaceProblem`, all matrix arguments (e.g., `A`, `B`, `C`) flow into a single struct. If any argument is active (i.e., being differentiated), Enzyme needs shadow copies for all arguments in the struct. Passing some arguments as `Const` while others are `Duplicated` triggers an `EnzymeRuntimeActivityError`. The solution is to mark all arguments as `Duplicated`. + +## What is the observables timing convention? + +The `tspan` `(0, T)` produces `T+1` states: ``u_0, u_1, \ldots, u_T``. Observations ``z_n`` correspond to state ``u_n``. The `observables` keyword expects `T` vectors corresponding to ``z_1, z_2, \ldots, z_T`` (skipping ``z_0``). So when passing simulated data, use `sol.z[2:end]`. + +## What does `observables_noise` represent? + +The `observables_noise` keyword specifies the observation noise **covariance matrix** (entries are variances and covariances, not standard deviations). It must be an `AbstractMatrix` — use `Diagonal([σ₁², σ₂², …])` for diagonal noise or a full `Matrix`/`Symmetric(H * H')` for correlated noise. + +Its behavior depends on context: + +- **During simulation** (when `observables` is not provided): used to generate synthetic measurement noise added to the clean observations `sol.z`. +- **During likelihood computation** (when `observables` is provided): used as the observation noise covariance in the log-likelihood calculation. + +## How do I differentiate with respect to only some parameters? + +- **Enzyme**: All arguments to the likelihood function must be marked `Duplicated` (see [Enzyme AD](@ref)). However, you only need to *read* the shadow of the parameter you care about — the other shadows are computed but can be discarded. There is no performance cost to ignoring unused shadows. +- **ForwardDiff**: Only the parameter passed through `ForwardDiff.gradient`'s vector argument is differentiated. Other parameters are captured as constants in the closure, so no derivatives are computed for them. + +## How do I choose between QuadraticStateSpaceProblem and PrunedQuadraticStateSpaceProblem? + +- **`PrunedQuadraticStateSpaceProblem`**: Use for second-order perturbation solutions of DSGE models. The pruning prevents explosive dynamics by applying the quadratic term to a separate linear-part state rather than the full nonlinear state. +- **`QuadraticStateSpaceProblem`**: Use if you specifically need the unpruned quadratic form (e.g., for comparison or when the system is known to be stable). diff --git a/docs/src/basics/problem_types.md b/docs/src/basics/problem_types.md new file mode 100644 index 0000000..d596d40 --- /dev/null +++ b/docs/src/basics/problem_types.md @@ -0,0 +1,87 @@ +# Problem Types + +DifferenceEquations.jl provides a hierarchy of problem types for defining discrete-time state-space models. All concrete problem types inherit from `AbstractStateSpaceProblem` and share a common interface for specifying dynamics, observations, and noise. + +## Abstract Type + +```@docs +AbstractStateSpaceProblem +``` + +## LinearStateSpaceProblem + +```@docs +LinearStateSpaceProblem +``` + +## QuadraticStateSpaceProblem + +```@docs +QuadraticStateSpaceProblem +``` + +## PrunedQuadraticStateSpaceProblem + +```@docs +PrunedQuadraticStateSpaceProblem +``` + +## StateSpaceProblem + +```@docs +StateSpaceProblem +``` + +## Common Keyword Arguments + +The following keywords are shared by all problem constructors: + +| Keyword | Description | Default | +|---------|-------------|---------| +| `observables_noise` | Observation noise covariance matrix (`AbstractMatrix`, e.g. `Diagonal(d)` or `Symmetric(H * H')`) | `nothing` | +| `observables` | Observed data as `Vector{Vector{T}}` | `nothing` | +| `noise` | Fixed noise as `Vector{Vector{T}}` | `nothing` (drawn randomly) | +| `syms` | State variable names as a `Tuple` of `Symbol`s, e.g. `(:x, :y)` | `nothing` | +| `obs_syms` | Observation variable names as a `Tuple` of `Symbol`s | `nothing` | + +### Linear-only keywords + +These are accepted only by [`LinearStateSpaceProblem`](@ref): + +| Keyword | Description | Default | +|---------|-------------|---------| +| `C` | Observation matrix | `nothing` | +| `u0_prior_mean` | Prior mean for Kalman filtering | `nothing` | +| `u0_prior_var` | Prior covariance for Kalman filtering | `nothing` | + +### Quadratic-only keywords + +[`QuadraticStateSpaceProblem`](@ref) and [`PrunedQuadraticStateSpaceProblem`](@ref) accept `C_0`, `C_1`, `C_2` instead of `C`. + +### Generic-only keywords + +[`StateSpaceProblem`](@ref) requires the additional positional/keyword arguments `n_shocks` and `n_obs` to specify dimensions. + +### Dual role of `observables_noise` + +The `observables_noise` keyword has a dual role: +- **During simulation** (when `observables` is not provided): observation noise with this covariance is added to the simulated observations `sol.z`. +- **During likelihood computation** (when `observables` is provided): it defines the observation noise covariance used in the log-likelihood calculation. + +!!! note + + `observables_noise` must be an `AbstractMatrix`. For diagonal noise, use `Diagonal([σ₁², σ₂², …])` where the entries are **variances** (not standard deviations). For a general covariance, use a full `Matrix` or `Symmetric(H * H')`. + +## Remaking Problems + +Use `remake` to create a modified copy of a problem, changing specific fields while keeping everything else. This is useful for parameter sweeps and optimization loops. + +```@example remake_example +using DifferenceEquations, LinearAlgebra +A = [0.95 6.2; 0.0 0.2] +B = [0.0; 0.01;;] +prob = LinearStateSpaceProblem(A, B, zeros(2), (0, 5)) +prob2 = remake(prob; u0 = [0.1, 0.2]) +sol2 = solve(prob2) +sol2.u[1] # new initial condition +``` diff --git a/docs/src/basics/solutions.md b/docs/src/basics/solutions.md new file mode 100644 index 0000000..cd55135 --- /dev/null +++ b/docs/src/basics/solutions.md @@ -0,0 +1,60 @@ +# Solutions + +```@docs +StateSpaceSolution +``` + +## Fields + +| Field | Type | Description | +|-------|------|-------------| +| `u` | `Vector{Vector{T}}` | State trajectory | +| `t` | Range | Time values | +| `z` | `Vector{Vector{T}}` or `nothing` | Observations | +| `W` | `Vector{Vector{T}}` or `nothing` | Noise sequence (DirectIteration only) | +| `P` | `Vector{Matrix{T}}` or `nothing` | Posterior covariances (KalmanFilter only) | +| `logpdf` | `Real` | Log-likelihood (0.0 if no observables; may be a `Dual` number under ForwardDiff) | +| `retcode` | `ReturnCode.T` | `ReturnCode.Success` (errors are thrown as exceptions, not encoded in the return code) | +| `prob` | Problem | Original problem | +| `alg` | Algorithm | Algorithm used | + +## Symbolic Indexing + +If `syms` or `obs_syms` were provided when constructing the problem, the solution supports symbolic indexing: + +```julia +prob = LinearStateSpaceProblem(A, B, u0, (0, 10); C, syms=(:x, :y), obs_syms=(:obs1, :obs2)) +sol = solve(prob) + +# Access state variables by name +sol[:x] # vector of :x values across all time steps +sol[:obs1] # vector of :obs1 observations across all time steps +``` + +## Standard Indexing + +Solutions support standard Julia indexing to access states at specific time steps: + +```julia +sol = solve(prob) + +sol[1] # state at t=0 (initial condition) +sol[end] # state at the final time step +sol.u[3] # state at the third time index +sol.z[2] # observation at the second time index (if C was provided) +``` + +## DataFrame Conversion + +The state trajectory can be converted to a DataFrame. Column names come from `syms` if provided. Note that only the state variables (not observations) appear in the DataFrame. + +```@example solutions_df +using DifferenceEquations, LinearAlgebra, DataFrames +A = [0.95 6.2; 0.0 0.2] +B = [0.0; 0.01;;] +C = [0.09 0.67; 1.00 0.00] +prob = LinearStateSpaceProblem(A, B, zeros(2), (0, 5); C, + syms = (:capital, :productivity), obs_syms = (:output, :investment)) +sol = solve(prob) +DataFrame(sol) +``` diff --git a/docs/src/basics/solvers.md b/docs/src/basics/solvers.md new file mode 100644 index 0000000..9b2fd17 --- /dev/null +++ b/docs/src/basics/solvers.md @@ -0,0 +1,67 @@ +# Solvers + +Solving a state-space problem is as simple as calling `solve(prob)`, which automatically selects an appropriate algorithm. You can also pass an algorithm explicitly via `solve(prob, alg)`. + +```@docs +DirectIteration +``` + +```@docs +KalmanFilter +``` + +```@docs +ConditionalLikelihood +``` + +## Default Algorithm Selection + +When no algorithm is specified, `solve(prob)` selects the algorithm based on the problem type and its fields: + +- **`DirectIteration`** is the default for all problem types. It simulates the state-space model forward in time, generating states and observations directly. If `observables` are provided, it computes the joint log-likelihood of the observed data given the noise sequence. + +- **`KalmanFilter`** is auto-selected for `LinearStateSpaceProblem` when all of the following conditions hold: + - `u0_prior_var` is an `AbstractMatrix` (prior covariance is specified) + - `noise` is `nothing` (noise is not fixed) + - `observables` is an `AbstractVector` (observed data is provided) + - `observables_noise` is an `AbstractMatrix` (observation noise covariance is specified) + - `A`, `B`, and `C` are all `AbstractMatrix` (not `nothing`) + + The Kalman filter computes the filtered state estimates and the marginal log-likelihood of the observations, integrating over the unknown noise sequence. + +- **`ConditionalLikelihood`** is never auto-selected. You must pass it explicitly via `solve(prob, ConditionalLikelihood())`. Use it for fully-observed state-space models (AR, VAR, nonlinear) where the state is directly observed and you want the prediction error decomposition log-likelihood. Works with all problem types. + +!!! warning + + If any of the KalmanFilter conditions are not met, `DirectIteration` is silently selected instead. For example, forgetting to pass `C` or `u0_prior_var` will produce a `DirectIteration` solve with `logpdf = 0.0` rather than the expected Kalman filter result. + +## `save_everystep` Keyword + +All algorithms support `save_everystep=false`, which stores only the initial and final states instead of the full trajectory: + +```julia +sol = solve(prob; save_everystep=false) # 2-element sol.u +sol = solve(prob, ConditionalLikelihood(); save_everystep=false) +sol = solve(prob, KalmanFilter(); save_everystep=false) +``` + +When `save_everystep=false`: +- `sol.u` contains `[u_initial, u_final]` (2 entries instead of T+1) +- `sol.z` contains `[z_initial, z_final]` (if observations are present) +- `sol.P` contains `[P_initial, P_final]` (KalmanFilter only) +- `sol.logpdf` is **identical** — computed on the fly, not from stored trajectory + +This is useful when you only need the final state or the log-likelihood (e.g., in optimization loops). It dramatically reduces memory allocation, which benefits ForwardDiff gradient computation: + +| Scenario | Typical speedup | Allocation reduction | +|----------|----------------|---------------------| +| ForwardDiff + StaticArrays (KF, N=5) | **7x** | 4,288 → 175 | +| ForwardDiff + StaticArrays (CL, N=5) | **3.4x** | 805 → 190 | +| ForwardDiff + mutable (KF, N=30) | **1.5x** | 342k → 8k | + +The workspace API also supports it: + +```julia +ws = init(prob, alg; save_everystep=false) +sol = solve!(ws) # reads save_everystep from workspace +``` diff --git a/docs/src/basics/workspace.md b/docs/src/basics/workspace.md new file mode 100644 index 0000000..e16784b --- /dev/null +++ b/docs/src/basics/workspace.md @@ -0,0 +1,81 @@ +# Workspace API + +The workspace API provides a pre-allocated, reusable solving pattern via `init` and `solve!`. This avoids repeated memory allocation when solving the same type of problem many times, and is required for compatibility with Enzyme.jl reverse-mode AD. + +```@docs +StateSpaceWorkspace +``` + +## Creating and Using a Workspace + +```@docs +DifferenceEquations.init +DifferenceEquations.solve! +``` + +## Basic Usage + +```@example workspace +using DifferenceEquations, LinearAlgebra, Random +A = [0.95 6.2; 0.0 0.2] +B = [0.0; 0.01;;] +C = [0.09 0.67; 1.00 0.00] +u0 = zeros(2) +prob = LinearStateSpaceProblem(A, B, u0, (0, 5); C) +ws = init(prob, DirectIteration()) +sol = solve!(ws) +sol.u[end] +``` + +## Cache Reuse + +Calling `solve!(ws)` again on the same workspace reuses all previously allocated buffers. The solver fully overwrites all output arrays on each call, so no manual reset is needed between calls. This makes the workspace pattern ideal for tight loops: + +```julia +ws = init(prob, DirectIteration()) +for i in 1:1000 + sol = solve!(ws) + # process sol... +end +``` + +You can also change the problem between calls for parameter sweeps using `remake`: + +```julia +ws = init(prob, DirectIteration()) +for a11 in [0.9, 0.95, 1.0] + ws.prob = remake(ws.prob; A = [a11 6.2; 0.0 0.2]) + sol = solve!(ws) + # process sol.logpdf... +end +``` + +## Endpoints-Only Mode (`save_everystep=false`) + +Pass `save_everystep=false` to `init` to allocate minimal 2-element buffers. The solver stores only the initial and final states, while still correctly accumulating `logpdf`: + +```@example workspace +ws_ep = init(prob, DirectIteration(); save_everystep=false) +sol_ep = solve!(ws_ep) +length(sol_ep.u) # 2: [u_initial, u_final] +``` + +This is especially useful for ForwardDiff gradient computation, where reducing the number of dual-number allocations from O(T) to O(1) gives significant speedups (up to 7x with StaticArrays): + +```julia +# ForwardDiff benefits from save_everystep=false +function neg_loglik(params) + prob = make_problem(params) + return -solve(prob, ConditionalLikelihood(); save_everystep=false).logpdf +end +ForwardDiff.gradient(neg_loglik, params0) +``` + +## When to Use + +The workspace API is useful in the following scenarios: + +- **Enzyme AD**: Enzyme requires pre-allocated buffers passed as `Duplicated` arguments. The workspace pattern via `init`/`solve!` is the recommended way to use Enzyme with DifferenceEquations.jl. See [Enzyme AD](@ref) for details. +- **Repeated solves in optimization loops**: When solving the same problem structure many times (e.g., during parameter estimation), the workspace avoids allocating new arrays on every iteration. +- **ForwardDiff with `save_everystep=false`**: Combining the workspace API with endpoints-only mode minimizes dual-number allocations, giving the best ForwardDiff performance. +- **Performance-critical code**: Eliminating allocations reduces GC pressure and improves performance, especially for small to medium-sized problems. diff --git a/docs/src/examples/general_state_space_examples.md b/docs/src/examples/general_state_space_examples.md deleted file mode 100644 index 056d3ab..0000000 --- a/docs/src/examples/general_state_space_examples.md +++ /dev/null @@ -1,21 +0,0 @@ -# General State Space Examples - -!!! note - - This is a placeholder for future support for general nonlinear state-space problems. The basic implementation is a relatively simple variation on the linear version, where you call back into AD for the `f,g,h` calls in the `rrule` definition. Because of the mixture of AD calls and rules, it may make sense to wait for `Enzyme.jl` to be ready. - -A future feature, if anyone is interested in writing it, is full support for - -```math -u_{n+1} = f(u_n,p,t_n) + g(u_n,p,t_n) w_{n+1} -``` - -for some functions $f$ and $g$, where $w_{n+1}$ are IID random shocks to the evolution equation. The $p$ is a vector of potentially differentiable parameters. - -In addition, there is an optional observation equation - -```math -z_n = h(u_n, p, t_n) + v_n -``` - -This could involve both the simulation and the calculation of the joint likelihood conditional on the noise, as in the other examples. diff --git a/docs/src/examples/linear_state_space_examples.md b/docs/src/examples/linear_state_space_examples.md deleted file mode 100644 index a267da8..0000000 --- a/docs/src/examples/linear_state_space_examples.md +++ /dev/null @@ -1,356 +0,0 @@ -# Linear State Space Examples - -This tutorial describes the support for linear and linear gaussian state space models. - -At this point, the package only supports linear time-invariant models without a separate `p` vector. The canonical form of the linear model is - -```math -u_{n+1} = A u_n + B w_{n+1} -``` - -with - -```math -z_n = C u_n + v_n -``` - -and optionally $v_n \sim N(0, D)$ and $w_{n+1} \sim N(0,I)$. If you pass noise into the solver, it no longer needs to be Gaussian. More generally, support could be added for $u_{n+1} = A(p,n) u_n + B(p,n) w_{n+1}$ where $p$ is a vector of differentiable parameters, and the $A$ and $B$ are potentially matrix-free operators. - -## Simulating a Linear (and Time-Invariant) State Space Model - -Creating a `LinearStateSpaceProblem` and simulating it for a simple, linear equation. - -```@example 1 -using DifferenceEquations, LinearAlgebra, Distributions, Random, Plots, DataFrames, Zygote -A = [0.95 6.2; - 0.0 0.2] -B = [0.0; 0.01;;] # matrix -C = [0.09 0.67; - 1.00 0.00] -D = [0.1, 0.1] # diagonal observation noise -u0 = zeros(2) -T = 10 - -prob = LinearStateSpaceProblem(A, B, u0, (0, T); C, observables_noise = D, syms = [:a, :b]) -sol = solve(prob) -``` - -The `u` vector of the simulated solution can be plotted using the standard recipes, including the use of the optional `syms`. -See the [SciML docs](https://diffeq.sciml.ai/latest/basics/plot/) for more options. - -```@example 1 -plot(sol) -``` - -By default, the solution provides an interface to access the simulated `u` via `sol.u`, - -```@example 1 -sol.u[2] -``` - -Or to get the first element of the last step - -```@example 1 -sol.u[end][1] #first element of last step -``` - -Finally, to extract the full vector - -```@example 1 -@show sol[:, 2]; # whole second vector -``` - -The results for all of `sol.u` can be loaded in a dataframe, where the column names will be the (optionally) provided symbols. - -```@example 1 -df = DataFrame(sol) -``` - -Other results, such as the simulated noise and observables, can be extracted from the solution - -```@example 1 -sol.z # observables -``` - -```@example 1 -sol.W # Simulated Noise -``` - -We can also solve the model by passing in fixed noise, which will be useful for joint likelihoods. First, let's extract the noise from the previous solution, then rerun the simulation but with a different initial value - -```@example 1 -noise = sol.W -u0_2 = [0.1, 0.0] -prob2 = LinearStateSpaceProblem( - A, B, u0_2, (0, T); C, observables_noise = D, syms = [:a, :b], noise) -sol2 = solve(prob2) -plot(sol2) -``` - -To construct an IRF we can take the model and perturb just the first element of the noise, - -```@example 1 -function irf(A, B, C, T = 20) - noise = Matrix([1.0; zeros(T - 1)]') - problem = LinearStateSpaceProblem(A, B, zeros(2), (0, T); C, noise, syms = [:a, :b]) - return solve(problem) -end -plot(irf(A, B, C)) -``` - -Let's find the 2nd observable at the end of the IRF. - -```@example 1 -function last_observable_irf(A, B, C) - sol = irf(A, B, C) - return sol.z[end][2] # return 2nd argument of last observable -end -last_observable_irf(A, B, C) -``` - -But everything in this package is differentiable. Let's differentiate the observable of the IRF with respect to all the parameters using `Zygote.jl`, - -```@example 1 -gradient(last_observable_irf, A, B, C) # calculates gradient wrt all arguments -``` - -Gradients of other model elements (e.g. `.u`) are also possible. With this in mind, let's find the gradient of the mean of the 1st element of the IRF of the solution with respect to a particular noise vector. - -```@example 1 -function mean_u_1(A, B, C, noise, u0, T) - problem = LinearStateSpaceProblem(A, B, u0, (0, T); noise, syms = [:a, :b]) - sol = solve(problem) - u = sol.u # see issue #75 workaround - # can have nontrivial functions and even non-mutating loops - return mean(u[i][1] for i in 1:T) -end -u0 = [0.0, 0.0] -noise = sol.W # from simulation above -mean_u_1(A, B, C, noise, u0, T) -# dropping a few arguments from derivative -gradient((noise, u0) -> mean_u_1(A, B, C, noise, u0, T), noise, u0) -``` - -## Simulating Ensembles and Fixing Noise - -If you pass in a distribution for the initial condition, it will draw an initial condition. Below, we will simulate from a deterministic evolution equation, without any observation noise. - -```@example 1 -using Distributions, DiffEqBase -u0 = MvNormal([1.0 0.1; 0.1 1.0]) # mean zero initial conditions -prob = LinearStateSpaceProblem(A, nothing, u0, (0, T); C) -sol = solve(prob) -plot(sol) -``` - -With this, we can simulate an ensemble of solutions from different initial conditions (and we will turn back on the noise). The `EnsembleSummary` calculates a set of quantiles by default. - -```@example 1 -T = 10 -trajectories = 50 -prob = LinearStateSpaceProblem(A, B, u0, (0, T); C) -sol = solve(EnsembleProblem(prob), DirectIteration(), EnsembleThreads(); trajectories) -summ = EnsembleSummary(sol) #calculate summarize statistics from the -plot(summ) # shows quantiles by default -``` - -## Observables and Marginal Likelihood using a Kalman Filter - -If you provide `observables` and provide a distribution for the `observables_noise` then the model can provide a calculation of the likelihood. - -The simplest case is if you use a gaussian prior and have gaussian observation noise. First, let's simulate some data with included observation noise. If passing in a matrix or vector, the `observables_noise` argument is intended to be the cholesky of the covariance matrix. At this point, only diagonal observation noise is allowed. - -```@example 1 -u0 = MvNormal([1.0 0.1; 0.1 1.0]) # draw from mean zero initial conditions -T = 10 -prob = LinearStateSpaceProblem(A, B, u0, (0, T); C, observables_noise = D, syms = [:a, :b]) -sol = solve(prob) -sol.z # simulated observables with observation noise -``` - -Next, we will find the log likelihood of these simulated observables using `u0` as a prior and with the true parameters. - -The new arguments we pass to the problem creation are `u0_prior_variance, u0_prior_mean,` and `observables`. The `u0` is ignored for the filtering problem, but must match the size. The `KalmanFilter()` argument to the `solve` is unnecessary since it can be selected automatically given the priors and observables. - -!!! note - - The timing convention is such that `observables` are expected to match the predictions starting at the second time period. As the likelihood of the first element `u0` comes from a prior, the `observables` start at the next element, and hence the observables and noise sequences should be 1 less than the tspan. - -```@example 1 -observables = hcat(sol.z...) # Observables required to be matrix. Issue #55 -observables = observables[:, 2:end] # see note above on likelihood and timing -noise = copy(sol.W) # save for later -u0_prior_mean = [0.0, 0.0] -# use covariance of distribution we drew from -u0_prior_var = cov(u0) - -prob = LinearStateSpaceProblem(A, B, u0, (0, size(observables, 2)); C, observables, - observables_noise = D, syms = [:a, :b], u0_prior_var, u0_prior_mean) -sol = solve(prob, KalmanFilter()) -# plot(sol) The `u` is the sequence of posterior means. -sol.logpdf -``` - -Hence, the `logpdf` provides the log likelihood marginalizing out the latent noise variables. - -As before, we can differentiate the kalman filter itself. - -```@example 1 -function kalman_likelihood(A, B, C, D, u0_prior_mean, u0_prior_var, observables) - prob = LinearStateSpaceProblem(A, B, u0, (0, size(observables, 2)); C, observables, - observables_noise = D, syms = [:a, :b], u0_prior_var, u0_prior_mean) - return solve(prob).logpdf -end -kalman_likelihood(A, B, C, D, u0_prior_mean, u0_prior_var, observables) -# Find the gradient wrt the A, B, C and priors variance. -gradient( - (A, - B, - C, - u0_prior_var) -> kalman_likelihood( - A, B, C, D, u0_prior_mean, u0_prior_var, observables), - A, - B, - C, - u0_prior_var) -``` - -!!! note - - Some gradients, such as those for `observables`, have not been implemented, so test carefully. This is a general theme with gradients and `Zygote.jl` in general. Your best friend in this process is the spectacular [ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl) package. See `test_rrule` usage in the [linear unit tests](https://github.com/SciML/DifferenceEquations.jl/blob/main/test/linear_gradients.jl). - -## Joint Likelihood with Noise - -A key application of these methods is to find the joint likelihood of the latent variables (i.e., the `noise`) and the model definition. - -The actual calculation of the likelihood is trivial in that case, and just requires iteration of the linear system while accumulating the likelihood given the observation noise. - -Crucially, the differentiability with respect to the high-dimensional noise vector enables gradient-based sampling and estimation methods that would otherwise be infeasible. - -```@example 1 -function joint_likelihood(A, B, C, D, u0, noise, observables) - prob = LinearStateSpaceProblem( - A, B, u0, (0, size(observables, 2)); C, observables, observables_noise = D, noise) - return solve(prob).logpdf -end -u0 = [0.0, 0.0] -joint_likelihood(A, B, C, D, u0, noise, observables) -``` - -And as always, this can be differentiated with respect to the state-space matrices and the noise. Choosing a few parameters, - -```@example 1 -gradient( - (A, u0, noise) -> joint_likelihood(A, B, C, D, u0, noise, observables), A, u0, noise) -``` - -## Composition of State Space Models and AD - -While the above gradients have been with respect to the full state space objects `A, B`, etc. those themselves could be generated through a separate procedure and the whole object differentiated. For example, let's repeat the above examples where we generate the `A` matrix from some sort of deep parameters. - -First, we will generate some observations with a `generate_model` proxy, which could be replaced with something more complicated but still differentiable - -```@example 1 -function generate_model(β) - A = [β 6.2; - 0.0 0.2] - B = Matrix([0.0 0.001]') # [0.0; 0.001;;] gives a zygote bug - C = [0.09 0.67; - 1.00 0.00] - D = [0.01, 0.01] - return (; A, B, C, D) -end - -function simulate_model(β, u0; T = 200) - mod = generate_model(β) - prob = LinearStateSpaceProblem( - mod.A, mod.B, u0, (0, T); mod.C, observables_noise = mod.D) - sol = solve(prob) # simulates - observables = hcat(sol.z...) - observables = observables[:, 2:end] # see note above on likelihood and timing - return observables, sol.W -end - -# Fix a "pseudo-true" and generate noise and observables -β = 0.95 -u0 = [0.0, 0.0] -observables, noise = simulate_model(β, u0) -``` - -Next, we will evaluate the marginal likelihood using the kalman filter for a particular `β` value, - -```@example 1 -function kalman_model_likelihood(β, u0_prior_mean, u0_prior_var, observables) - mod = generate_model(β) # generate model from structural parameters - prob = LinearStateSpaceProblem( - mod.A, mod.B, u0, (0, size(observables, 2)); mod.C, observables, - observables_noise = mod.D, u0_prior_var, u0_prior_mean) - return solve(prob).logpdf -end -u0_prior_mean = [0.0, 0.0] -u0_prior_var = [1e-10 0.0; - 0.0 1e-10] # starting with degenerate prior -kalman_model_likelihood(β, u0_prior_mean, u0_prior_var, observables) -``` - -Given the observation error, we would not expect the pseudo-true to exactly maximize the log likelihood. To show this, we can optimize it using the Optim package, specifically using a gradient-based optimization routine - -```@example 1 -using Optimization, OptimizationOptimJL -# Create a function to minimize only of β and use Zygote based gradients -function kalman_objective(β, p) - -kalman_model_likelihood(β, u0_prior_mean, u0_prior_var, observables) -end -kalman_objective(0.95, nothing) -gradient(β -> kalman_objective(β, nothing), β) # Verifying it can be differentiated - -optf = OptimizationFunction(kalman_objective, Optimization.AutoZygote()) -β0 = [0.91] # start off of the pseudotrue -optprob = OptimizationProblem(optf, β0) -optsol = solve(optprob, LBFGS()) # reverse-mode AD is overkill here -``` - -In this way, this package composes with others such as [DifferentiableStateSpaceModels.jl](https://github.com/HighDimensionalEconLab/DifferentiableStateSpaceModels.jl) which takes a set of structural parameters and an expected difference equation to generate a state-space model. - -Similarly, we can find the joint likelihood for a particular `β` value and noise. Here we will add in prior. Some form of prior or regularization is generally necessary for these sorts of nonlinear models. - -```@example 1 -function joint_model_posterior(β, u0, noise, observables, noise_prior, β_prior) - mod = generate_model(β) # generate model from structural parameters - prob = LinearStateSpaceProblem(mod.A, mod.B, u0, (0, size(observables, 2)); mod.C, - observables, observables_noise = mod.D, noise) - return solve(prob).logpdf + sum(logpdf.(noise_prior, noise)) + logpdf(β_prior, β) # posterior -end -u0 = [0.0, 0.0] -noise_prior = Normal(0.0, 1.0) -β_prior = Normal(β, 0.03) # prior local to the true value -joint_model_posterior(β, u0, noise, observables, noise_prior, β_prior) -``` - -Which we can turn into a differentiable objective by adding in a prior on the noise - -```@example 1 -function joint_model_objective(x, p) - -joint_model_posterior(x[1], u0, Matrix(x[2:end]'), observables, noise_prior, β_prior) -end # extract noise and parameeter from vector -x0 = vcat([0.95], noise[1, :]) # starting at the true noise -joint_model_objective(x0, nothing) -gradient(x -> joint_model_objective(x, nothing), x0) # Verifying it can be differentiated - -# optimize -optf = OptimizationFunction(joint_model_objective, Optimization.AutoZygote()) -optprob = OptimizationProblem(optf, x0) -optsol = solve(optprob, LBFGS()) -``` - -This "solves" the problem relatively quickly, despite the high-dimensionality. However, from a statistics perspective note that this last optimization process does not do especially well in recovering the pseudotrue if you increase the prior variance on the `β` parameter. Maximizing the posterior is usually the wrong thing to do in high-dimensions because the mode is not a typical set. - -## Caveats on Gradients and Performance - -A few notes on performance and gradients: - - 1. As this is using reverse-mode AD it will be efficient for fairly large systems as long as the ultimate value of your differentiable program. With a little extra work and unit tests, it could support structured matrices/etc. as well. - 2. Getting to much higher scales, where the `A,B,C,D` are so large that matrix-free operators are necessary, is feasible but will require generalizing those to LinearOperators. This would be reasonably easy for joint likelihood and feasible but possible for the Kalman filter. - 3. At this point, there is no support for forward-mode auto-differentiation. For smaller systems with a kalman filter, this should dominate the alternatives, and efficient forward-mode AD rules for the kalman filter exist (see the supplementary materials in the [Differentiable State Space Models](https://github.com/HighDimensionalEconLab/DifferentiableStateSpaceModels.jl) paper). However, it would be a significant amount of work to add end-to-end support and fulfill standard SciML interfaces, and perhaps waiting for [Enzyme](https://enzyme.mit.edu/julia/) or similar AD systems that provide both forward/reverse/mixed mode makes sense. - 4. Forward-mode AD is likely inappropriate for the joint-likelihood based models, since the dimensionality of the noise is always large. - 5. The gradient rules are written using [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) so in theory they will work with any supporting AD. In practice, though, Zygote is the most tested, and other systems have inconsistent support for Julia at this time. diff --git a/docs/src/examples/quadratic_state_space_examples.md b/docs/src/examples/quadratic_state_space_examples.md deleted file mode 100644 index 522bc1c..0000000 --- a/docs/src/examples/quadratic_state_space_examples.md +++ /dev/null @@ -1,115 +0,0 @@ -# Quadratic State Space Examples - -Second-order state-space models here have pruning as in [Andreasen, Fernandez-Villaverde, and Rubio-Ramirez (2017)](https://www.sas.upenn.edu/%7Ejesusfv/Pruning.pdf). - -At this point, the package only supports linear time-invariant models without a separate `p` vector. The canonical form is - -```math -u_{n+1} = A_0 + A_1 u_n + u_n^{\top} A_2 u_n + B w_{n+1} -``` - -with - -```math -z_n = C_0 + C_1 u_n + u_n^{\top} C_2 u_n + v_n -``` - -and optionally $v_n \sim N(0, D)$ and $w_{n+1} \sim N(0,I)$. If you pass noise into the solver, it no longer needs to be Gaussian. - -!!! note - - Quadratic state-space models do not have the full feature coverage as the linear models. In particular, the auto-differentiation rules are only currently implemented for the `logpdf` required for estimation, and the simulation doesn't have much flexibility on which model elements can be missing. - -## Simulating a Quadratic (and Time-Invariant) State Space Model - -Creating a `QuadraticStateSpaceModel` is similar to the Linear version described previously. - -```@example 2 -using DifferenceEquations, LinearAlgebra, Distributions, Random, Plots, DataFrames, Zygote, - DiffEqBase -A_0 = [-7.824904812740593e-5, 0.0] -A_1 = [0.95 6.2; - 0.0 0.2] -A_2 = cat([-0.0002 0.0334; 0.0 0.0], - [0.034 3.129; 0.0 0.0]; dims = 3) -B = [0.0; 0.01;;] # matrix -C_0 = [7.8e-5, 0.0] -C_1 = [0.09 0.67; - 1.00 0.00] -C_2 = cat([-0.00019 0.0026; 0.0 0.0], - [0.0026 0.313; 0.0 0.0]; dims = 3) -D = [0.01, 0.01] # diagonal observation noise -u0 = zeros(2) -T = 30 - -prob = QuadraticStateSpaceProblem( - A_0, A_1, A_2, B, u0, (0, T); C_0, C_1, C_2, observables_noise = D, syms = [:a, :b]) -sol = solve(prob) -``` - -As in the linear case, this model can be simulated and plotted - -```@example 2 -plot(sol) -``` - -And the observables and noise can be stored - -```@example 2 -observables = hcat(sol.z...) # Observables required to be matrix. Issue #55 -observables = observables[:, 2:end] # see note above on likelihood and timing -noise = sol.W -``` - -Ensembles work as well, - -```@example 2 -trajectories = 50 -u0_dist = MvNormal([1.0 0.1; 0.1 1.0]) # mean zero initial conditions -prob = QuadraticStateSpaceProblem(A_0, A_1, A_2, B, u0_dist, (0, T); C_0, C_1, - C_2, observables_noise = D, syms = [:a, :b]) -ens_sol = solve(EnsembleProblem(prob), DirectIteration(), EnsembleThreads(); trajectories) -summ = EnsembleSummary(ens_sol) # calculate summarize statistics such as quantiles -plot(summ) -``` - -## Joint Likelihood with Noise - -To calculate the likelihood, the Kalman Filter is no longer applicable. However, we can still calculate the joint likelihood as we did in the linear examples. Using the simulated observables and noise, - -```@example 2 -function joint_likelihood_quad(A_0, A_1, A_2, B, C_0, C_1, C_2, D, u0, noise, observables) - prob = QuadraticStateSpaceProblem(A_0, A_1, A_2, B, u0, (0, size(observables, 2)); C_0, - C_1, C_2, observables, observables_noise = D, noise) - return solve(prob).logpdf -end -u0 = [0.0, 0.0] -joint_likelihood_quad(A_0, A_1, A_2, B, C_0, C_1, C_2, D, u0, noise, observables) -``` - -Which, in turn, can itself be differentiated. - -```@example 2 -gradient( - (A_0, - A_1, - A_2, - B, - C_0, - C_1, - C_2, - noise) -> joint_likelihood_quad( - A_0, A_1, A_2, B, C_0, C_1, C_2, D, u0, noise, observables), - A_0, - A_1, - A_2, - B, - C_0, - C_1, - C_2, - noise) -``` - -Note that this is not only calculating the gradient of the likelihood with respect to the underlying canonical representations for the quadratic state space form, but also the entire noise vector. - -As in the linear case, this likelihood calculation can be nested such that a separate differentiable function could generate the quadratic state space model, and the gradients could be over a smaller set of structural parameters. diff --git a/docs/src/getting_started.md b/docs/src/getting_started.md new file mode 100644 index 0000000..9a4269a --- /dev/null +++ b/docs/src/getting_started.md @@ -0,0 +1,91 @@ +# Getting Started + +This tutorial walks through the core workflow of DifferenceEquations.jl: defining a linear state-space model, simulating it, and computing likelihoods. + +## Creating a Linear State Space Model + +A [`LinearStateSpaceProblem`](@ref) represents a linear time-invariant state-space model: + +```math +u_{n+1} = A\, u_n + B\, w_{n+1}, \qquad z_n = C\, u_n + v_n +``` + +Define the model primitives, create a problem, and solve: + +```@example getting_started +using DifferenceEquations, LinearAlgebra, Random +A = [0.95 6.2; 0.0 0.2] +B = [0.0; 0.01;;] # 2×1 Matrix (Julia's ;; creates a column matrix) +C = [0.09 0.67; 1.00 0.00] +u0 = zeros(2) +T = 10 + +prob = LinearStateSpaceProblem(A, B, u0, (0, T); C) +sol = solve(prob) +sol.u[end] +``` + +## Computing Likelihood + +To compute log-likelihoods, provide `observables` (a `Vector{Vector}` of length `T`) and `observables_noise` (the observation noise covariance matrix — e.g., `Diagonal(d)` for diagonal noise or `Symmetric(H * H')` for a general covariance). + +!!! note "Timing convention" + + Observations correspond to ``z_1, z_2, \ldots, z_T`` -- that is, the states *after* the initial condition. Pass `T` observation vectors for a `tspan` of `(0, T)`. + +First, simulate some data to use as observables: + +```@example getting_started +Random.seed!(123) +D = Diagonal([0.1, 0.1]) # diagonal observation noise covariance matrix +prob_sim = LinearStateSpaceProblem(A, B, u0, (0, T); C, observables_noise = D) +sol_sim = solve(prob_sim) + +# Extract observations at times 1..T (skip the initial condition at t=0) +observables = sol_sim.z[2:end] +length(observables) # should be T +``` + +Compute the **joint** log-likelihood given fixed noise using [`DirectIteration`](@ref): + +```@example getting_started +prob_lik = LinearStateSpaceProblem(A, B, u0, (0, length(observables)); C, + observables = observables, + observables_noise = D, + noise = sol_sim.W) +sol_lik = solve(prob_lik) +sol_lik.logpdf # joint log-likelihood +``` + +For the **marginal** log-likelihood (integrating out the latent noise), use a [`KalmanFilter`](@ref) by additionally providing a Gaussian prior on `u0`: + +```@example getting_started +prob_kf = LinearStateSpaceProblem(A, B, u0, (0, length(observables)); C, + observables = observables, + observables_noise = D, + u0_prior_mean = zeros(2), + u0_prior_var = Matrix(1.0I, 2, 2)) +sol_kf = solve(prob_kf) # KalmanFilter is auto-selected +sol_kf.logpdf # marginal log-likelihood +``` + +## DataFrame Conversion + +Convert the state trajectory to a `DataFrame` for analysis. Column names come from `syms` if provided: + +```@example getting_started +using DataFrames +prob_df = LinearStateSpaceProblem(A, B, u0, (0, T); C, + syms = (:capital, :productivity)) +sol_df = solve(prob_df) +DataFrame(sol_df) +``` + +## Next Steps + + - [Linear Simulation](@ref) -- detailed simulation examples, symbolic indexing, fixed noise, and ensemble runs. + - [Likelihood & Kalman Filter](@ref) -- marginal and joint likelihood, gradient-based estimation. + - [Quadratic Models](@ref) -- second-order perturbation models. + - [Generic Callbacks](@ref) -- user-defined nonlinear transition and observation functions. + - [Workspace API](@ref) -- allocation-free repeated solves for performance-critical loops. + - [Enzyme AD](@ref) -- differentiating through solvers and filters with Enzyme.jl. diff --git a/docs/src/index.md b/docs/src/index.md index ddb551a..2e0a715 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,12 +1,15 @@ # DifferenceEquations.jl -This package simulates for **initial value problems** for deterministic and stochastic difference equations, with or without a separate observation equation. In addition, the package provides likelihoods for some standard filters for estimating state-space models. +DifferenceEquations.jl solves initial value problems for deterministic and stochastic difference equations, with differentiable solvers and filters. Automatic differentiation is powered by [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) (reverse and forward mode). The package is part of the [SciML](https://sciml.ai/) ecosystem. -Relative to existing solvers, this package is intended to provide **differentiable solvers and filters**. For example, you can simulate a linear gaussian state space model and find the gradient of the solution with respect to the model primitives. Similarly, the likelihood for of Kalman Filter can itself be differentiated with respect to the underlying model primitives. This makes the package especially amenable to estimation and calibration, where the entire solution blocks become auto-differentiable. +## Features -!!! note - - Boundary value problems and difference-algebraic equations are not in scope. See [DifferentiableStateSpaceModels.jl](https://github.com/HighDimensionalEconLab/DifferentiableStateSpaceModels.jl) for experimental support for perturbation solutions and DSGEs. + - **Linear, quadratic, and generic state-space models** -- [`LinearStateSpaceProblem`](@ref), [`QuadraticStateSpaceProblem`](@ref), [`PrunedQuadraticStateSpaceProblem`](@ref), and [`StateSpaceProblem`](@ref) with user-defined callbacks. + - **Kalman filter** for computing the marginal log-likelihood of linear Gaussian models via [`KalmanFilter`](@ref). + - **Differentiable via [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) and [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl)** -- Enzyme reverse/forward mode for all problem sizes; ForwardDiff as a lightweight alternative for small models (N ≤ 5). + - **StaticArrays support** for small models where heap allocations dominate runtime. + - **Workspace API** -- [`StateSpaceWorkspace`](@ref) with `init` / `solve!` for allocation-free repeated solves (useful inside AD and tight loops). + - **SciML ecosystem integration** -- `EnsembleProblem` for Monte Carlo, plot recipes, `DataFrame` conversion, symbolic indexing, and `remake`. ## Installation @@ -17,84 +20,48 @@ using Pkg Pkg.add("DifferenceEquations") ``` -For additional functionality, you may want to add `Plots, DiffEqBase`. If you want to explore differentiable filters, you can install `Zygote` - -## Mathematical Specification of a Discrete Problem - -For comparison, see the specifications of the deterministic [Discrete Problem](https://diffeq.sciml.ai/latest/types/discrete_types/#Mathematical-Specification-of-a-Discrete-Problem) (albeit with a small difference in timing conventions) and the [SDE Problem](https://diffeq.sciml.ai/latest/types/sde_types/). Other introductions can be found by [checking out DiffEqTutorials.jl](https://github.com/JuliaDiffEq/DiffEqTutorials.jl). +## Quick Example -The general class of problems intended to be supported in this package is to take an initial condition, $u_0$, and an evolution equation - -```math -u_{n+1} = f(u_n,p,t_n) + g(u_n,p,t_n) w_{n+1} +```@example index +using DifferenceEquations, LinearAlgebra +A = [0.95 0.1; 0.0 0.2] +B = [0.0; 0.01;;] +u0 = zeros(2) +T = 10 +prob = LinearStateSpaceProblem(A, B, u0, (0, T)) +sol = solve(prob) +sol.u[end] # final state ``` -for some functions $f$ and $g$, and where $w_{n+1}$ are IID random shocks to the evolution equation. The $p$ is a vector of potentially differentiable parameters. - -In addition, there is an optional observation equation - -```math -z_n = h(u_n, p, t_n) + v_n -``` - -where $v_n$ is noisy observation error and the size of $z_n$ may be different from $u_n$. - -A few notes on the structure: +## Mathematical Background - 1. Frequently, the $g$ provides the covariance structure, so a reasonable default is $w_{n+1} \sim N(0,I)$, and $v_n \sim N(0, D)$ is a common observation error for some covariance matrix $D$. - 2. If $f,g,h$ are all linear, the shocks are both gaussian, and the prior on the latent space is gaussian, then this is a linear gaussian state-space model. Kalman filters can be used to calculate marginal likelihoods, and simulations can be executed with very little overhead. - 3. ``t_n`` is the current time at which the map is applied, where ``t_n = t_0 + n*dt`` (with `dt=1` being the default). - 4. If $f, g, h$ are not functions of time, then it is a time-invariant state-space model. - -## Likelihood and Filtering Calculations - -Certain `solve` algorithms will run a filter on the unobservable `u` states and compare to the `observables` if provided. In that case, it might do so (1) with unobservable $w_n$ noise; or (2) conditioning on a particular sequence of $w_{n+1}$ shocks, where the likelihood depends on the unknown observational error $v_n$. - -If an algorithm is given for the filtering, then the return type of `solve` will have access to a `logpdf` for the log likelihood. In addition, the solution will provide information on the sequence of posteriors (and smoothed values, if required). - -### Joint Likelihood - -In the case of a joint-likelihood where the `noise` (i.e. $w_n$) is given, it is not a hidden markov model and the log likelihood simply accumulates the likelihood of each observation. The timing is such that given a $u_0$ which is fixed (and often added to the likelihood separately), and observables $z \equiv \{z_1, \ldots z_N\}$ and noise $w \equiv \{w_1, \ldots w_N\}$ then, +The general class of discrete-time state-space models supported by this package takes an initial condition ``u_0`` and an evolution equation ```math -\mathcal{L}(z, u_0, w) = \sum_{n=1}^N \log P\left(v_n, t_n, w_n\right) +u_{n+1} = f(u_n, p, t_n) + g(u_n, p, t_n)\, w_{n+1} ``` -where +for transition function ``f``, noise coefficient ``g``, and IID noise shocks ``w_{n+1}``. The parameter vector ``p`` is potentially differentiable. -```math -v_n = z_n - h(u_n, p, t_n)\\ -u_{n+1} = f(u_n,p,t_n) + g(u_n,p,t_n) w_{n+1} -``` - -The density, $P$, is in the case of the typical Gaussian errors, it would be +An optional observation equation relates the latent state to measured data: ```math -z_n - h(u_n, p, t_n) \sim N(0, D) = P +z_n = h(u_n, p, t_n) + v_n ``` -Ultimately, IID Gaussian observation noise is not required, and though the package currently only supports gaussian observation noise with a diagonal covariance matrix, it could be adapted without significant changes. - -### Linear Filtering for the Marginal Likelihood +where ``v_n`` is observation noise and ``z_n`` may have a different dimension from ``u_n``. -When the system is linear and the prior is gaussian, there is an exact likelihood for the marginal likelihood using the [Kalman Filter](https://en.wikipedia.org/wiki/Kalman_filter#Marginal_likelihood). Unlike the previous example, this is a marginal likelihood and not conditional on the noise, $w$. See the [Kalman Filter Likelihood](https://en.wikipedia.org/wiki/Kalman_filter#Marginal_likelihood) for more details. +### Specializations -## Current Status + - **Linear**: ``f(u) = A\,u``, ``g(u) = B``, ``h(u) = C\,u``. Solved by [`DirectIteration`](@ref) or [`KalmanFilter`](@ref). See [`LinearStateSpaceProblem`](@ref). + - **Quadratic**: Adds second-order terms ``u^\top A_2\, u`` to both transition and observation. Useful for pruned perturbation solutions of DSGE models. See [`QuadraticStateSpaceProblem`](@ref) and [`PrunedQuadraticStateSpaceProblem`](@ref). + - **Generic**: User-supplied `transition` and `observation` callbacks. See [`StateSpaceProblem`](@ref). -At this point, the package does not cover all the variations on these features. In particular, +When the system is linear, the shocks are Gaussian, and a Gaussian prior is provided, the [`KalmanFilter`](@ref) computes the exact marginal log-likelihood. For all other cases, [`DirectIteration`](@ref) iterates the state forward and (optionally) accumulates a joint log-likelihood. - 1. It only supports linear and quadratic $f, g, h$ functions. General $f,g$ simulation are relatively easy to add, but full SciML compliance would require experience with those APIs. The custom rrule for those is also a straightforward variation on the existing linear version. - 2. It only supports time-invariant functions. - 3. There is limited support for non-Gaussian $w_n$ and $v_n$ processes. - 4. It does not support linear or quadratic functions parameterized by the $p$ vector for differentiation. - 5. There are some hard-coded types that prevent it from working with fully generic arrays. - 6. It does not support in-place vs. out-of-place, nor support static arrays, nor matrix-free linear operators. - 7. While many functions in the SciML framework are working, support is incomplete. - 8. There is no complete coverage of gradients for the solution for all parameter inputs/etc. - 9. The package does not support non-gaussian observation noise and is inconsistent with SciML noise process data structures. -10. Many cleanup steps are necessary for full SciML compliance (e.g., enable passing in vectors-of-vectors or noise/observations, standard SciML dispatching). +!!! note -To help contribute to filling in these features, see the [issues](https://github.com/SciML/DifferenceEquations.jl/issues). + Boundary value problems and difference-algebraic equations are not in scope. ## Contributing @@ -104,7 +71,7 @@ To help contribute to filling in these features, see the [issues](https://github - See the [SciML Style Guide](https://github.com/SciML/SciMLStyle) for common coding practices and other style decisions. - There are a few community forums: - + + The #diffeq-bridged and #sciml-bridged channels in the [Julia Slack](https://julialang.org/slack/) + The #diffeq-bridged and #sciml-bridged channels in the diff --git a/docs/src/tutorials/conditional_likelihood.md b/docs/src/tutorials/conditional_likelihood.md new file mode 100644 index 0000000..12879ef --- /dev/null +++ b/docs/src/tutorials/conditional_likelihood.md @@ -0,0 +1,195 @@ +# Conditional Likelihood + +The [`ConditionalLikelihood`](@ref) algorithm computes the prediction error +decomposition log-likelihood for fully-observed state-space models. At each +time step, it predicts the next observation from the *observed* current state +(not the model-predicted state), and accumulates the Gaussian log-likelihood +of the innovation (prediction error). + +This is the standard approach for maximum likelihood estimation of AR(1), +VAR(1), nonlinear DSGE, and other models where the state is directly observed. + +## When to Use Each Algorithm + +| Algorithm | Use Case | +|-----------|----------| +| [`DirectIteration`](@ref) | Simulation, or joint likelihood given a fixed noise sequence | +| [`KalmanFilter`](@ref) | Marginal likelihood for linear models with latent (unobserved) noise | +| [`ConditionalLikelihood`](@ref) | MLE for fully-observed models (AR, VAR, nonlinear) | + +## Mathematical Formulation + +Given a state-space model with transition ``x_{t+1} = f(x_t, w_t)`` and +observation ``z_t = g(x_t)``, the conditional log-likelihood is: + +```math +\log L = \sum_{t=1}^{T} \left[ -\frac{1}{2} \left( M \log(2\pi) + \log|R| + \nu_t^\top R^{-1} \nu_t \right) \right] +``` + +where ``\nu_t = y_t - g(f(y_{t-1}, w_t))`` is the innovation (prediction error), +``R`` is the observation noise covariance, and ``M`` is the observation dimension. + +The key difference from `DirectIteration` is that at each step the state is +**clamped to the observation**: the prediction uses ``f(y_{t-1}, \ldots)`` +rather than ``f(f(\ldots, u_0), \ldots)``. + +## AR(1) Example + +```@example cond_lik +using DifferenceEquations, LinearAlgebra, Random + +rho_true = 0.8 +sigma_e = 0.5 +T = 200 + +# Simulate AR(1) data using the package +Random.seed!(42) +prob_sim = LinearStateSpaceProblem( + fill(rho_true, 1, 1), fill(sigma_e, 1, 1), [0.0], (0, T)) +sol_sim = solve(prob_sim) +y = sol_sim.u[2:end] # observed states y_1, ..., y_T + +# Compute conditional log-likelihood +prob = LinearStateSpaceProblem( + fill(rho_true, 1, 1), nothing, [0.0], (0, T); + observables = y, + observables_noise = Diagonal([sigma_e^2]), +) +sol = solve(prob, ConditionalLikelihood()) +sol.logpdf +``` + +## VAR(1) Example + +The same approach works for multivariate models: + +```@example cond_lik +A = [0.8 0.1; -0.1 0.7] +B = [0.5 0.0; 0.0 0.5] +T_var = 100 + +# Simulate VAR(1) data +Random.seed!(123) +prob_sim_var = LinearStateSpaceProblem(A, B, zeros(2), (0, T_var)) +sol_sim_var = solve(prob_sim_var) +y_var = sol_sim_var.u[2:end] + +prob_var = LinearStateSpaceProblem( + A, nothing, zeros(2), (0, T_var); + observables = y_var, + observables_noise = Diagonal([0.25, 0.25]), +) +sol_var = solve(prob_var, ConditionalLikelihood()) +sol_var.logpdf +``` + +## Nonlinear Example with StateSpaceProblem + +`ConditionalLikelihood` works with all problem types, including user-defined +nonlinear callbacks via [`StateSpaceProblem`](@ref). + +Here we estimate a nonlinear AR(1): ``x_{t+1} = \rho x_t + \alpha x_t^2 + e_t``. +We first simulate data using a generic `StateSpaceProblem`, then compute the +conditional likelihood. + +```@example cond_lik +rho_nl = 0.8 +alpha_nl = 0.05 +sigma_nl = 0.3 +T_nl = 100 + +# Nonlinear transition (supports both mutable and immutable arrays) +function nl_transition!!(x_next, x, w, p, t) + (; rho, alpha) = p + val = rho * x[1] + alpha * x[1]^2 + if ismutable(x_next) + x_next[1] = val + w[1] + return x_next + else + return typeof(x)(val + w[1]) + end +end + +p_nl = (; rho = rho_nl, alpha = alpha_nl) + +# Simulate nonlinear data +Random.seed!(99) +prob_sim_nl = StateSpaceProblem( + nl_transition!!, nothing, [0.0], (0, T_nl), p_nl; + n_shocks = 1, n_obs = 0) +sol_sim_nl = solve(prob_sim_nl) +y_nl = sol_sim_nl.u[2:end] + +# Conditional likelihood (no noise in prediction, noise only in obs) +function nl_transition_noiseless!!(x_next, x, w, p, t) + (; rho, alpha) = p + val = rho * x[1] + alpha * x[1]^2 + if ismutable(x_next) + x_next[1] = val + return x_next + else + return typeof(x)(val) + end +end + +prob_nl = StateSpaceProblem( + nl_transition_noiseless!!, nothing, [0.0], (0, T_nl), p_nl; + n_shocks = 0, n_obs = 0, + observables = y_nl, + observables_noise = Diagonal([sigma_nl^2]), +) +sol_nl = solve(prob_nl, ConditionalLikelihood()) +sol_nl.logpdf +``` + +## Maximum Likelihood Estimation + +`ConditionalLikelihood` is fully differentiable with ForwardDiff.jl, making it +straightforward to use with gradient-based optimization for MLE. + +Use `save_everystep=false` when you only need the log-likelihood (not the +full trajectory). This reduces allocations and speeds up ForwardDiff +gradient computation — up to 7x faster with StaticArrays. + +```@example cond_lik +using ForwardDiff + +# Negative log-likelihood as a function of rho +function neg_loglik(rho_vec) + T_el = eltype(rho_vec) + A_opt = fill(rho_vec[1], 1, 1) + prob_opt = LinearStateSpaceProblem( + A_opt, nothing, [zero(T_el)], (0, length(y)); + observables = y, + observables_noise = Diagonal([T_el(sigma_e^2)]), + ) + return -solve(prob_opt, ConditionalLikelihood(); save_everystep = false).logpdf +end + +# Gradient at the true value +grad = ForwardDiff.gradient(neg_loglik, [rho_true]) +``` + +The gradient is near zero at the true parameter value, confirming the MLE +is correctly identified. For full optimization, use Optimization.jl with +`AutoForwardDiff()` and an optimizer like `LBFGS()`. + +## Using the Workspace + +For repeated solves (e.g., inside an optimizer), use the `init`/`solve!` +pattern to avoid repeated memory allocation: + +```@example cond_lik +ws = init(prob, ConditionalLikelihood()) +sol_ws = solve!(ws) +sol_ws.logpdf +``` + +With `save_everystep=false`, the workspace allocates only 2-element +buffers: + +```@example cond_lik +ws_ep = init(prob, ConditionalLikelihood(); save_everystep = false) +sol_ep = solve!(ws_ep) +length(sol_ep.u) # 2: [u_initial, u_final] +``` diff --git a/docs/src/tutorials/generic_callbacks.md b/docs/src/tutorials/generic_callbacks.md new file mode 100644 index 0000000..adea759 --- /dev/null +++ b/docs/src/tutorials/generic_callbacks.md @@ -0,0 +1,146 @@ +# Generic Callbacks + +The [`StateSpaceProblem`](@ref) type provides a fully generic interface for +discrete-time state-space models. Instead of specifying matrices, you supply +callback functions for the state transition and observation equations. This is +useful for nonlinear models, time-varying dynamics, or any structure that does not +fit the linear or quadratic templates. + +## Callback Signatures + +The two callbacks follow the "bang-bang" convention used throughout SciML: for +mutable arrays, mutate the output buffer in place and return it; for immutable +arrays (e.g., `SVector`), ignore the buffer and return a new value. + +**Transition function:** `f!!(x_next, x, w, p, t) -> x_next` +- `x_next`: pre-allocated output buffer (mutate in place for mutable arrays) +- `x`: current state +- `w`: noise shock at this step (or `nothing` if `n_shocks = 0`) +- `p`: parameters passed to the problem +- `t`: integer time index (0-based) + +**Observation function:** `g!!(y, x, p, t) -> y` +- `y`: pre-allocated output buffer +- `x`: current state +- `p`: parameters +- `t`: integer time index (0-based) + +Pass `nothing` for the observation function if no observations are needed. + +## Example: Linear Model via Callbacks + +We can reproduce the behavior of [`LinearStateSpaceProblem`](@ref) using generic +callbacks. This verifies the interface and demonstrates the pattern. + +```@example generic +using DifferenceEquations, LinearAlgebra, Random + +A = [0.95 6.2; 0.0 0.2] +B = [0.0; 0.01;;] +C = [0.09 0.67; 1.00 0.00] + +linear_f!! = (x_next, x, w, p, t) -> begin + mul!(x_next, p.A, x) + mul!(x_next, p.B, w, 1.0, 1.0) + return x_next +end +linear_g!! = (y, x, p, t) -> begin + mul!(y, p.C, x) + return y +end +p = (; A, B, C) +u0 = zeros(2) +T = 10 + +prob = StateSpaceProblem(linear_f!!, linear_g!!, u0, (0, T), p; + n_shocks = 1, n_obs = 2, syms = (:a, :b)) +sol = solve(prob) +``` + +The solution has the same structure as the linear case: + +```@example generic +sol.u # state trajectory, Vector{Vector} +``` + +```@example generic +sol.z # observations, Vector{Vector} +``` + +We can verify this matches the matrix-based formulation: + +```@example generic +Random.seed!(123) +sol_generic = solve(StateSpaceProblem(linear_f!!, linear_g!!, u0, (0, T), p; + n_shocks = 1, n_obs = 2)) + +Random.seed!(123) +sol_linear = solve(LinearStateSpaceProblem(A, B, u0, (0, T); C)) + +sol_generic.u ≈ sol_linear.u +``` + +## Example: Nonlinear Growth Model + +`StateSpaceProblem` handles arbitrary nonlinear dynamics. Here is a discrete-time logistic growth model with process noise, demonstrating that the generic callback interface works for any transition function: + +```@example generic +# Nonlinear transition: logistic growth with stochastic shocks +logistic_f!! = (x_next, x, w, p, t) -> begin + x_next[1] = p.r * x[1] * (1.0 - x[1] / p.K) + p.sigma * w[1] + return x_next +end + +# Observation: noisy measurement of population +logistic_g!! = (y, x, p, t) -> begin + y[1] = x[1] + return y +end + +p_logistic = (; r = 1.5, K = 100.0, sigma = 2.0) +u0_logistic = [50.0] + +prob_logistic = StateSpaceProblem(logistic_f!!, logistic_g!!, u0_logistic, (0, 50), p_logistic; + n_shocks = 1, n_obs = 1, syms = (:population,), obs_syms = (:measured_pop,)) +sol_logistic = solve(prob_logistic) +``` + +## Parametric Models and `remake` + +The `p` argument holds all model parameters. When exploring different parameter +values, use `remake` to create a new problem without reallocating everything. + +```@example generic +new_u0 = [0.1, 0.2] +new_p = (; A = A * 0.99, B, C) + +prob2 = remake(prob; u0 = new_u0, p = new_p) +sol2 = solve(prob2) +sol2.u[1] # new initial condition +``` + +The `remake` function preserves all keyword arguments (noise, observables, syms, etc.) +from the original problem. + +## Symbolic Indexing + +`StateSpaceProblem` supports the same symbolic indexing as the linear problem types. +Pass `syms` for state variable names and `obs_syms` for observation names. + +```@example generic +D = Diagonal([0.1, 0.1]) +noise = sol.W # reuse noise from earlier + +prob_sym = StateSpaceProblem(linear_f!!, linear_g!!, u0, (0, T), p; + n_shocks = 1, n_obs = 2, + syms = (:capital, :productivity), + obs_syms = (:output, :consumption), + observables_noise = D, noise) +sol_sym = solve(prob_sym) + +sol_sym[:capital] # state time series by name +``` + +```@example generic +sol_sym[:output] # observation time series by name +``` diff --git a/docs/src/tutorials/linear_likelihood.md b/docs/src/tutorials/linear_likelihood.md new file mode 100644 index 0000000..4e357c2 --- /dev/null +++ b/docs/src/tutorials/linear_likelihood.md @@ -0,0 +1,121 @@ +# Likelihood & Kalman Filter + +DifferenceEquations.jl supports two approaches to computing the log-likelihood of +observed data: + +- **Marginal likelihood** via the [`KalmanFilter`](@ref): the probability of the + observed data conditioned on the core model parameters (``A, B, C``, etc.) and the + initial condition prior, with the latent noise sequence analytically integrated out. + This is the standard approach for maximum likelihood estimation (MLE) of structural + parameters. +- **Joint likelihood** via [`DirectIteration`](@ref): the probability of the observed + data AND a specific noise realization, conditioned on the core parameters and initial + conditions. Requires fixing the noise sequence. Useful in Bayesian methods where the + noise is sampled as part of inference (e.g., particle MCMC, HMC on latent variables). + +Both approaches are fully differentiable with Enzyme.jl and ForwardDiff.jl. + +## Simulating Observations + +First, let us simulate a model with observation noise to produce synthetic data. + +```@example linear_lik +using DifferenceEquations, LinearAlgebra, Distributions, Random + +A = [0.95 6.2; 0.0 0.2] +B = [0.0; 0.01;;] +C = [0.09 0.67; 1.00 0.00] +D = Diagonal([0.1, 0.1]) +u0 = zeros(2) +T = 80 + +Random.seed!(42) +prob_sim = LinearStateSpaceProblem(A, B, MvNormal(zeros(2), I(2)), (0, T); + C, observables_noise = D) +sol_sim = solve(prob_sim) +sol_sim.z # simulated observations with noise (Vector{Vector}) +``` + +## Marginal Likelihood with the Kalman Filter + +The Kalman filter computes the marginal log-likelihood by integrating out the latent +noise sequence. It requires a Gaussian prior on the initial state (`u0_prior_mean`, +`u0_prior_var`) and Gaussian observation noise (`observables_noise`). + +!!! note "Timing convention" + + Observations correspond to ``z_1, z_2, \ldots, z_T`` (predictions starting from + the second state). When the simulation produces `T+1` observation vectors + (including ``z_0``), pass `sol.z[2:end]` as the observables. The length of + `observables` must equal the integer distance of `tspan`. + +```@example linear_lik +observables = sol_sim.z[2:end] # Vector{Vector}, length T + +u0_prior_mean = zeros(2) +u0_prior_var = Matrix(1.0 * I(2)) + +prob_kalman = LinearStateSpaceProblem(A, B, u0, (0, length(observables)); C, + observables_noise = D, observables, + u0_prior_mean, u0_prior_var) + +# KalmanFilter is auto-selected when priors + observables + noise covariance are given +sol_kalman = solve(prob_kalman) +sol_kalman.logpdf # marginal log-likelihood +``` + +The Kalman solution also provides filtered state estimates in `sol.u` and posterior +covariances in `sol.P`: + +```@example linear_lik +sol_kalman.u[end] # filtered mean at the final time step +``` + +```@example linear_lik +sol_kalman.P[end] # posterior covariance at the final time step +``` + +## Joint Likelihood with Fixed Noise + +When both `noise` and `observables` are provided, `DirectIteration` iterates the +state transition forward using the given noise and accumulates the joint +log-likelihood of the observations. + +```@example linear_lik +noise = sol_sim.W # realized noise from simulation (Vector{Vector}) + +prob_joint = LinearStateSpaceProblem(A, B, u0, (0, length(observables)); C, + observables_noise = D, observables, noise) +sol_joint = solve(prob_joint) +sol_joint.logpdf # joint log-likelihood conditioned on noise +``` + +## Composing Structural Models + +In practice, the state-space matrices ``A, B, C, D`` are often generated from deeper +structural (or "deep") parameters. The entire pipeline from structural parameters +to log-likelihood is differentiable. + +```@example linear_lik +function generate_model(beta) + A = [beta 6.2; 0.0 0.2] + B = [0.0; 0.001;;] + C = [0.09 0.67; 1.00 0.00] + D = Diagonal([0.01, 0.01]) + return (; A, B, C, D) +end + +function kalman_model_likelihood(beta, observables) + mod = generate_model(beta) + prob = LinearStateSpaceProblem(mod.A, mod.B, zeros(2), (0, length(observables)); + C = mod.C, observables, observables_noise = mod.D, + u0_prior_mean = zeros(2), u0_prior_var = Matrix(1.0 * I(2))) + return solve(prob).logpdf +end + +kalman_model_likelihood(0.95, observables) +``` + +## Next Steps + +To differentiate these likelihoods with Enzyme.jl and use them inside an optimization loop, see the [Enzyme AD](@ref) page. It covers the workspace-based `init`/`solve!` pattern required by Enzyme, gradient computation for both `DirectIteration` and `KalmanFilter`, and a full maximum-likelihood example with Optimization.jl. diff --git a/docs/src/tutorials/linear_simulation.md b/docs/src/tutorials/linear_simulation.md new file mode 100644 index 0000000..5eeb665 --- /dev/null +++ b/docs/src/tutorials/linear_simulation.md @@ -0,0 +1,180 @@ +# Linear Simulation + +This tutorial walks through simulating a linear time-invariant state-space model of the form + +```math +u_{n+1} = A \, u_n + B \, w_{n+1} +``` + +with observation equation + +```math +z_n = C \, u_n + v_n +``` + +where ``w_{n+1} \sim N(0, I)`` and optionally ``v_n \sim N(0, D)``. + +## Simulating a Linear State Space Model + +We begin by defining system matrices and creating a [`LinearStateSpaceProblem`](@ref). +Passing `C` enables the observation equation, `observables_noise` adds Gaussian +measurement noise to the simulated observations, and `syms` attaches symbolic names +to the state variables. + +```@example linear_sim +using DifferenceEquations, LinearAlgebra, Distributions, Random, Plots, DiffEqBase + +A = [0.95 6.2; 0.0 0.2] +B = [0.0; 0.01;;] +C = [0.09 0.67; 1.00 0.00] +D = Diagonal([0.1, 0.1]) +u0 = zeros(2) +T = 10 + +prob = LinearStateSpaceProblem(A, B, u0, (0, T); C, observables_noise = D, syms = (:a, :b)) +sol = solve(prob) +``` + +## Plotting + +The solution object integrates with Plots.jl recipes. When `syms` are provided, +the legend labels correspond to those names. + +```@example linear_sim +plot(sol) +``` + +## Accessing the Solution + +The state trajectory is stored in `sol.u` as a `Vector{Vector}`. Standard indexing +works on the solution object directly. + +```@example linear_sim +sol.u # full state trajectory, Vector{Vector} +``` + +Access a specific time step: + +```@example linear_sim +sol[2] # state at t=1 (second entry; sol[1] is the initial condition u₀) +``` + +Or a specific element of the last state: + +```@example linear_sim +sol[end][1] # first element of the final state +``` + +Observations and noise are also available: + +```@example linear_sim +sol.z # observed trajectory, Vector{Vector} +``` + +```@example linear_sim +sol.W # realized noise sequence, Vector{Vector} +``` + +## Symbolic Indexing + +When `syms` are provided, you can extract the full time series for a state variable +by name: + +```@example linear_sim +sol[:a] # time series for state variable :a +``` + +If `obs_syms` are also provided, observation variables can be accessed similarly: + +```@example linear_sim +prob_obs = LinearStateSpaceProblem(A, B, u0, (0, T); C, observables_noise = D, + syms = (:a, :b), obs_syms = (:output, :consumption)) +sol_obs = solve(prob_obs) +sol_obs[:output] # time series for observation :output +``` + +## Fixed Noise + +We can extract the noise from a previous simulation and use it to reproduce a +trajectory (possibly with different initial conditions). This is essential for +joint likelihood computations. + +```@example linear_sim +noise = sol.W # extract realized noise (Vector{Vector}) +u0_new = [0.1, 0.0] +prob_fixed = LinearStateSpaceProblem(A, B, u0_new, (0, T); C, observables_noise = D, + syms = (:a, :b), noise) +sol_fixed = solve(prob_fixed) +plot(sol_fixed) +``` + +## Impulse Response Functions + +An impulse response function (IRF) applies a one-time unit shock at the first period +and traces the system's response. We construct this by passing a fixed noise sequence +where only the first entry is nonzero. + +```@example linear_sim +function irf(A, B, C, T = 20) + noise = [[i == 1 ? 1.0 : 0.0] for i in 1:T] + problem = LinearStateSpaceProblem(A, B, zeros(2), (0, T); C, noise, syms = (:a, :b)) + return solve(problem) +end +plot(irf(A, B, C)) +``` + +## Deterministic Dynamics (`B = nothing`) + +When the model has no process noise, pass `B = nothing`. The solver will skip +noise generation entirely. No `sol.W` is produced. + +```@example linear_sim +prob_det = LinearStateSpaceProblem(A, nothing, [1.0, 0.5], (0, T); C, syms = (:a, :b)) +sol_det = solve(prob_det) +sol_det.W === nothing # no noise generated +``` + +```@example linear_sim +plot(sol_det) +``` + +## No Observation Equation (`C = nothing`) + +When you only need the state trajectory and don't require observations, +omit `C` (or pass `C = nothing`). No `sol.z` is produced. + +```@example linear_sim +prob_no_obs = LinearStateSpaceProblem(A, B, u0, (0, T); syms = (:a, :b)) +sol_no_obs = solve(prob_no_obs) +sol_no_obs.z === nothing # no observations +``` + +```@example linear_sim +plot(sol_no_obs) +``` + +## Random Initial Conditions + +Passing a `Distribution` for `u0` draws a random initial state at each solve. + +```@example linear_sim +u0_dist = MvNormal([1.0 0.1; 0.1 1.0]) # zero-mean Gaussian +prob_rand = LinearStateSpaceProblem(A, nothing, u0_dist, (0, T); C) +sol_rand = solve(prob_rand) +plot(sol_rand) +``` + +## Ensemble Simulations + +The SciML `EnsembleProblem` interface runs many independent simulations in parallel. +Each trajectory draws a fresh initial condition (when `u0` is a distribution) and/or +fresh noise. `EnsembleSummary` computes quantile bands across the ensemble. + +```@example linear_sim +trajectories = 50 +prob_ens = LinearStateSpaceProblem(A, B, u0_dist, (0, T); C) +ensemble_sol = solve(EnsembleProblem(prob_ens), DirectIteration(), EnsembleThreads(); + trajectories) +summ = EnsembleSummary(ensemble_sol) +plot(summ) +``` diff --git a/docs/src/tutorials/quadratic.md b/docs/src/tutorials/quadratic.md new file mode 100644 index 0000000..b231a5b --- /dev/null +++ b/docs/src/tutorials/quadratic.md @@ -0,0 +1,187 @@ +# Quadratic Models + +DifferenceEquations.jl supports second-order perturbation solutions through +[`QuadraticStateSpaceProblem`](@ref) and [`PrunedQuadraticStateSpaceProblem`](@ref). +These extend the linear state-space model with quadratic terms: + +```math +u_{n+1} = A_0 + A_1 \, u_n + u_n^\top A_2 \, u_n + B \, w_{n+1} +``` + +with observation equation + +```math +z_n = C_0 + C_1 \, u_n + u_n^\top C_2 \, u_n + v_n +``` + +This formulation follows Andreasen, Fernandez-Villaverde, and Rubio-Ramirez (2017), +"The Pruned State-Space System for Non-Linear DSGE Models: Theory and Empirical +Applications." + +## Simulating a Quadratic Model + +We define the quadratic coefficients. The tensors `A_2` and `C_2` are 3-dimensional arrays where `A_2[i, :, :]` gives +the matrix for the `i`-th element of the quadratic form. For a 2-state model, +`A_2` is a `2×2×2` array: the quadratic contribution to state `i` is +``u^\top A_2[i,:,:]\, u``. For example, `A_2[1,:,:]` is the 2×2 matrix whose +quadratic form gives the nonlinear correction to the first state element. + +```@example quad +using DifferenceEquations, LinearAlgebra, Random, Plots + +A_0 = [-7.824904812740593e-5, 0.0] +A_1 = [0.95 6.2; 0.0 0.2] +A_2 = cat([-0.0002 0.0334; 0.0 0.0], [0.034 3.129; 0.0 0.0]; dims = 3) +B = [0.0; 0.01;;] +C_0 = [7.8e-5, 0.0] +C_1 = [0.09 0.67; 1.00 0.00] +C_2 = cat([-0.00019 0.0026; 0.0 0.0], [0.0026 0.313; 0.0 0.0]; dims = 3) +D = Diagonal([0.01, 0.01]) +u0 = zeros(2) +T = 30 + +Random.seed!(42) +prob = QuadraticStateSpaceProblem(A_0, A_1, A_2, B, u0, (0, T); + C_0, C_1, C_2, observables_noise = D, syms = (:a, :b)) +sol = solve(prob) +``` + +The solution has the same structure as the linear case: `sol.u` holds the state +trajectory, `sol.z` holds observations, and `sol.W` holds the noise sequence -- all +as `Vector{Vector}`. + +## Plotting and Ensembles + +The standard plotting recipes work identically: + +```@example quad +plot(sol) +``` + +Ensemble simulations follow the same SciML interface: + +```@example quad +using DiffEqBase + +prob_ens = QuadraticStateSpaceProblem(A_0, A_1, A_2, B, u0, (0, T); + C_0, C_1, C_2, observables_noise = D, syms = (:a, :b)) +ensemble_sol = solve(EnsembleProblem(prob_ens), DirectIteration(), EnsembleThreads(); + trajectories = 50) +summ = EnsembleSummary(ensemble_sol) +plot(summ) +``` + +## Joint Likelihood + +When both `noise` and `observables` are provided, the solver computes the joint +log-likelihood conditioned on the noise realization. As with linear models, +observables correspond to ``z_1, \ldots, z_T``, so we pass `sol.z[2:end]`. + +```@example quad +observables = sol.z[2:end] # Vector{Vector}, length T +noise = sol.W # Vector{Vector}, length T + +prob_lik = QuadraticStateSpaceProblem(A_0, A_1, A_2, B, u0, (0, length(observables)); + C_0, C_1, C_2, observables_noise = D, observables, noise) +sol_lik = solve(prob_lik) +sol_lik.logpdf +``` + +## Pruned Quadratic Models + +The pruned formulation of Andreasen et al. (2017) prevents explosive dynamics in +higher-order perturbation solutions. Instead of applying the quadratic term to the +full state, it maintains a separate linear-part state ``u_f`` and applies the +quadratic form to that: + +```math +u_f^{n+1} = A_1 \, u_f^n + B \, w_{n+1} +``` +```math +u_{n+1} = A_0 + A_1 \, u_n + (u_f^n)^\top A_2 \, u_f^n + B \, w_{n+1} +``` + +The [`PrunedQuadraticStateSpaceProblem`](@ref) takes the same arguments as the +unpruned version: + +```@example quad +Random.seed!(42) +prob_pruned = PrunedQuadraticStateSpaceProblem(A_0, A_1, A_2, B, u0, (0, T); + C_0, C_1, C_2, observables_noise = D, syms = (:a, :b)) +sol_pruned = solve(prob_pruned) +``` + +With the same noise, we can compare the pruned and unpruned trajectories: + +```@example quad +Random.seed!(100) +noise_compare = [randn(1) for _ in 1:T] + +sol_unpruned = solve(QuadraticStateSpaceProblem(A_0, A_1, A_2, B, u0, (0, T); + C_0, C_1, C_2, noise = noise_compare)) +sol_pruned_compare = solve(PrunedQuadraticStateSpaceProblem(A_0, A_1, A_2, B, u0, (0, T); + C_0, C_1, C_2, noise = noise_compare)) + +plot( + plot([sol_unpruned.u[t][1] for t in eachindex(sol_unpruned.u)], label = "unpruned", + title = "State 1"), + plot([sol_pruned_compare.u[t][1] for t in eachindex(sol_pruned_compare.u)], + label = "pruned", title = "State 1 (pruned)"), + layout = (1, 2) +) +``` + +The pruned joint likelihood works the same way: + +```@example quad +obs_pruned = sol_pruned.z[2:end] +noise_pruned = sol_pruned.W + +prob_pruned_lik = PrunedQuadraticStateSpaceProblem(A_0, A_1, A_2, B, u0, + (0, length(obs_pruned)); + C_0, C_1, C_2, observables_noise = D, observables = obs_pruned, + noise = noise_pruned) +sol_pruned_lik = solve(prob_pruned_lik) +sol_pruned_lik.logpdf +``` + +## Differentiating with Enzyme + +The joint likelihood for quadratic models can be differentiated with Enzyme.jl, +using the workspace-based `init`/`solve!` pattern. + +!!! note "Enzyme.jl required" + + This example requires Enzyme.jl to be installed. The code is shown but not + executed during documentation build due to Enzyme's compilation overhead. + +```julia +using Enzyme + +function quad_joint_loglik(A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, obs, D, + sol, cache)::Float64 + prob = QuadraticStateSpaceProblem(A_0, A_1, A_2, B, u0, (0, length(obs)); + C_0, C_1, C_2, observables_noise = D, observables = obs, noise) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol, cache) + return solve!(ws).logpdf +end + +# Pre-allocate workspace +prob0 = QuadraticStateSpaceProblem(A_0, A_1, A_2, B, u0, (0, length(obs_pruned)); + C_0, C_1, C_2, observables_noise = D, observables = obs_pruned, noise = noise_pruned) +ws0 = init(prob0, DirectIteration()) + +# Compute gradient with respect to A_1 (all arguments Duplicated) +dA_1 = zero(A_1) +Enzyme.autodiff(Reverse, quad_joint_loglik, + Duplicated(copy(A_0), zero(A_0)), Duplicated(copy(A_1), dA_1), + Duplicated(copy(A_2), zero(A_2)), Duplicated(copy(B), zero(B)), + Duplicated(copy(C_0), zero(C_0)), Duplicated(copy(C_1), zero(C_1)), + Duplicated(copy(C_2), zero(C_2)), Duplicated(copy(u0), zero(u0)), + Duplicated(deepcopy(noise_pruned), [zeros(size(B, 2)) for _ in noise_pruned]), + Duplicated(deepcopy(obs_pruned), [zeros(length(C_0)) for _ in obs_pruned]), + Duplicated(copy(D), zero(D)), + Duplicated(deepcopy(ws0.output), Enzyme.make_zero(deepcopy(ws0.output))), + Duplicated(deepcopy(ws0.cache), Enzyme.make_zero(deepcopy(ws0.cache)))) +dA_1 # gradient of logpdf with respect to A_1 +``` diff --git a/src/DifferenceEquations.jl b/src/DifferenceEquations.jl index a67fb34..8868f85 100644 --- a/src/DifferenceEquations.jl +++ b/src/DifferenceEquations.jl @@ -1,31 +1,35 @@ module DifferenceEquations -using ChainRulesCore: ChainRulesCore, NoTangent, Tangent, ZeroTangent -using CommonSolve: CommonSolve, solve -using DiffEqBase: DiffEqBase, AbstractDEProblem, get_concrete_u0, get_concrete_p, isconcreteu0, +using CommonSolve: CommonSolve, solve, init, solve! +using ConcreteStructs: @concrete +using DiffEqBase: DiffEqBase, DEProblem, get_concrete_u0, get_concrete_p, isconcreteu0, promote_u0 -using Distributions: Distributions, Distribution, MvNormal, UnivariateDistribution, - ZeroMeanDiagNormal, logpdf -using LinearAlgebra: LinearAlgebra, Cholesky, Diagonal, NoPivot, Symmetric, cholesky!, - dot, ldiv!, lmul!, mul!, rmul!, transpose! -using PDMats: PDMats, PDMat +using LinearAlgebra: LinearAlgebra, Diagonal, NoPivot, Symmetric, cholesky, + cholesky!, dot, ldiv!, mul!, transpose! using SciMLBase: SciMLBase, @add_kwonly, NullParameters, promote_tspan, AbstractRODESolution, - ODEFunction, remake, ConstantInterpolation, build_solution -using SymbolicIndexingInterface: SymbolCache -using UnPack: UnPack, @unpack + ODEFunction, remake, ConstantInterpolation, build_solution, ReturnCode +using StaticArrays: StaticArrays, SVector, SMatrix, StaticMatrix, ismutable +using SymbolicIndexingInterface: SymbolicIndexingInterface, SymbolCache, variable_index +include("utilities_bangbang.jl") include("utilities.jl") include("problems/state_space_problems.jl") +include("problems/quadratic_state_space_problems.jl") include("solutions/state_space_solutions.jl") include("solve.jl") +include("caches.jl") +include("workspace.jl") include("algorithms/linear.jl") +include("algorithms/generic.jl") include("algorithms/quadratic.jl") include("precompilation.jl") # Exports -export AbstractStateSpaceProblem, LinearStateSpaceProblem, QuadraticStateSpaceProblem -export StateSpaceSolution, DirectIteration, KalmanFilter +export AbstractStateSpaceProblem, LinearStateSpaceProblem, StateSpaceProblem +export QuadraticStateSpaceProblem, PrunedQuadraticStateSpaceProblem +export StateSpaceSolution, DirectIteration, KalmanFilter, ConditionalLikelihood +export StateSpaceWorkspace -export solve +export solve, init, solve!, remake end # module diff --git a/src/algorithms/generic.jl b/src/algorithms/generic.jl new file mode 100644 index 0000000..b7d3930 --- /dev/null +++ b/src/algorithms/generic.jl @@ -0,0 +1,31 @@ +# Generic state-space model callbacks for DirectIteration solver +# The solver loop is in linear.jl — these methods define the model-specific behavior. + +# ============================================================================= +# NoiseSpec sentinel — provides size(_, 2) and eltype for get_concrete_noise +# ============================================================================= + +struct NoiseSpec{T} + n_shocks::Int +end +NoiseSpec(n_shocks::Int, ::Type{T}) where {T} = NoiseSpec{T}(n_shocks) +Base.size(ns::NoiseSpec, i::Int) = i == 2 ? ns.n_shocks : 1 +Base.eltype(::NoiseSpec{T}) where {T} = T + +# ============================================================================= +# Model interface methods for StateSpaceProblem +# ============================================================================= + +function _noise_matrix(prob::StateSpaceProblem) + return prob.n_shocks > 0 ? NoiseSpec(prob.n_shocks, eltype(prob.u0)) : nothing +end + +_init_model_state!!(::StateSpaceProblem, cache) = nothing + +@inline function _transition!!(x_next, x, w, prob::StateSpaceProblem, cache, t) + return prob.transition(x_next, x, w, prob.p, t - 2) # 0-based time +end + +@inline function _observation!!(y, x, prob::StateSpaceProblem, cache, t) + return prob.observation(y, x, prob.p, t - 1) # 0-based time +end diff --git a/src/algorithms/linear.jl b/src/algorithms/linear.jl index 5ebf387..68dcba7 100644 --- a/src/algorithms/linear.jl +++ b/src/algorithms/linear.jl @@ -1,481 +1,806 @@ -# See the utilities.jl for all of the "maybe" functions and other utilities. Those provide ways for the same algorithm to implement various permutations of the model definitions -# For example, B = nothing, noise = nothing, observables = nothing are all supported +# ============================================================================= +# Model interface for DirectIteration dispatch +# Each problem type defines these methods to plug into the generic solver loop. +# ============================================================================= -function DiffEqBase.__solve( - prob::LinearStateSpaceProblem{ - uType, uPriorMeanType, - uPriorVarType, tType, - P, NP, F, AType, BType, - CType, RType, ObsType, K, - }, - alg::DirectIteration, args...; - kwargs... - ) where { - uType, uPriorMeanType, uPriorVarType, tType, - P, NP, F, AType, - BType, CType, RType, - ObsType, K, - } +# --- Noise matrix extraction --- +_noise_matrix(prob::LinearStateSpaceProblem) = prob.B + +# --- Cache noise access --- +_cache_noise(cache) = cache.noise + +# --- Model-specific initialization (e.g., quadratic u_f) --- +_init_model_state!!(::LinearStateSpaceProblem, cache) = nothing + +# --- Observation flag --- +_has_observations(sol) = !isnothing(sol.z) + +# ============================================================================= +# Linear state-space callbacks +# ============================================================================= + +""" + _transition!!(x_next, x, w, prob::LinearStateSpaceProblem, cache, t) + +Linear state transition: `x_next = A * x + B * w` +""" +@inline function _transition!!(x_next, x, w, prob::LinearStateSpaceProblem, cache, t) + x_next = mul!!(x_next, prob.A, x) + x_next = muladd!!(x_next, prob.B, w) + return x_next +end + +""" + _observation!!(y, x, prob::LinearStateSpaceProblem, cache, t) + +Linear observation: `y = C * x` +""" +@inline function _observation!!(y, x, prob::LinearStateSpaceProblem, cache, t) + y = mul!!(y, prob.C, x) + return y +end + +# ============================================================================= +# Generic DirectIteration solver — single loop for all problem types +# ============================================================================= + +# --- Default 7-arg fallbacks for save_everystep=false endpoints loops --- +# PrunedQuadratic overrides these; all other problem types fall through. +@inline _transition!!(x_next, x, w, prob, cache, t, ::Val) = + _transition!!(x_next, x, w, prob, cache, t) +@inline _observation!!(y, x, prob, cache, t, ::Val) = + _observation!!(y, x, prob, cache, t) +@inline _init_model_state!!(prob, cache, ::Val) = + _init_model_state!!(prob, cache) + +# Function barrier: _noise_matrix may return a union type for StateSpaceProblem +# (n_shocks is a runtime Int). Splitting here lets Julia specialize the hot loop +# on the concrete B type. +function _solve!( + prob::AbstractStateSpaceProblem, alg::DirectIteration, sol, cache; + save_everystep::Val{SE} = Val(true), kwargs... + ) where {SE} T = convert(Int64, prob.tspan[2] - prob.tspan[1] + 1) - @unpack A, B, C = prob + B = _noise_matrix(prob) + if SE + return _solve_direct_iteration!(prob, alg, sol, cache, B, T; kwargs...) + else + return _solve_direct_iteration_endpoints!(prob, alg, sol, cache, B, T; kwargs...) + end +end + +function _solve_direct_iteration!( + prob, alg, sol, cache, B, T; + perturb_diagonal = 0.0, kwargs... + ) + # Get concrete noise and copy into cache + noise_concrete = get_concrete_noise(prob, prob.noise, B, T - 1) + + # Validate dimensions + if !isnothing(noise_concrete) + length(noise_concrete) == T - 1 || + throw(ArgumentError("noise length $(length(noise_concrete)) must equal T-1 = $(T - 1)")) + length(noise_concrete[1]) == size(B, 2) || + throw(ArgumentError("noise dimension $(length(noise_concrete[1])) must equal number of shocks $(size(B, 2))")) + end + maybe_check_size(prob.observables, 2, T - 1) || + throw(ArgumentError("observables length must equal T-1 = $(T - 1)")) - # checks on bounds - noise = get_concrete_noise(prob, prob.noise, prob.B, T - 1) # concrete noise for simulations as required. - observables_noise = make_observables_noise(prob.observables_noise) + (; u, z) = sol + noise = _cache_noise(cache) - @assert maybe_check_size(noise, 1, prob.B, 2) - @assert maybe_check_size(noise, 2, T - 1) - @assert maybe_check_size(prob.observables, 2, T - 1) + # Copy noise into cache buffers + if !isnothing(noise) && !isnothing(noise_concrete) + copy_noise_to_cache!(noise, noise_concrete) + end - # Initialize - u = [zero(prob.u0) for _ in 1:T] - u[1] .= prob.u0 + # Initialize state + u[1] = assign!!(u[1], prob.u0) + _init_model_state!!(prob, cache) + + # Initial observation + if _has_observations(sol) + z[1] = _observation!!(z[1], u[1], prob, cache, 1) + end - z = allocate_z(prob, C, prob.u0, T) - maybe_mul!(z, 1, C, u, 1) # update the first of z if it isn't nothing + # Pre-compute observation noise Cholesky (used for loglik and/or simulation noise) + has_obs_noise = !isnothing(prob.observables_noise) && !isnothing(cache.R) + has_obs = has_obs_noise && !isnothing(prob.observables) + if has_obs_noise + R_cov = make_observables_covariance_matrix(prob.observables_noise) + R_buf = cache.R + R_chol_buf = cache.R_chol + R_buf = copyto!!(R_buf, R_cov) + R_chol_buf = symmetrize_upper!!(R_chol_buf, R_buf, perturb_diagonal) + F_obs = cholesky!!(R_chol_buf, :U) + end + if has_obs + logdetR = logdet_chol(F_obs) + M_obs = size(R_buf, 1) + log_const = M_obs * log(2π) + logdetR + end + + loglik = zero(eltype(prob.u0)) + is_mutable = ismutable(u[1]) - loglik = 0.0 @inbounds for t in 2:T - mul!(u[t], A, u[t - 1]) - maybe_muladd!(u[t], B, noise, t - 1) # was: mul!(u[t], B, view(noise, :, t - 1), 1, 1) + w_t = isnothing(noise) ? nothing : noise[t - 1] + u[t] = _transition!!(u[t], u[t - 1], w_t, prob, cache, t) + + if _has_observations(sol) + z[t] = _observation!!(z[t], u[t], prob, cache, t) + end + + # Log-likelihood contribution (Cholesky-based, allocation-free) + if has_obs + obs_t = get_observable(prob.observables, t - 1) + ν = cache.innovation[t - 1] + ν = copyto!!(ν, obs_t) + if is_mutable + for i in eachindex(ν) + ν[i] -= z[t][i] + end + else + ν = ν - z[t] + end + cache.innovation[t - 1] = ν + + ν_solved = cache.innovation_solved[t - 1] + ν_solved = ldiv!!(ν_solved, F_obs, ν) + cache.innovation_solved[t - 1] = ν_solved + quad = dot(ν, ν_solved) + loglik -= 0.5 * (log_const + quad) + end + end - maybe_mul!(z, t, C, u, t) # does mul!(z[t], C, u[t]) if C is not nothing - loglik += maybe_logpdf(observables_noise, prob.observables, t - 1, z, t) + # Add observation noise for simulation (when no observables provided) + if has_obs_noise && isnothing(prob.observables) + _add_observation_noise!!(z, F_obs) end - maybe_add_observation_noise!(z, observables_noise, prob.observables) - t_values = prob.tspan[1]:prob.tspan[2] + + t_values = prob.tspan[1]:1:prob.tspan[2] return build_solution( - prob, alg, t_values, u; W = noise, - logpdf = ObsType <: Nothing ? nothing : loglik, z, - retcode = :Success + prob, alg, t_values, u; W = noise_concrete, + logpdf = loglik, z, + retcode = ReturnCode.Success ) end -# Ideally hook into existing sensitivity dispatching -# Trouble with Zygote. The problem isn't the _concrete_solve_adjoint but rather something in the -# adjoint of the basic solve and `solve_up`. Probably promotion on the prob - -# function DiffEqBase._concrete_solve_adjoint(prob::LinearStateSpaceProblem, alg::DirectIteration, -# sensealg, u0, p, args...; kwargs...) -function ChainRulesCore.rrule( - ::typeof(solve), - prob::LinearStateSpaceProblem{ - uType, uPriorMeanType, - uPriorVarType, - tType, - P, NP, F, AType, BType, - CType, RType, ObsType, K, - }, - alg::DirectIteration, args...; - kwargs... - ) where { - uType, uPriorMeanType, uPriorVarType, tType, - P, NP, F, - AType, - BType, CType, RType, - ObsType, K, - } - T = convert(Int64, prob.tspan[2] - prob.tspan[1] + 1) - @unpack A, B, C = prob - # @assert !isnothing(prob.noise) || isnothing(prob.B) # need to have concrete noise or no noise for this simple method +# ============================================================================= +# DirectIteration endpoints solver (save_everystep=false) +# Ping-pong between 2-element u/z, single-slot innovation cache. +# ============================================================================= + +function _solve_direct_iteration_endpoints!( + prob, alg, sol, cache, B, T; + perturb_diagonal = 0.0, kwargs... + ) + noise_concrete = get_concrete_noise(prob, prob.noise, B, T - 1) - # checks on bounds - noise = get_concrete_noise(prob, prob.noise, prob.B, T - 1) # concrete noise for simulations as required. - observables_noise = make_observables_noise(prob.observables_noise) - @assert observables_noise isa Union{ZeroMeanDiagNormal, Nothing} # can extend to more general in rrule later + if !isnothing(noise_concrete) + length(noise_concrete) == T - 1 || + throw(ArgumentError("noise length $(length(noise_concrete)) must equal T-1 = $(T - 1)")) + length(noise_concrete[1]) == size(B, 2) || + throw(ArgumentError("noise dimension $(length(noise_concrete[1])) must equal number of shocks $(size(B, 2))")) + end + maybe_check_size(prob.observables, 2, T - 1) || + throw(ArgumentError("observables length must equal T-1 = $(T - 1)")) - @assert maybe_check_size(noise, 1, prob.B, 2) - @assert maybe_check_size(noise, 2, T - 1) - @assert maybe_check_size(prob.observables, 2, T - 1) + (; u, z) = sol + noise = _cache_noise(cache) + _se = Val(false) - # Initialize - u = [zero(prob.u0) for _ in 1:T] - u[1] .= prob.u0 + if !isnothing(noise) && !isnothing(noise_concrete) + copy_noise_to_cache!(noise, noise_concrete) + end + + # Initialize state at ping-pong slot 1 + u[1] = assign!!(u[1], prob.u0) + _init_model_state!!(prob, cache, _se) + + if _has_observations(sol) + z[1] = _observation!!(z[1], u[1], prob, cache, 1, _se) + end - z = allocate_z(prob, C, prob.u0, T) - maybe_mul!(z, 1, C, u, 1) # update the first of z if it isn't nothing + has_obs_noise = !isnothing(prob.observables_noise) && !isnothing(cache.R) + has_obs = has_obs_noise && !isnothing(prob.observables) + if has_obs_noise + R_cov = make_observables_covariance_matrix(prob.observables_noise) + R_buf = cache.R + R_chol_buf = cache.R_chol + R_buf = copyto!!(R_buf, R_cov) + R_chol_buf = symmetrize_upper!!(R_chol_buf, R_buf, perturb_diagonal) + F_obs = cholesky!!(R_chol_buf, :U) + end + if has_obs + logdetR = logdet_chol(F_obs) + M_obs = size(R_buf, 1) + log_const = M_obs * log(2π) + logdetR + end + + loglik = zero(eltype(prob.u0)) + is_mutable = ismutable(u[1]) - loglik = 0.0 @inbounds for t in 2:T - mul!(u[t], A, u[t - 1]) - maybe_muladd!(u[t], B, noise, t - 1) # was: mul!(u[t], B, view(noise, :, t - 1), 1, 1) + w_t = isnothing(noise) ? nothing : noise[t - 1] + ci = _u_idx_pingpong(t) + pi = _u_idx_pingpong(t - 1) + u[ci] = _transition!!(u[ci], u[pi], w_t, prob, cache, t, _se) + + if _has_observations(sol) + z[ci] = _observation!!(z[ci], u[ci], prob, cache, t, _se) + end + + if has_obs + obs_t = get_observable(prob.observables, t - 1) + ν = cache.innovation[1] + ν = copyto!!(ν, obs_t) + if is_mutable + for i in eachindex(ν) + ν[i] -= z[ci][i] + end + else + ν = ν - z[ci] + end + cache.innovation[1] = ν - maybe_mul!(z, t, C, u, t) # does mul!(z[t], C, u[t]) if C is not nothing - loglik += maybe_logpdf(observables_noise, prob.observables, t - 1, z, t) + ν_solved = cache.innovation_solved[1] + ν_solved = ldiv!!(ν_solved, F_obs, ν) + cache.innovation_solved[1] = ν_solved + quad = dot(ν, ν_solved) + loglik -= 0.5 * (log_const + quad) + end end - maybe_add_observation_noise!(z, observables_noise, prob.observables) - t_values = prob.tspan[1]:prob.tspan[2] - sol = build_solution( - prob, alg, t_values, u; W = noise, - logpdf = ObsType <: Nothing ? nothing : loglik, z, - retcode = :Success - ) - function solve_pb(Δsol) - ΔA = zero(A) - ΔB = maybe_zero(B) - ΔC = maybe_zero(C) - Δnoise = maybe_zero(noise) - Δu = zero(u[1]) - Δu_temp = zero(u[1]) - - # Assert checked above about being diagonal and Normal - observables_noise_cov = prob.observables_noise - - @views @inbounds for t in T:-1:2 - Δz = maybe_zero(z, 1) # zero out in case no logpdf but z observations available - maybe_add_Δ_logpdf!( - Δz, Δsol.logpdf, prob.observables, z, t, - observables_noise_cov - ) - maybe_add_Δ!(Δz, Δsol.z, t) # only accumulte if z provided - - copy!(Δu_temp, Δu) - maybe_muladd_transpose!(Δu_temp, C, Δz) # mul!(Δu_temp, C', Δz, 1, 1) - maybe_add_Δ!(Δu_temp, Δsol.u, t) # only accumulate if not NoTangent and if observables provided - mul!(Δu, A', Δu_temp) - maybe_mul_transpose!(Δnoise, t - 1, B, Δu_temp) - maybe_add_Δ_slice!(Δnoise, Δsol.W, t - 1) - mul!(ΔA, Δu_temp, u[t - 1]', 1, 1) - maybe_muladd_transpose!(ΔB, Δu_temp, noise, t - 1) - maybe_muladd!(ΔC, Δz, u[t]') + # Add observation noise for simulation (when no observables provided) + if has_obs_noise && isnothing(prob.observables) + _add_observation_noise!!(z, F_obs) + end + + # Fixup: ensure u[1]=u0, u[2]=final state + final_idx = _u_idx_pingpong(T) + if final_idx == 1 + u[2] = assign!!(u[2], u[1]) + end + u[1] = assign!!(u[1], prob.u0) + if _has_observations(sol) + if final_idx == 1 + z[2] = assign!!(z[2], z[1]) end - return ( - NoTangent(), - Tangent{typeof(prob)}(; - A = ΔA, B = ΔB, C = ΔC, u0 = Δu, noise = Δnoise, - observables = NoTangent(), # not implemented - observables_noise = NoTangent() - ), NoTangent(), - map(_ -> NoTangent(), args)..., - ) + z[1] = _observation!!(z[1], u[1], prob, cache, 1, _se) end - return sol, solve_pb + + _step = max(1, prob.tspan[2] - prob.tspan[1]) + t_values = prob.tspan[1]:_step:prob.tspan[2] + return build_solution( + prob, alg, t_values, u; W = noise_concrete, + logpdf = loglik, z, + retcode = ReturnCode.Success + ) end +# Single __solve route for all problem types with DirectIteration function DiffEqBase.__solve( - prob::LinearStateSpaceProblem, alg::KalmanFilter, args...; - kwargs... + prob::AbstractStateSpaceProblem, alg::DirectIteration, args...; + save_everystep = true, kwargs... ) + ws = CommonSolve.init(prob, alg; save_everystep, kwargs...) + return CommonSolve.solve!(ws; kwargs...) +end + +# ============================================================================= +# ConditionalLikelihood solver — generic for all problem types +# Prediction error decomposition for fully-observed state-space models. +# ============================================================================= + +# Function barrier: same pattern as DirectIteration +function _solve!( + prob::AbstractStateSpaceProblem, alg::ConditionalLikelihood, sol, cache; + save_everystep::Val{SE} = Val(true), kwargs... + ) where {SE} T = convert(Int64, prob.tspan[2] - prob.tspan[1] + 1) + B = _noise_matrix(prob) + if SE + return _solve_conditional_likelihood!(prob, alg, sol, cache, B, T; kwargs...) + else + return _solve_conditional_likelihood_endpoints!(prob, alg, sol, cache, B, T; kwargs...) + end +end + +function _solve_conditional_likelihood!( + prob, alg, sol, cache, B, T; + perturb_diagonal = 0.0, kwargs... + ) + # Validate requirements + isnothing(prob.observables) && + throw(ArgumentError("ConditionalLikelihood requires observables")) + isnothing(prob.observables_noise) && + throw(ArgumentError("ConditionalLikelihood requires observables_noise")) + maybe_check_size(prob.observables, 2, T - 1) || + throw(ArgumentError("observables length must equal T-1 = $(T - 1)")) + + # Get concrete noise and copy into cache + noise_concrete = get_concrete_noise(prob, prob.noise, B, T - 1) + + (; u, z) = sol + noise = _cache_noise(cache) + has_obs_func = _has_observations(sol) + + if !isnothing(noise) && !isnothing(noise_concrete) + copy_noise_to_cache!(noise, noise_concrete) + end + + # Initialize state + u[1] = assign!!(u[1], prob.u0) + _init_model_state!!(prob, cache) + + # Initial observation (for diagnostics, not used in loglik) + if has_obs_func + z[1] = _observation!!(z[1], u[1], prob, cache, 1) + end + + # Pre-compute observation noise Cholesky + R_cov = make_observables_covariance_matrix(prob.observables_noise) + R_buf = copyto!!(cache.R, R_cov) + R_chol_buf = symmetrize_upper!!(cache.R_chol, R_buf, perturb_diagonal) + F_obs = cholesky!!(R_chol_buf, :U) + logdetR = logdet_chol(F_obs) + M_obs = size(R_buf, 1) + log_const = M_obs * log(2π) + logdetR - # checks on bounds - @assert size(prob.observables, 2) == T - 1 - - @unpack A, B, C, u0_prior_mean, u0_prior_var = prob - N = length(u0_prior_mean) - L = size(C, 1) - - # TODO: move to internal algorithm cache - # This method of preallocation won't work with staticarrays. Note that we can't use eltype(mean(u0)) since it may be special case of FillArrays.zeros - u = [Vector{eltype(u0_prior_var)}(undef, N) for _ in 1:T] # Mean of Kalman filter inferred latent states - P = [Matrix{eltype(u0_prior_var)}(undef, N, N) for _ in 1:T] # Posterior variance of Kalman filter inferred latent states - z = [Vector{eltype(prob.observables)}(undef, size(prob.observables, 1)) for _ in 1:T] # Mean of observables, generated from mean of latent states - - # TODO: these intermediates should be of size T-1 instead as the first was skipped. Left in for checks on timing - # Maintaining allocations for these intermediates is necessary for the rrule, but not for forward only. Code could be refactored along those lines with solid unit tests. - B_prod = Matrix{eltype(u0_prior_var)}(undef, N, N) - u_mid = [Vector{eltype(u0_prior_var)}(undef, N) for _ in 1:T] # intermediate in u calculation - P_mid = [Matrix{eltype(u0_prior_var)}(undef, N, N) for _ in 1:T] # intermediate in P calculation - innovation = [ - Vector{eltype(prob.observables)}(undef, size(prob.observables, 1)) - for _ in 1:T - ] - K = [Matrix{eltype(u0_prior_var)}(undef, N, L) for _ in 1:T] # Gain - CP = [Matrix{eltype(u0_prior_var)}(undef, L, N) for _ in 1:T] # C * P[t] - V = [ - PDMat{eltype(u0_prior_var), Matrix{eltype(u0_prior_var)}}( - Matrix{eltype(u0_prior_var)}(undef, L, L), - Cholesky{eltype(u0_prior_var), Matrix{eltype(u0_prior_var)}}( - Matrix{eltype(u0_prior_var)}(undef, L, L), - 'U', - 0 - ) - ) - for _ in 1:T - ] # preallocated buffers for cholesky and matrix itself - - R = make_observables_covariance_matrix(prob.observables_noise) # Support diagonal or matrix covariance matrices. - mul!(B_prod, B, B') - - # Gaussian Prior - u[1] .= u0_prior_mean - P[1] .= u0_prior_var - z[1] .= C * u[1] - - loglik = 0.0 - - # temp buffers. Could be moved into algorithm settings - temp_N_N = Matrix{eltype(u0_prior_var)}(undef, N, N) - temp_L_L = Matrix{eltype(u0_prior_var)}(undef, L, L) - temp_L_N = Matrix{eltype(u0_prior_var)}(undef, L, N) - - retcode = :Failure - try - @inbounds for t in 2:T - # Kalman iteration - mul!(u_mid[t], A, u[t - 1]) # u[t] = A u[t-1] - mul!(z[t], C, u_mid[t]) # z[t] = C u[t] - - # P[t] = A * P[t - 1] * A' + B * B' - mul!(temp_N_N, P[t - 1], A') - mul!(P_mid[t], A, temp_N_N) - P_mid[t] .+= B_prod - - mul!(CP[t], C, P_mid[t]) # CP[t] = C * P[t] - - # V[t] = CP[t] * C' + R - mul!(V[t].mat, CP[t], C') - V[t].mat .+= R - - # V_t .= (V_t + V_t') / 2 # classic hack to deal with stability of not being quite symmetric - transpose!(temp_L_L, V[t].mat) - V[t].mat .+= temp_L_L - lmul!(0.5, V[t].mat) - - copy!(V[t].chol.factors, V[t].mat) # copy over to the factors for the cholesky and do in place - cholesky!(V[t].chol.factors, NoPivot(); check = false) # inplace uses V_t with cholesky. Now V[t]'s chol is upper-triangular - innovation[t] .= prob.observables[:, t - 1] - z[t] - loglik += logpdf(MvNormal(V[t]), innovation[t]) # no allocations since V[t] is a PDMat - - # K[t] .= CP[t]' / V[t] # Kalman gain - # Can rewrite as K[t]' = V[t] \ CP[t] since V[t] is symmetric - ldiv!(temp_L_N, V[t].chol, CP[t]) - transpose!(K[t], temp_L_N) - - #u[t] += K[t] * innovation[t] - copy!(u[t], u_mid[t]) - mul!(u[t], K[t], innovation[t], 1, 1) - - #P[t] -= K[t] * CP[t] - copy!(P[t], P_mid[t]) - mul!(P[t], K[t], CP[t], -1, 1) + loglik = zero(eltype(prob.u0)) + is_mutable = ismutable(u[1]) + + @inbounds for t in 2:T + w_t = isnothing(noise) ? nothing : noise[t - 1] + + # Predict into u[t] (temporary) + u[t] = _transition!!(u[t], u[t - 1], w_t, prob, cache, t) + + # Predicted observation + if has_obs_func + z[t] = _observation!!(z[t], u[t], prob, cache, t) + z_pred = z[t] + else + z_pred = u[t] + end + + # Innovation: ν = obs_t - z_pred + obs_t = get_observable(prob.observables, t - 1) + ν = cache.innovation[t - 1] + ν = copyto!!(ν, obs_t) + if is_mutable + for i in eachindex(ν) + ν[i] -= z_pred[i] + end + else + ν = ν - z_pred end - retcode = :Success - catch e - loglik = -Inf + cache.innovation[t - 1] = ν + + # Log-likelihood contribution + ν_solved = cache.innovation_solved[t - 1] + ν_solved = ldiv!!(ν_solved, F_obs, ν) + cache.innovation_solved[t - 1] = ν_solved + quad = dot(ν, ν_solved) + loglik -= 0.5 * (log_const + quad) + + # CLAMP: set state to observation for next step + u[t] = assign!!(u[t], obs_t) end - t_values = prob.tspan[1]:prob.tspan[2] + t_values = prob.tspan[1]:1:prob.tspan[2] return build_solution( - prob, alg, t_values, u; P, W = nothing, logpdf = loglik, z, - retcode + prob, alg, t_values, u; W = noise_concrete, logpdf = loglik, z, + retcode = ReturnCode.Success ) end -# NOTE: when moving to ._concrete_solve_adjoint will need to be careful to ensure the u0 sensitivity -# takes into account any promotion in the `remake_model` side. We want u0 to be the prior and have the -# sensitivity of it as a distribution, not a draw from it which might happen in the remake(...) +# ============================================================================= +# ConditionalLikelihood endpoints solver (save_everystep=false) +# ============================================================================= + +function _solve_conditional_likelihood_endpoints!( + prob, alg, sol, cache, B, T; + perturb_diagonal = 0.0, kwargs... + ) + isnothing(prob.observables) && + throw(ArgumentError("ConditionalLikelihood requires observables")) + isnothing(prob.observables_noise) && + throw(ArgumentError("ConditionalLikelihood requires observables_noise")) + maybe_check_size(prob.observables, 2, T - 1) || + throw(ArgumentError("observables length must equal T-1 = $(T - 1)")) + + noise_concrete = get_concrete_noise(prob, prob.noise, B, T - 1) + + (; u, z) = sol + noise = _cache_noise(cache) + has_obs_func = _has_observations(sol) + _se = Val(false) + + if !isnothing(noise) && !isnothing(noise_concrete) + copy_noise_to_cache!(noise, noise_concrete) + end + + u[1] = assign!!(u[1], prob.u0) + _init_model_state!!(prob, cache, _se) + + if has_obs_func + z[1] = _observation!!(z[1], u[1], prob, cache, 1, _se) + end + + R_cov = make_observables_covariance_matrix(prob.observables_noise) + R_buf = copyto!!(cache.R, R_cov) + R_chol_buf = symmetrize_upper!!(cache.R_chol, R_buf, perturb_diagonal) + F_obs = cholesky!!(R_chol_buf, :U) + logdetR = logdet_chol(F_obs) + M_obs = size(R_buf, 1) + log_const = M_obs * log(2π) + logdetR + + loglik = zero(eltype(prob.u0)) + is_mutable = ismutable(u[1]) + + @inbounds for t in 2:T + w_t = isnothing(noise) ? nothing : noise[t - 1] + ci = _u_idx_pingpong(t) + pi = _u_idx_pingpong(t - 1) + + # Predict into u[ci] + u[ci] = _transition!!(u[ci], u[pi], w_t, prob, cache, t, _se) + + # Predicted observation + if has_obs_func + z[ci] = _observation!!(z[ci], u[ci], prob, cache, t, _se) + z_pred = z[ci] + else + z_pred = u[ci] + end + + # Innovation + obs_t = get_observable(prob.observables, t - 1) + ν = cache.innovation[1] + ν = copyto!!(ν, obs_t) + if is_mutable + for i in eachindex(ν) + ν[i] -= z_pred[i] + end + else + ν = ν - z_pred + end + cache.innovation[1] = ν + + ν_solved = cache.innovation_solved[1] + ν_solved = ldiv!!(ν_solved, F_obs, ν) + cache.innovation_solved[1] = ν_solved + quad = dot(ν, ν_solved) + loglik -= 0.5 * (log_const + quad) + + # CLAMP + u[ci] = assign!!(u[ci], obs_t) + end + + # Fixup: u[1]=u0, u[2]=final clamped state + final_idx = _u_idx_pingpong(T) + if final_idx == 1 + u[2] = assign!!(u[2], u[1]) + end + u[1] = assign!!(u[1], prob.u0) + if has_obs_func + if final_idx == 1 + z[2] = assign!!(z[2], z[1]) + end + z[1] = _observation!!(z[1], u[1], prob, cache, 1, _se) + end -# function DiffEqBase._concrete_solve_adjoint(prob::LinearStateSpaceProblem, alg::KalmanFilter, -# sensealg, u0, p, args...; kwargs...) -function ChainRulesCore.rrule( - ::typeof(solve), prob::LinearStateSpaceProblem, - alg::KalmanFilter, args...; kwargs... + _step = max(1, prob.tspan[2] - prob.tspan[1]) + t_values = prob.tspan[1]:_step:prob.tspan[2] + return build_solution( + prob, alg, t_values, u; W = noise_concrete, logpdf = loglik, z, + retcode = ReturnCode.Success + ) +end + +function DiffEqBase.__solve( + prob::AbstractStateSpaceProblem, alg::ConditionalLikelihood, args...; + save_everystep = true, kwargs... ) - # Preallocate values + ws = CommonSolve.init(prob, alg; save_everystep, kwargs...) + return CommonSolve.solve!(ws; kwargs...) +end + +# ============================================================================= +# KalmanFilter solver — specific to LinearStateSpaceProblem +# ============================================================================= + +function _solve!( + prob::LinearStateSpaceProblem, alg::KalmanFilter, sol, cache; + save_everystep::Val{SE} = Val(true), perturb_diagonal = 0.0, kwargs... + ) where {SE} T = convert(Int64, prob.tspan[2] - prob.tspan[1] + 1) - # checks on bounds - @assert size(prob.observables, 2) == T - 1 - - @unpack A, B, C, u0_prior_mean, u0_prior_var = prob - N = length(u0_prior_mean) - L = size(C, 1) - - # TODO: move to internal algorithm cache - # This method of preallocation won't work with staticarrays. Note that we can't use eltype(mean(u0)) since it may be special case of FillArrays.zeros - B_prod = Matrix{eltype(u0_prior_var)}(undef, N, N) - u = [Vector{eltype(u0_prior_var)}(undef, N) for _ in 1:T] # Mean of Kalman filter inferred latent states - P = [Matrix{eltype(u0_prior_var)}(undef, N, N) for _ in 1:T] # Posterior variance of Kalman filter inferred latent states - z = [Vector{eltype(prob.observables)}(undef, size(prob.observables, 1)) for _ in 1:T] # Mean of observables, generated from mean of latent states - - # TODO: these intermediates should be of size T-1 instead as the first was skipped. Left in for checks on timing - # Maintaining allocations for these intermediates is necessary for the rrule, but not for forward only. Code could be refactored along those lines with solid unit tests. - u_mid = [Vector{eltype(u0_prior_var)}(undef, N) for _ in 1:T] # intermediate in u calculation - P_mid = [Matrix{eltype(u0_prior_var)}(undef, N, N) for _ in 1:T] # intermediate in P calculation - innovation = [ - Vector{eltype(prob.observables)}(undef, size(prob.observables, 1)) - for _ in 1:T - ] - K = [Matrix{eltype(u0_prior_var)}(undef, N, L) for _ in 1:T] # Gain - CP = [Matrix{eltype(u0_prior_var)}(undef, L, N) for _ in 1:T] # C * P[t] - V = [ - PDMat{eltype(u0_prior_var), Matrix{eltype(u0_prior_var)}}( - Matrix{eltype(u0_prior_var)}(undef, L, L), - Cholesky{eltype(u0_prior_var), Matrix{eltype(u0_prior_var)}}( - Matrix{eltype(u0_prior_var)}(undef, L, L), - 'U', - 0 - ) - ) - for _ in 1:T - ] # preallocated buffers for cholesky and matrix itself - - R = make_observables_covariance_matrix(prob.observables_noise) # Support diagonal or matrix covariance matrices. - mul!(B_prod, B, B') - - u[1] .= u0_prior_mean - P[1] .= u0_prior_var - z[1] .= C * u[1] - - loglik = 0.0 - - # temp buffers. Could be moved into algorithm settings - temp_N_N = Matrix{eltype(u0_prior_var)}(undef, N, N) - temp_L_L = Matrix{eltype(u0_prior_var)}(undef, L, L) - temp_L_N = Matrix{eltype(u0_prior_var)}(undef, L, N) - temp_N_L = Matrix{eltype(u0_prior_var)}(undef, N, L) - temp_M = Vector{eltype(u0_prior_var)}(undef, L) - temp_N = Vector{eltype(u0_prior_var)}(undef, N) - retcode = :Failure - try - @inbounds for t in 2:T - # Kalman iteration - mul!(u_mid[t], A, u[t - 1]) # u[t] = A u[t-1] - mul!(z[t], C, u_mid[t]) # z[t] = C u[t] - - # P[t] = A * P[t - 1] * A' + B * B' - mul!(temp_N_N, P[t - 1], A') - mul!(P_mid[t], A, temp_N_N) - P_mid[t] .+= B_prod - - mul!(CP[t], C, P_mid[t]) # CP[t] = C * P[t] - - # V[t] = CP[t] * C' + R - mul!(V[t].mat, CP[t], C') - V[t].mat .+= R - - # V_t .= (V_t + V_t') / 2 # classic hack to deal with stability of not being quite symmetric - transpose!(temp_L_L, V[t].mat) - V[t].mat .+= temp_L_L - lmul!(0.5, V[t].mat) - - copy!(V[t].chol.factors, V[t].mat) # copy over to the factors for the cholesky and do in place - cholesky!(V[t].chol.factors, NoPivot(); check = false) # inplace uses V_t with cholesky. Now V[t]'s chol is upper-triangular - innovation[t] .= prob.observables[:, t - 1] - z[t] - loglik += logpdf(MvNormal(V[t]), innovation[t]) # no allocations since V[t] is a PDMat - - # K[t] .= CP[t]' / V[t] # Kalman gain - # Can rewrite as K[t]' = V[t] \ CP[t] since V[t] is symmetric - ldiv!(temp_L_N, V[t].chol, CP[t]) - transpose!(K[t], temp_L_N) - - #u[t] += K[t] * innovation[t] - copy!(u[t], u_mid[t]) - mul!(u[t], K[t], innovation[t], 1, 1) - - #P[t] -= K[t] * CP[t] - copy!(P[t], P_mid[t]) - mul!(P[t], K[t], CP[t], -1, 1) + if !SE + return _solve_kalman_endpoints!( + prob, alg, sol, cache, T; + perturb_diagonal, kwargs... + ) + end + length(prob.observables) == T - 1 || + throw(ArgumentError("observables length $(length(prob.observables)) must equal T-1 = $(T - 1)")) + + (; A, B, C, u0_prior_mean, u0_prior_var) = prob + R = make_observables_covariance_matrix(prob.observables_noise) + + (; u, P, z) = sol + (; B_prod, B_t) = cache + + # Compute B*B' once (mul_aat!! avoids BLAS syrk path for Enzyme AD correctness) + B_prod = mul_aat!!(B_prod, B, B_t) + + # Initialize + u[1] = copyto!!(u[1], u0_prior_mean) + P[1] = copyto!!(P[1], u0_prior_var) + z[1] = mul!!(z[1], C, u[1]) + + loglik = zero(eltype(u0_prior_var)) + is_mutable = ismutable(u[1]) + T_obs = length(cache.mu_pred) + M_obs = size(C, 1) + log_const_kf = M_obs * log(2π) + + @inbounds for t in 1:T_obs + # Get scratch buffers for this timestep + μp = cache.mu_pred[t] + Σp = cache.sigma_pred[t] + AΣ = cache.A_sigma[t] + ΣGt = cache.sigma_Gt[t] + ν = cache.innovation[t] + S = cache.innovation_cov[t] + S_chol_buf = cache.S_chol[t] + ν_solved = cache.innovation_solved[t] + rhs = cache.gain_rhs[t] + K_t = cache.gain[t] + KG = cache.gainG[t] + KGS = cache.KgSigma[t] + μu = cache.mu_update[t] + + # Current state (from solution output) + μt = u[t] + Σt = P[t] + + # Predict mean: μp = A * μt + μp = mul!!(μp, A, μt) + + # Predict covariance: Σp = A * Σt * A' + B * B' + AΣ = mul!!(AΣ, A, Σt) + Σp = mul!!(Σp, AΣ, transpose(A)) + if is_mutable + @inbounds for i in eachindex(Σp) + Σp[i] += B_prod[i] + end + else + Σp = Σp + B_prod + end + + # Predicted observation: z[t+1] = C * μp + z[t + 1] = mul!!(z[t + 1], C, μp) + + # Innovation: ν = observables[t] - z[t+1] + obs_t = get_observable(prob.observables, t) + ν = copyto!!(ν, obs_t) + ν = mul!!(ν, C, μp, -1.0, 1.0) + + # Innovation covariance: S = C * Σp * C' + R + ΣGt = mul!!(ΣGt, Σp, transpose(C)) + S = mul!!(S, C, ΣGt) + if is_mutable + @inbounds for i in eachindex(S) + S[i] += R[i] + end + else + S = S + R + end + + # Symmetrize and Cholesky + S_chol_buf = symmetrize_upper!!(S_chol_buf, S, perturb_diagonal) + F = cholesky!!(S_chol_buf, :U) + + # Kalman gain: K = Σp * C' * S^{-1} + rhs = transpose!!(rhs, ΣGt) + rhs = ldiv!!(F, rhs) + K_t = transpose!!(K_t, rhs) + + # Update mean: u[t+1] = μp + K * ν + μu = mul!!(μu, K_t, ν) + if is_mutable + @inbounds for i in eachindex(μp) + u[t + 1][i] = μp[i] + μu[i] + end + else + cache.mu_pred[t] = μp + cache.mu_update[t] = μu + u[t + 1] = μp + μu end - retcode = :Success - catch e - loglik = -Inf + + # Update covariance: P[t+1] = Σp - K * C * Σp + KG = mul!!(KG, K_t, C) + KGS = mul!!(KGS, KG, Σp) + if is_mutable + @inbounds for i in eachindex(Σp) + P[t + 1][i] = Σp[i] - KGS[i] + end + else + cache.sigma_pred[t] = Σp + cache.KgSigma[t] = KGS + P[t + 1] = Σp - KGS + end + + # Log-likelihood contribution (allocation-free) + ν_solved = ldiv!!(ν_solved, F, ν) + cache.innovation[t] = ν + cache.innovation_solved[t] = ν_solved + logdetS = logdet_chol(F) + quad = dot(ν_solved, ν) + loglik -= 0.5 * (log_const_kf + logdetS + quad) end - t_values = prob.tspan[1]:prob.tspan[2] - sol = build_solution( - prob, alg, t_values, u; P, W = nothing, logpdf = loglik, z, - retcode + + t_values = prob.tspan[1]:1:prob.tspan[2] + return build_solution( + prob, alg, t_values, sol.u; P = sol.P, W = nothing, logpdf = loglik, + z = sol.z, retcode = ReturnCode.Success + ) +end + +# ============================================================================= +# KalmanFilter endpoints solver (save_everystep=false) +# Pure recursive filter: 2-element sol, 1-slot cache arrays. +# ============================================================================= + +function _solve_kalman_endpoints!( + prob, alg, sol, cache, T; + perturb_diagonal = 0.0, kwargs... ) - function solve_pb(Δsol) - # Currently only changes in the logpdf are supported in the rrule - @assert Δsol.u == ZeroTangent() - @assert Δsol.W == ZeroTangent() - @assert Δsol.P == ZeroTangent() - @assert Δsol.z == ZeroTangent() - - Δlogpdf = Δsol.logpdf - - if iszero(Δlogpdf) - return ( - NoTangent(), Tangent{typeof(prob)}(), NoTangent(), - map(_ -> NoTangent(), args)..., - ) + T_obs = T - 1 + length(prob.observables) == T_obs || + throw(ArgumentError("observables length $(length(prob.observables)) must equal T-1 = $(T_obs)")) + + (; A, B, C, u0_prior_mean, u0_prior_var) = prob + R = make_observables_covariance_matrix(prob.observables_noise) + + (; u, P, z) = sol + (; B_prod, B_t) = cache + + B_prod = mul_aat!!(B_prod, B, B_t) + + # Initialize at ping-pong slot 1 + u[1] = copyto!!(u[1], u0_prior_mean) + P[1] = copyto!!(P[1], u0_prior_var) + z[1] = mul!!(z[1], C, u[1]) + + loglik = zero(eltype(u0_prior_var)) + is_mutable = ismutable(u[1]) + M_obs = size(C, 1) + log_const_kf = M_obs * log(2π) + + @inbounds for t in 1:T_obs + ci = _u_idx_pingpong(t) # current filtered state + ni = _u_idx_pingpong(t + 1) # next filtered state + + # Single-slot cache buffers + μp = cache.mu_pred[1] + Σp = cache.sigma_pred[1] + AΣ = cache.A_sigma[1] + ΣGt = cache.sigma_Gt[1] + ν = cache.innovation[1] + S = cache.innovation_cov[1] + S_chol_buf = cache.S_chol[1] + ν_solved = cache.innovation_solved[1] + rhs = cache.gain_rhs[1] + K_t = cache.gain[1] + KG = cache.gainG[1] + KGS = cache.KgSigma[1] + μu = cache.mu_update[1] + + μt = u[ci] + Σt = P[ci] + + # Predict mean + μp = mul!!(μp, A, μt) + + # Predict covariance + AΣ = mul!!(AΣ, A, Σt) + Σp = mul!!(Σp, AΣ, transpose(A)) + if is_mutable + @inbounds for i in eachindex(Σp) + Σp[i] += B_prod[i] + end + else + Σp = Σp + B_prod + end + + # Predicted observation + z[ni] = mul!!(z[ni], C, μp) + + # Innovation + obs_t = get_observable(prob.observables, t) + ν = copyto!!(ν, obs_t) + ν = mul!!(ν, C, μp, -1.0, 1.0) + + # Innovation covariance + ΣGt = mul!!(ΣGt, Σp, transpose(C)) + S = mul!!(S, C, ΣGt) + if is_mutable + @inbounds for i in eachindex(S) + S[i] += R[i] + end + else + S = S + R + end + + S_chol_buf = symmetrize_upper!!(S_chol_buf, S, perturb_diagonal) + F = cholesky!!(S_chol_buf, :U) + + # Kalman gain + rhs = transpose!!(rhs, ΣGt) + rhs = ldiv!!(F, rhs) + K_t = transpose!!(K_t, rhs) + + # Update mean + μu = mul!!(μu, K_t, ν) + if is_mutable + @inbounds for i in eachindex(μp) + u[ni][i] = μp[i] + μu[i] + end + else + cache.mu_pred[1] = μp + cache.mu_update[1] = μu + u[ni] = μp + μu end - # Buffers - ΔP = zero(P[1]) - Δu = zero(u[1]) - ΔA = zero(A) - ΔB = zero(B) - ΔC = zero(C) - ΔK = zero(K[1]) - ΔP_mid = zero(ΔP) - ΔP_mid_sum = zero(ΔP) - ΔCP = zero(CP[1]) - Δu_mid = zero(u_mid[1]) - Δz = zero(z[1]) - ΔV = zero(V[1].mat) - - # If it was a failure, just return and hope the gradients are ignored! - if retcode == :Success - for t in T:-1:2 - # The inverse is used throughout, including in quadratic forms. For large systems this might not be stable - inv_V = Symmetric(inv(V[t].chol)) # use cholesky factorization to invert. Symmetric - - # Sensitivity accumulation - copy!(ΔP_mid, ΔP) - mul!(ΔK, ΔP, CP[t]', -1, 0) # i.e. ΔK = -ΔP * CP[t]' - mul!(ΔCP, K[t]', ΔP, -1, 0) # i.e. ΔCP = - K[t]' * ΔP - copy!(Δu_mid, Δu) - mul!(ΔK, Δu, innovation[t]', 1, 1) # ΔK += Δu * innovation[t]' - mul!(Δz, K[t]', Δu, -1, 0) # i.e, Δz = -K[t]'* Δu - mul!(ΔCP, inv_V, ΔK', 1, 1) # ΔCP += inv_V * ΔK' - - # ΔV .= -inv_V * CP[t] * ΔK * inv_V - mul!(temp_L_N, inv_V, CP[t]) - mul!(temp_N_L, ΔK, inv_V) - mul!(ΔV, temp_L_N, temp_N_L, -1, 0) - - mul!(ΔC, ΔCP, P_mid[t]', 1, 1) # ΔC += ΔCP * P_mid[t]' - mul!(ΔP_mid, C', ΔCP, 1, 1) # ΔP_mid += C' * ΔCP - mul!(Δz, inv_V, innovation[t], Δlogpdf, 1) # Δz += Δlogpdf * inv_V * innovation[t] # Σ^-1 * (z_obs - z) - - #ΔV -= Δlogpdf * 0.5 * (inv_V - inv_V * innovation[t] * innovation[t]' * inv_V) # -0.5 * (Σ^-1 - Σ^-1(z_obs - z)(z_obx - z)'Σ^-1) - mul!(temp_M, inv_V, innovation[t]) - mul!(temp_L_L, temp_M, temp_M') - temp_L_L .-= inv_V - rmul!(temp_L_L, Δlogpdf * 0.5) - ΔV += temp_L_L - - #ΔC += ΔV * C * P_mid[t]' + ΔV' * C * P_mid[t] - mul!(temp_L_N, C, P_mid[t]) - transpose!(temp_L_L, ΔV) - temp_L_L .+= ΔV - mul!(ΔC, temp_L_L, temp_L_N, 1, 1) - - # ΔP_mid += C' * ΔV * C - mul!(temp_L_N, ΔV, C) - mul!(ΔP_mid, C', temp_L_N, 1, 1) - - mul!(ΔC, Δz, u_mid[t]', 1, 1) # ΔC += Δz * u_mid[t]' - mul!(Δu_mid, C', Δz, 1, 1) # Δu_mid += C' * Δz - - # Calculates (ΔP_mid + ΔP_mid') - transpose!(ΔP_mid_sum, ΔP_mid) - ΔP_mid_sum .+= ΔP_mid - - # ΔA += (ΔP_mid + ΔP_mid') * A * P[t - 1] - mul!(temp_N_N, A, P[t - 1]) - mul!(ΔA, ΔP_mid_sum, temp_N_N, 1, 1) - - # ΔP .= A' * ΔP_mid * A # pass into next period - mul!(temp_N_N, ΔP_mid, A) - mul!(ΔP, A', temp_N_N) - - mul!(ΔB, ΔP_mid_sum, B, 1, 1) # ΔB += ΔP_mid_sum * B - mul!(ΔA, Δu_mid, u[t - 1]', 1, 1) # ΔA += Δu_mid * u[t - 1]' - mul!(Δu, A', Δu_mid) + + # Update covariance + KG = mul!!(KG, K_t, C) + KGS = mul!!(KGS, KG, Σp) + if is_mutable + @inbounds for i in eachindex(Σp) + P[ni][i] = Σp[i] - KGS[i] end + else + cache.sigma_pred[1] = Σp + cache.KgSigma[1] = KGS + P[ni] = Σp - KGS end - return ( - NoTangent(), - Tangent{typeof(prob)}(; - A = ΔA, B = ΔB, C = ΔC, u0 = ZeroTangent(), # u0 not used in kalman filter - u0_prior_mean = Δu, u0_prior_var = ΔP - ), - NoTangent(), map(_ -> NoTangent(), args)..., - ) + + # Log-likelihood + ν_solved = ldiv!!(ν_solved, F, ν) + cache.innovation[1] = ν + cache.innovation_solved[1] = ν_solved + logdetS = logdet_chol(F) + quad = dot(ν_solved, ν) + loglik -= 0.5 * (log_const_kf + logdetS + quad) end - return sol, solve_pb + + # Fixup: u[1]=initial, u[2]=final, P[1]=initial, P[2]=final, z[1]=initial, z[2]=final + final_idx = _u_idx_pingpong(T) # where final state ended up + if final_idx == 1 + u[2] = assign!!(u[2], u[1]) + P[2] = copyto!!(P[2], P[1]) + z[2] = assign!!(z[2], z[1]) + end + u[1] = copyto!!(u[1], u0_prior_mean) + P[1] = copyto!!(P[1], u0_prior_var) + z[1] = mul!!(z[1], C, u[1]) + + _step = max(1, prob.tspan[2] - prob.tspan[1]) + t_values = prob.tspan[1]:_step:prob.tspan[2] + return build_solution( + prob, alg, t_values, sol.u; P = sol.P, W = nothing, logpdf = loglik, + z = sol.z, retcode = ReturnCode.Success + ) +end + +function DiffEqBase.__solve( + prob::LinearStateSpaceProblem, alg::KalmanFilter, args...; + save_everystep = true, kwargs... + ) + ws = CommonSolve.init(prob, alg; save_everystep, kwargs...) + return CommonSolve.solve!(ws; kwargs...) end diff --git a/src/algorithms/quadratic.jl b/src/algorithms/quadratic.jl index 82a0729..e5de916 100644 --- a/src/algorithms/quadratic.jl +++ b/src/algorithms/quadratic.jl @@ -1,206 +1,115 @@ -# This should be ported over to use the "maybe" utilities of the linear model, which will expand the number of model variations that are available. -function DiffEqBase.__solve( - prob::QuadraticStateSpaceProblem{ - uType, uPriorMeanType, - uPriorVarType, - tType, P, NP, F, A0Type, - A1Type, A2Type, BType, C0Type, - C1Type, - C2Type, RType, ObsType, K, - }, - alg::DirectIteration, args...; - kwargs... - ) where { - uType, uPriorMeanType, uPriorVarType, tType, - P, NP, F, - A0Type, A1Type, A2Type, - BType, C0Type, C1Type, C2Type, RType, ObsType, - K, - } - T = convert(Int64, prob.tspan[2] - prob.tspan[1] + 1) - noise = get_concrete_noise(prob, prob.noise, prob.B, T - 1) # concrete noise for simulations as required. - observables_noise = make_observables_noise(prob.observables_noise) - # checks on bounds - @assert size(noise, 1) == size(prob.B, 2) - @assert size(noise, 2) == T - 1 - @assert maybe_check_size(prob.observables, 2, T - 1) +# Quadratic state-space model dispatches for DirectIteration solver +# Two variants: unpruned (quad on x) and pruned (quad on linear-part u_f) +# Both plug into the generic _solve_direct_iteration! loop via these methods. - @unpack A_0, A_1, A_2, B, C_0, C_1, C_2 = prob +# --- Noise matrix extraction --- +_noise_matrix(prob::AnyQuadraticProblem) = prob.B - # These should be be the native datastructure and replace A_2 and C_2 - # See https://github.com/SciML/DifferenceEquations.jl/issues/54 - C_2_vec = [C_2[i, :, :] for i in 1:size(C_2, 1)] - A_2_vec = [A_2[i, :, :] for i in 1:size(A_2, 1)] +# --- Model initialization --- +_init_model_state!!(::QuadraticStateSpaceProblem, cache) = nothing - u_f = [zero(prob.u0) for _ in 1:T] - u = [zero(prob.u0) for _ in 1:T] - z = [zero(C_0) for _ in 1:T] - - u[1] .= prob.u0 - u_f[1] .= prob.u0 - z[1] .= C_0 - mul!(z[1], C_1, prob.u0, 1, 1) - quad_muladd!(z[1], C_2_vec, prob.u0) #z0 .+= quad(C_2, prob.u0) - - loglik = 0.0 - @inbounds @views for t in 2:T - mul!(u_f[t], A_1, u_f[t - 1]) - mul!(u_f[t], B, view(noise, :, t - 1), 1, 1) +function _init_model_state!!(prob::PrunedQuadraticStateSpaceProblem, cache) + cache.u_f[1] = assign!!(cache.u_f[1], prob.u0) + return nothing +end - u[t] .= A_0 - mul!(u[t], A_1, u[t - 1], 1, 1) - quad_muladd!(u[t], A_2_vec, u_f[t - 1]) # u[t] .+= quad(A_2, u_f[t - 1]) - mul!(u[t], B, view(noise, :, t - 1), 1, 1) +# --- Observation flag (shared with linear, already defined) --- +# _has_observations(sol) = !isnothing(sol.z) # defined in linear.jl - z[t] .= C_0 - mul!(z[t], C_1, u[t], 1, 1) - quad_muladd!(z[t], C_2_vec, u_f[t]) # z[t] .+= quad(C_2, u_f[t]) - loglik += maybe_logpdf(observables_noise, prob.observables, t - 1, z, t) +# --- Quadratic form helper --- +# Computes q[i] = v' * A_2[i, :, :] * v for each output dimension +@inline function _add_quadratic!!(y, A_2, v) + if ismutable(y) + @inbounds for i in 1:length(y) + y[i] += dot(v, view(A_2, i, :, :), v) + end + return y + else + n = length(y) + return y + typeof(y)(ntuple(i -> dot(v, view(A_2, i, :, :), v), n)) end - - maybe_add_observation_noise!(z, observables_noise, prob.observables) - t_values = prob.tspan[1]:prob.tspan[2] - return build_solution( - prob, alg, t_values, u; W = noise, - logpdf = ObsType <: Nothing ? nothing : loglik, z, - retcode = :Success - ) end -# Note: this repeats the primal calculation because so many of the internal buffers are useful for the rrule. Refactoring could enable directly shared buffers. -function ChainRulesCore.rrule( - ::typeof(solve), prob::QuadraticStateSpaceProblem, - alg::DirectIteration, args...; kwargs... - ) - T = convert(Int64, prob.tspan[2] - prob.tspan[1] + 1) - noise = get_concrete_noise(prob, prob.noise, prob.B, T - 1) # concrete noise for simulations as required. - @assert !isnothing(prob.noise) # need to have concrete noise for this simple method - # checks on bounds - observables_noise = make_observables_noise(prob.observables_noise) - @assert observables_noise isa ZeroMeanDiagNormal # can extend to more general in rrule - - @assert size(noise, 1) == size(prob.B, 2) - @assert maybe_check_size(prob.observables, 2, T - 1) - @assert size(noise, 2) == T - 1 - - @unpack A_0, A_1, A_2, B, C_0, C_1, C_2 = prob +# ============================================================================= +# Unpruned quadratic: quad(A_2, x) +# ============================================================================= + +@inline function _transition!!(x_next, x, w, prob::QuadraticStateSpaceProblem, cache, t) + (; A_0, A_1, A_2, B) = prob + x_next = copyto!!(x_next, A_0) + x_next = mul!!(x_next, A_1, x, 1.0, 1.0) + x_next = _add_quadratic!!(x_next, A_2, x) + x_next = muladd!!(x_next, B, w) + return x_next +end - # These should be be the native datastructure and replace A_2 and C_2 - # See https://github.com/SciML/DifferenceEquations.jl/issues/54 - C_2_vec = [C_2[i, :, :] for i in 1:size(C_2, 1)] - A_2_vec = [A_2[i, :, :] for i in 1:size(A_2, 1)] +@inline function _observation!!(y, x, prob::QuadraticStateSpaceProblem, cache, t) + (; C_0, C_1, C_2) = prob + y = copyto!!(y, C_0) + y = mul!!(y, C_1, x, 1.0, 1.0) + y = _add_quadratic!!(y, C_2, x) + return y +end - u_f = [zero(prob.u0) for _ in 1:T] - u = [zero(prob.u0) for _ in 1:T] - z = [zero(C_0) for _ in 1:T] +# ============================================================================= +# Pruned quadratic: quad(A_2, u_f) where u_f tracks the linear-part state +# ============================================================================= + +@inline function _transition!!(x_next, x, w, prob::PrunedQuadraticStateSpaceProblem, cache, t) + (; A_0, A_1, A_2, B) = prob + u_f_prev = cache.u_f[t - 1] + # Advance u_f: u_f[t] = A_1 * u_f[t-1] + B * w + u_f_new = mul!!(cache.u_f[t], A_1, u_f_prev) + u_f_new = muladd!!(u_f_new, B, w) + cache.u_f[t] = u_f_new + # Full transition: x_next = A_0 + A_1*x + quad(A_2, u_f_prev) + B*w + x_next = copyto!!(x_next, A_0) + x_next = mul!!(x_next, A_1, x, 1.0, 1.0) + x_next = _add_quadratic!!(x_next, A_2, u_f_prev) + x_next = muladd!!(x_next, B, w) + return x_next +end - u[1] .= prob.u0 - u_f[1] .= prob.u0 - z[1] .= C_0 - mul!(z[1], C_1, prob.u0, 1, 1) - quad_muladd!(z[1], C_2_vec, prob.u0) #z0 .+= quad(C_2, prob.u0) +@inline function _observation!!(y, x, prob::PrunedQuadraticStateSpaceProblem, cache, t) + (; C_0, C_1, C_2) = prob + u_f = cache.u_f[t] + y = copyto!!(y, C_0) + y = mul!!(y, C_1, x, 1.0, 1.0) + y = _add_quadratic!!(y, C_2, u_f) + return y +end - loglik = 0.0 - @inbounds @views for t in 2:T - mul!(u_f[t], A_1, u_f[t - 1]) - mul!(u_f[t], B, view(noise, :, t - 1), 1, 1) +# --- Pruned quadratic: save_everystep=false overloads (ping-pong u_f) --- - u[t] .= A_0 - mul!(u[t], A_1, u[t - 1], 1, 1) - quad_muladd!(u[t], A_2_vec, u_f[t - 1]) # u[t] .+= quad(A_2, u_f[t - 1]) - mul!(u[t], B, view(noise, :, t - 1), 1, 1) +function _init_model_state!!(prob::PrunedQuadraticStateSpaceProblem, cache, ::Val{false}) + cache.u_f[1] = assign!!(cache.u_f[1], prob.u0) + return nothing +end - z[t] .= C_0 - mul!(z[t], C_1, u[t], 1, 1) - quad_muladd!(z[t], C_2_vec, u_f[t]) # z[t] .+= quad(C_2, u_f[t]) - loglik += logpdf(observables_noise, view(prob.observables, :, t - 1) - z[t]) - end - t_values = prob.tspan[1]:prob.tspan[2] - maybe_add_observation_noise!(z, observables_noise, prob.observables) - sol = build_solution( - prob, alg, t_values, u; W = noise, logpdf = loglik, z, - retcode = :Success +@inline function _transition!!( + x_next, x, w, prob::PrunedQuadraticStateSpaceProblem, cache, t, ::Val{false} ) + (; A_0, A_1, A_2, B) = prob + uf_prev_idx = _u_idx_pingpong(t - 1) + uf_curr_idx = _u_idx_pingpong(t) + u_f_prev = cache.u_f[uf_prev_idx] + u_f_new = mul!!(cache.u_f[uf_curr_idx], A_1, u_f_prev) + u_f_new = muladd!!(u_f_new, B, w) + cache.u_f[uf_curr_idx] = u_f_new + x_next = copyto!!(x_next, A_0) + x_next = mul!!(x_next, A_1, x, 1.0, 1.0) + x_next = _add_quadratic!!(x_next, A_2, u_f_prev) + x_next = muladd!!(x_next, B, w) + return x_next +end - function solve_pb(Δsol) - # Currently only changes in the logpdf are supported in the rrule - @assert Δsol.u == ZeroTangent() - @assert Δsol.W == ZeroTangent() - - Δlogpdf = Δsol.logpdf - if iszero(Δlogpdf) - return ( - NoTangent(), Tangent{typeof(prob)}(), NoTangent(), - map(_ -> NoTangent(), args)..., - ) - end - ΔA_0 = zero(A_0) - ΔA_1 = zero(A_1) - ΔA_2_vec = [zero(A) for A in A_2_vec] # should be native datastructure - ΔA_2 = zero(A_2) - - ΔB = zero(B) - ΔC_0 = zero(C_0) - ΔC_1 = zero(C_1) - ΔC_2_vec = [zero(A) for A in C_2_vec] # should be native datastructure - ΔC_2 = zero(C_2) - Δu_f_sum = zero(u[1]) - - Δnoise = similar(noise) - Δu = [zero(prob.u0) for _ in 1:T] - Δu_f = [zero(prob.u0) for _ in 1:T] - A_2_vec_sum = [(A + A') for A in A_2_vec] # prep the sum since we will use it repeatedly - C_2_vec_sum = [(A + A') for A in C_2_vec] # prep the sum since we will use it repeatedly - - # Assert checked above about being diagonal - observables_noise_cov = prob.observables_noise - - @views @inbounds for t in T:-1:2 - Δz = Δlogpdf * (view(prob.observables, :, t - 1) - z[t]) ./ - observables_noise_cov # More generally, it should be Σ^-1 * (z_obs - z) - - # inplace adoint of quadratic form with accumulation - quad_muladd_pb!(ΔC_2_vec, Δu_f[t], Δz, C_2_vec_sum, u_f[t]) - mul!(Δu[t], C_1', Δz, 1, 1) - - quad_muladd_pb!(ΔA_2_vec, Δu_f[t - 1], Δu[t], A_2_vec_sum, u_f[t - 1]) - mul!(Δu[t - 1], A_1', Δu[t]) - mul!(Δu_f[t - 1], A_1', Δu_f[t], 1, 1) - - # Δu_f_sum = Δu[t] .+ Δu_f[t] - copy!(Δu_f_sum, Δu[t]) - Δu_f_sum .+= Δu_f[t] - - mul!(view(Δnoise, :, t - 1), B', Δu_f_sum) - # Now, deal with the coefficients - ΔA_0 += Δu[t] - mul!(ΔA_1, Δu[t], u[t - 1]', 1, 1) - mul!(ΔA_1, Δu_f[t], u_f[t - 1]', 1, 1) - mul!(ΔB, Δu_f_sum, view(noise, :, t - 1)', 1, 1) - ΔC_0 += Δz - mul!(ΔC_1, Δz, u[t]', 1, 1) - end - - # Remove once the vector of matrices or column-major organized 3-tensor is the native datastructure for C_2/A_2 - @views @inbounds for (i, ΔA_2_slice) in enumerate(ΔA_2_vec) - ΔA_2[i, :, :] .= ΔA_2_slice - end - @views @inbounds for (i, ΔC_2_slice) in enumerate(ΔC_2_vec) - ΔC_2[i, :, :] .= ΔC_2_slice - end - - return ( - NoTangent(), - Tangent{typeof(prob)}(; - A_0 = ΔA_0, A_1 = ΔA_1, A_2 = ΔA_2, B = ΔB, - C_0 = ΔC_0, - C_1 = ΔC_1, C_2 = ΔC_2, u0 = Δu[1] + Δu_f[1], - noise = Δnoise, - observables = NoTangent(), # not implemented - observables_noise = NoTangent() - ), NoTangent(), - map(_ -> NoTangent(), args)..., - ) - end - return sol, solve_pb +@inline function _observation!!( + y, x, prob::PrunedQuadraticStateSpaceProblem, cache, t, ::Val{false} + ) + (; C_0, C_1, C_2) = prob + u_f = cache.u_f[_u_idx_pingpong(t)] + y = copyto!!(y, C_0) + y = mul!!(y, C_1, x, 1.0, 1.0) + y = _add_quadratic!!(y, C_2, u_f) + return y end diff --git a/src/caches.jl b/src/caches.jl new file mode 100644 index 0000000..e5fa38d --- /dev/null +++ b/src/caches.jl @@ -0,0 +1,414 @@ +# Cache allocation: pre-allocated solution output + scratch workspace buffers +# Named-tuple storage, vector-of-vectors format + +# ============================================================================= +# Solution output allocation (u, P, z — owned by workspace, returned to user) +# ============================================================================= + +""" + alloc_sol(prob::LinearStateSpaceProblem, ::DirectIteration, T) + +Allocate solution output arrays for DirectIteration. +""" +function alloc_sol(prob::LinearStateSpaceProblem, ::DirectIteration, T) + (; u0, C) = prob + M = isnothing(C) ? 0 : size(C, 1) + return (; + u = [alloc_like(u0) for _ in 1:T], + z = isnothing(C) ? nothing : [alloc_like(u0, M) for _ in 1:T], + ) +end + +""" + alloc_sol(prob::LinearStateSpaceProblem, ::KalmanFilter, T) + +Allocate solution output arrays for KalmanFilter (filtered means, covariances, observations). +""" +function alloc_sol(prob::LinearStateSpaceProblem, ::KalmanFilter, T) + (; u0_prior_mean, u0_prior_var, C) = prob + L = size(C, 1) + return (; + u = [alloc_like(u0_prior_mean) for _ in 1:T], + P = [alloc_like(u0_prior_var) for _ in 1:T], + z = [alloc_like(u0_prior_mean, L) for _ in 1:T], + ) +end + +""" + alloc_sol(prob::StateSpaceProblem, ::DirectIteration, T) + +Allocate solution output arrays for generic StateSpaceProblem. +""" +function alloc_sol(prob::StateSpaceProblem, ::DirectIteration, T) + (; u0, n_obs) = prob + return (; + u = [alloc_like(u0) for _ in 1:T], + z = n_obs > 0 ? [alloc_like(u0, n_obs) for _ in 1:T] : nothing, + ) +end + +# --- Quadratic solution output (same structure as linear) --- + +function alloc_sol(prob::AnyQuadraticProblem, ::DirectIteration, T) + (; u0, C_0) = prob + M = isnothing(C_0) ? 0 : length(C_0) + return (; + u = [alloc_like(u0) for _ in 1:T], + z = isnothing(C_0) ? nothing : [alloc_like(u0, M) for _ in 1:T], + ) +end + +# ============================================================================= +# Scratch cache allocation (temporary workspace buffers only) +# ============================================================================= + +""" + alloc_cache(prob::LinearStateSpaceProblem, ::DirectIteration, T) + +Allocate scratch workspace for DirectIteration (noise buffers, loglik workspace). +""" +function alloc_cache(prob::LinearStateSpaceProblem, ::DirectIteration, T) + (; B, C, u0) = prob + M = isnothing(C) ? 0 : size(C, 1) + has_obs_noise = !isnothing(prob.observables_noise) + return _alloc_di_base_cache(B, u0, M, T, has_obs_noise) +end + +_alloc_noise(B::AbstractMatrix, T) = [alloc_like(B, size(B, 2)) for _ in 1:(T - 1)] +_alloc_noise(B, T) = [Vector{eltype(B)}(undef, size(B, 2)) for _ in 1:(T - 1)] +_alloc_noise(::Nothing, T) = nothing + +# --- Shared base cache for DirectIteration (noise + loglik workspace) --- + +function _alloc_di_base_cache(B, u0, M, T, has_obs_noise) + T_obs = T - 1 + return (; + noise = _alloc_noise(B, T), + R = has_obs_noise ? alloc_like(u0, M, M) : nothing, + R_chol = has_obs_noise ? alloc_like(u0, M, M) : nothing, + innovation = has_obs_noise ? [alloc_like(u0, M) for _ in 1:T_obs] : nothing, + innovation_solved = has_obs_noise ? [alloc_like(u0, M) for _ in 1:T_obs] : nothing, + ) +end + +# --- Unpruned quadratic cache (same as linear) --- + +function alloc_cache(prob::QuadraticStateSpaceProblem, ::DirectIteration, T) + (; B, C_0, u0) = prob + M = isnothing(C_0) ? 0 : length(C_0) + has_obs_noise = !isnothing(prob.observables_noise) + return _alloc_di_base_cache(B, u0, M, T, has_obs_noise) +end + +# --- Pruned quadratic cache (base + u_f buffer) --- + +function alloc_cache(prob::PrunedQuadraticStateSpaceProblem, ::DirectIteration, T) + (; B, C_0, u0) = prob + M = isnothing(C_0) ? 0 : length(C_0) + has_obs_noise = !isnothing(prob.observables_noise) + base = _alloc_di_base_cache(B, u0, M, T, has_obs_noise) + u_f = [alloc_like(u0) for _ in 1:T] + return (; base..., u_f) +end + +# ============================================================================= +# ConditionalLikelihood allocation +# alloc_sol follows DirectIteration (z conditional on observation function). +# alloc_cache always allocates innovation buffers (has_obs_noise = true). +# ============================================================================= + +function alloc_sol(prob::LinearStateSpaceProblem, ::ConditionalLikelihood, T) + (; u0, C) = prob + M = isnothing(C) ? 0 : size(C, 1) + return (; + u = [alloc_like(u0) for _ in 1:T], + z = isnothing(C) ? nothing : [alloc_like(u0, M) for _ in 1:T], + ) +end + +function alloc_sol(prob::StateSpaceProblem, ::ConditionalLikelihood, T) + (; u0, n_obs) = prob + return (; + u = [alloc_like(u0) for _ in 1:T], + z = n_obs > 0 ? [alloc_like(u0, n_obs) for _ in 1:T] : nothing, + ) +end + +function alloc_sol(prob::AnyQuadraticProblem, ::ConditionalLikelihood, T) + (; u0, C_0) = prob + M = isnothing(C_0) ? 0 : length(C_0) + return (; + u = [alloc_like(u0) for _ in 1:T], + z = isnothing(C_0) ? nothing : [alloc_like(u0, M) for _ in 1:T], + ) +end + +function alloc_cache(prob::LinearStateSpaceProblem, ::ConditionalLikelihood, T) + (; B, C, u0) = prob + M = isnothing(C) ? length(u0) : size(C, 1) + return _alloc_di_base_cache(B, u0, M, T, true) +end + +function alloc_cache(prob::StateSpaceProblem, ::ConditionalLikelihood, T) + (; u0, n_obs) = prob + B = _noise_matrix(prob) + M = n_obs > 0 ? n_obs : length(u0) + T_obs = T - 1 + return (; + noise = _alloc_noise(B, T), + R = alloc_like(u0, M, M), + R_chol = alloc_like(u0, M, M), + innovation = [alloc_like(u0, M) for _ in 1:T_obs], + innovation_solved = [alloc_like(u0, M) for _ in 1:T_obs], + ) +end + +function alloc_cache(prob::QuadraticStateSpaceProblem, ::ConditionalLikelihood, T) + (; B, C_0, u0) = prob + M = isnothing(C_0) ? length(u0) : length(C_0) + return _alloc_di_base_cache(B, u0, M, T, true) +end + +function alloc_cache(prob::PrunedQuadraticStateSpaceProblem, ::ConditionalLikelihood, T) + (; B, C_0, u0) = prob + M = isnothing(C_0) ? length(u0) : length(C_0) + base = _alloc_di_base_cache(B, u0, M, T, true) + u_f = [alloc_like(u0) for _ in 1:T] + return (; base..., u_f) +end + +""" + alloc_cache(prob::LinearStateSpaceProblem, ::KalmanFilter, T) + +Allocate scratch workspace for KalmanFilter (prediction, innovation, gain buffers). +""" +function alloc_cache(prob::LinearStateSpaceProblem, ::KalmanFilter, T) + (; A, B, C, u0_prior_mean, u0_prior_var) = prob + N = length(u0_prior_mean) + L = size(C, 1) + T_obs = T - 1 + K_noise = size(B, 2) + + return (; + mu_pred = [alloc_like(u0_prior_mean) for _ in 1:T_obs], + sigma_pred = [alloc_like(u0_prior_var) for _ in 1:T_obs], + A_sigma = [alloc_like(u0_prior_var) for _ in 1:T_obs], + sigma_Gt = [alloc_like(u0_prior_var, N, L) for _ in 1:T_obs], + innovation = [alloc_like(u0_prior_mean, L) for _ in 1:T_obs], + innovation_cov = [alloc_like(u0_prior_var, L, L) for _ in 1:T_obs], + S_chol = [alloc_like(u0_prior_var, L, L) for _ in 1:T_obs], + innovation_solved = [alloc_like(u0_prior_mean, L) for _ in 1:T_obs], + gain_rhs = [alloc_like(C) for _ in 1:T_obs], + gain = [alloc_like(u0_prior_var, N, L) for _ in 1:T_obs], + gainG = [alloc_like(u0_prior_var) for _ in 1:T_obs], + KgSigma = [alloc_like(u0_prior_var) for _ in 1:T_obs], + mu_update = [alloc_like(u0_prior_mean) for _ in 1:T_obs], + B_prod = alloc_like(u0_prior_var), + B_t = alloc_like(B, K_noise, N), + ) +end + +""" + alloc_cache(prob::StateSpaceProblem, ::DirectIteration, T) + +Allocate scratch workspace for generic StateSpaceProblem. +""" +function alloc_cache(prob::StateSpaceProblem, ::DirectIteration, T) + (; u0, n_obs) = prob + B = _noise_matrix(prob) + M = n_obs + T_obs = T - 1 + has_obs_noise = !isnothing(prob.observables_noise) && M > 0 + return (; + noise = _alloc_noise(B, T), + R = has_obs_noise ? alloc_like(u0, M, M) : nothing, + R_chol = has_obs_noise ? alloc_like(u0, M, M) : nothing, + innovation = has_obs_noise ? [alloc_like(u0, M) for _ in 1:T_obs] : nothing, + innovation_solved = has_obs_noise ? [alloc_like(u0, M) for _ in 1:T_obs] : nothing, + ) +end + +# ============================================================================= +# save_everystep=false allocation (endpoints only: 2-element sol, 1-slot cache) +# ============================================================================= + +# --- Shared base cache for endpoints (1-slot innovation buffers) --- + +function _alloc_di_base_cache_endpoints(B, u0, M, T, has_obs_noise) + return (; + noise = _alloc_noise(B, T), + R = has_obs_noise ? alloc_like(u0, M, M) : nothing, + R_chol = has_obs_noise ? alloc_like(u0, M, M) : nothing, + innovation = has_obs_noise ? [alloc_like(u0, M)] : nothing, + innovation_solved = has_obs_noise ? [alloc_like(u0, M)] : nothing, + ) +end + +# --- alloc_sol endpoints: DirectIteration --- + +function alloc_sol(prob::LinearStateSpaceProblem, ::DirectIteration, T, ::Val{false}) + (; u0, C) = prob + M = isnothing(C) ? 0 : size(C, 1) + return (; + u = [alloc_like(u0) for _ in 1:2], + z = isnothing(C) ? nothing : [alloc_like(u0, M) for _ in 1:2], + ) +end + +function alloc_sol(prob::StateSpaceProblem, ::DirectIteration, T, ::Val{false}) + (; u0, n_obs) = prob + return (; + u = [alloc_like(u0) for _ in 1:2], + z = n_obs > 0 ? [alloc_like(u0, n_obs) for _ in 1:2] : nothing, + ) +end + +function alloc_sol(prob::AnyQuadraticProblem, ::DirectIteration, T, ::Val{false}) + (; u0, C_0) = prob + M = isnothing(C_0) ? 0 : length(C_0) + return (; + u = [alloc_like(u0) for _ in 1:2], + z = isnothing(C_0) ? nothing : [alloc_like(u0, M) for _ in 1:2], + ) +end + +# --- alloc_cache endpoints: DirectIteration --- + +function alloc_cache(prob::LinearStateSpaceProblem, ::DirectIteration, T, ::Val{false}) + (; B, C, u0) = prob + M = isnothing(C) ? 0 : size(C, 1) + has_obs_noise = !isnothing(prob.observables_noise) + return _alloc_di_base_cache_endpoints(B, u0, M, T, has_obs_noise) +end + +function alloc_cache(prob::StateSpaceProblem, ::DirectIteration, T, ::Val{false}) + (; u0, n_obs) = prob + B = _noise_matrix(prob) + M = n_obs + has_obs_noise = !isnothing(prob.observables_noise) && M > 0 + return (; + noise = _alloc_noise(B, T), + R = has_obs_noise ? alloc_like(u0, M, M) : nothing, + R_chol = has_obs_noise ? alloc_like(u0, M, M) : nothing, + innovation = has_obs_noise ? [alloc_like(u0, M)] : nothing, + innovation_solved = has_obs_noise ? [alloc_like(u0, M)] : nothing, + ) +end + +function alloc_cache(prob::QuadraticStateSpaceProblem, ::DirectIteration, T, ::Val{false}) + (; B, C_0, u0) = prob + M = isnothing(C_0) ? 0 : length(C_0) + has_obs_noise = !isnothing(prob.observables_noise) + return _alloc_di_base_cache_endpoints(B, u0, M, T, has_obs_noise) +end + +function alloc_cache(prob::PrunedQuadraticStateSpaceProblem, ::DirectIteration, T, ::Val{false}) + (; B, C_0, u0) = prob + M = isnothing(C_0) ? 0 : length(C_0) + has_obs_noise = !isnothing(prob.observables_noise) + base = _alloc_di_base_cache_endpoints(B, u0, M, T, has_obs_noise) + u_f = [alloc_like(u0) for _ in 1:2] + return (; base..., u_f) +end + +# --- alloc_sol endpoints: ConditionalLikelihood --- + +function alloc_sol(prob::LinearStateSpaceProblem, ::ConditionalLikelihood, T, ::Val{false}) + (; u0, C) = prob + M = isnothing(C) ? 0 : size(C, 1) + return (; + u = [alloc_like(u0) for _ in 1:2], + z = isnothing(C) ? nothing : [alloc_like(u0, M) for _ in 1:2], + ) +end + +function alloc_sol(prob::StateSpaceProblem, ::ConditionalLikelihood, T, ::Val{false}) + (; u0, n_obs) = prob + return (; + u = [alloc_like(u0) for _ in 1:2], + z = n_obs > 0 ? [alloc_like(u0, n_obs) for _ in 1:2] : nothing, + ) +end + +function alloc_sol(prob::AnyQuadraticProblem, ::ConditionalLikelihood, T, ::Val{false}) + (; u0, C_0) = prob + M = isnothing(C_0) ? 0 : length(C_0) + return (; + u = [alloc_like(u0) for _ in 1:2], + z = isnothing(C_0) ? nothing : [alloc_like(u0, M) for _ in 1:2], + ) +end + +# --- alloc_cache endpoints: ConditionalLikelihood --- + +function alloc_cache(prob::LinearStateSpaceProblem, ::ConditionalLikelihood, T, ::Val{false}) + (; B, C, u0) = prob + M = isnothing(C) ? length(u0) : size(C, 1) + return _alloc_di_base_cache_endpoints(B, u0, M, T, true) +end + +function alloc_cache(prob::StateSpaceProblem, ::ConditionalLikelihood, T, ::Val{false}) + (; u0, n_obs) = prob + B = _noise_matrix(prob) + M = n_obs > 0 ? n_obs : length(u0) + return (; + noise = _alloc_noise(B, T), + R = alloc_like(u0, M, M), + R_chol = alloc_like(u0, M, M), + innovation = [alloc_like(u0, M)], + innovation_solved = [alloc_like(u0, M)], + ) +end + +function alloc_cache(prob::QuadraticStateSpaceProblem, ::ConditionalLikelihood, T, ::Val{false}) + (; B, C_0, u0) = prob + M = isnothing(C_0) ? length(u0) : length(C_0) + return _alloc_di_base_cache_endpoints(B, u0, M, T, true) +end + +function alloc_cache( + prob::PrunedQuadraticStateSpaceProblem, ::ConditionalLikelihood, T, ::Val{false} + ) + (; B, C_0, u0) = prob + M = isnothing(C_0) ? length(u0) : length(C_0) + base = _alloc_di_base_cache_endpoints(B, u0, M, T, true) + u_f = [alloc_like(u0) for _ in 1:2] + return (; base..., u_f) +end + +# --- alloc_sol/alloc_cache endpoints: KalmanFilter --- + +function alloc_sol(prob::LinearStateSpaceProblem, ::KalmanFilter, T, ::Val{false}) + (; u0_prior_mean, u0_prior_var, C) = prob + L = size(C, 1) + return (; + u = [alloc_like(u0_prior_mean) for _ in 1:2], + P = [alloc_like(u0_prior_var) for _ in 1:2], + z = [alloc_like(u0_prior_mean, L) for _ in 1:2], + ) +end + +function alloc_cache(prob::LinearStateSpaceProblem, ::KalmanFilter, T, ::Val{false}) + (; A, B, C, u0_prior_mean, u0_prior_var) = prob + N = length(u0_prior_mean) + L = size(C, 1) + K_noise = size(B, 2) + + return (; + mu_pred = [alloc_like(u0_prior_mean)], + sigma_pred = [alloc_like(u0_prior_var)], + A_sigma = [alloc_like(u0_prior_var)], + sigma_Gt = [alloc_like(u0_prior_var, N, L)], + innovation = [alloc_like(u0_prior_mean, L)], + innovation_cov = [alloc_like(u0_prior_var, L, L)], + S_chol = [alloc_like(u0_prior_var, L, L)], + innovation_solved = [alloc_like(u0_prior_mean, L)], + gain_rhs = [alloc_like(C)], + gain = [alloc_like(u0_prior_var, N, L)], + gainG = [alloc_like(u0_prior_var)], + KgSigma = [alloc_like(u0_prior_var)], + mu_update = [alloc_like(u0_prior_mean)], + B_prod = alloc_like(u0_prior_var), + B_t = alloc_like(B, K_noise, N), + ) +end diff --git a/src/precompilation.jl b/src/precompilation.jl index a946c38..adb080d 100644 --- a/src/precompilation.jl +++ b/src/precompilation.jl @@ -2,64 +2,109 @@ using PrecompileTools: PrecompileTools, @setup_workload, @compile_workload using LinearAlgebra: I @setup_workload begin - # Minimal setup data for precompilation workload - # Use simple 2x2 system that's typical for state-space models - @compile_workload begin # Common matrices for state-space models (2x2 system) A = [0.9 0.1; 0.0 0.8] B = reshape([0.0; 0.1], 2, 1) C = [1.0 0.0; 0.0 1.0] - D = [0.01, 0.01] + D = Diagonal([0.01, 0.01]) u0 = [0.0, 0.0] T = 10 - # === LinearStateSpaceProblem with DirectIteration (most common) === + # === LinearStateSpaceProblem with DirectIteration === # Simulation without observations prob_sim = LinearStateSpaceProblem(A, B, u0, (0, T)) sol_sim = solve(prob_sim) # Simulation with observation equation - prob_obs = LinearStateSpaceProblem(A, B, u0, (0, T); C = C) + prob_obs = LinearStateSpaceProblem(A, B, u0, (0, T); C) sol_obs = solve(prob_obs) # Simulation with observation noise - prob_noise = LinearStateSpaceProblem(A, B, u0, (0, T); C = C, observables_noise = D) + prob_noise = LinearStateSpaceProblem(A, B, u0, (0, T); C, observables_noise = D) sol_noise = solve(prob_noise) + # === init/solve! API === + ws = CommonSolve.init(prob_obs, DirectIteration()) + sol_ws = CommonSolve.solve!(ws) + # === LinearStateSpaceProblem with KalmanFilter === - # Generate fake observables for Kalman filter - # For tspan = (0, T), we get T+1 time points, so need T observables - observables = randn(2, T) + observables = [randn(2) for _ in 1:T] u0_prior_mean = zeros(2) u0_prior_var = Matrix{Float64}(I, 2, 2) prob_kalman = LinearStateSpaceProblem( - A, B, u0, (0, size(observables, 2)); - C = C, observables_noise = D, observables = observables, - u0_prior_mean = u0_prior_mean, u0_prior_var = u0_prior_var + A, B, u0, (0, length(observables)); + C, observables_noise = D, observables, + u0_prior_mean, u0_prior_var ) sol_kalman = solve(prob_kalman) + # Kalman init/solve! + ws_k = CommonSolve.init(prob_kalman, KalmanFilter()) + sol_k = CommonSolve.solve!(ws_k) + + # === ConditionalLikelihood === + # With C + prob_cl = LinearStateSpaceProblem( + A, nothing, u0, (0, length(observables)); + C, observables_noise = D, observables + ) + sol_cl = solve(prob_cl, ConditionalLikelihood()) + + # Without C (identity observation) + prob_cl_no_c = LinearStateSpaceProblem( + A, nothing, u0, (0, length(observables)); + observables_noise = Diagonal([0.01, 0.01]), observables + ) + sol_cl_no_c = solve(prob_cl_no_c, ConditionalLikelihood()) + + # init/solve! + ws_cl = CommonSolve.init(prob_cl, ConditionalLikelihood()) + sol_cl_ws = CommonSolve.solve!(ws_cl) + # === LinearStateSpaceProblem with no noise matrix === - prob_no_noise = LinearStateSpaceProblem(A, nothing, u0, (0, T); C = C) + prob_no_noise = LinearStateSpaceProblem(A, nothing, u0, (0, T); C) sol_no_noise = solve(prob_no_noise) - # === QuadraticStateSpaceProblem with DirectIteration === - # Use proper dimensions: B has 1 column, so noise needs 1 row - # For tspan = (0, T), we get T+1 time points, so need T noise samples - A_0 = zeros(2) - A_1 = A - A_2 = zeros(2, 2, 2) - C_0 = zeros(2) - C_1 = C - C_2 = zeros(2, 2, 2) - noise_quad = randn(1, T) - - prob_quad = QuadraticStateSpaceProblem( - A_0, A_1, A_2, B, u0, (0, T); - C_0 = C_0, C_1 = C_1, C_2 = C_2, noise = noise_quad + # === StateSpaceProblem with DirectIteration === + gen_f!! = (x_next, x, w, p, t) -> begin + mul!(x_next, A, x) + mul!(x_next, B, w, 1.0, 1.0) + return x_next + end + gen_g!! = (y, x, p, t) -> begin + mul!(y, C, x) + return y + end + prob_gen = StateSpaceProblem( + gen_f!!, gen_g!!, u0, (0, T); + n_shocks = 1, n_obs = 2, + syms = (:x1, :x2), obs_syms = (:y1, :y2) + ) + sol_gen = solve(prob_gen) + sol_gen[:x1] # precompile state indexing + sol_gen[:y1] # precompile obs indexing + + # Generic init/solve! + ws_gen = CommonSolve.init(prob_gen, DirectIteration()) + sol_gen_ws = CommonSolve.solve!(ws_gen) + + # Generic without observations + prob_gen_no_obs = StateSpaceProblem( + gen_f!!, nothing, u0, (0, T); + n_shocks = 1, n_obs = 0 ) - sol_quad = solve(prob_quad) + sol_gen_no_obs = solve(prob_gen_no_obs) + + # === StaticArrays workload === + A_s = SMatrix{2, 2}(0.9, 0.0, 0.1, 0.8) + B_s = SMatrix{2, 1}(0.0, 0.1) + C_s = SMatrix{2, 2}(1.0, 0.0, 0.0, 1.0) + u0_s = SVector{2}(0.0, 0.0) + noise_s = [SVector{1}(randn()) for _ in 1:T] + + prob_s = LinearStateSpaceProblem(A_s, B_s, u0_s, (0, T); C = C_s, noise = noise_s) + sol_s = solve(prob_s) end end diff --git a/src/problems/quadratic_state_space_problems.jl b/src/problems/quadratic_state_space_problems.jl new file mode 100644 index 0000000..694e46d --- /dev/null +++ b/src/problems/quadratic_state_space_problems.jl @@ -0,0 +1,148 @@ +# Quadratic state-space problem types +# Two variants: unpruned (quad on x) and pruned (quad on linear-part u_f) +# Union type for shared dispatch (cache allocation, noise matrix, etc.) + +# --- Unpruned quadratic --- +# x[t+1] = A_0 + A_1 * x[t] + quad(A_2, x[t]) + B * w[t] +# z[t] = C_0 + C_1 * x[t] + quad(C_2, x[t]) + +""" + QuadraticStateSpaceProblem(A_0, A_1, A_2, B, u0, tspan[, p]; kwargs...) + +Define a second-order (quadratic) state-space model: + +```math +u_{n+1} = A_0 + A_1 \\, u_n + u_n^\\top A_2 \\, u_n + B \\, w_{n+1} +``` + +with optional observation equation +``z_n = C_0 + C_1 \\, u_n + u_n^\\top C_2 \\, u_n + v_n``. + +# Positional Arguments +- `A_0`: Constant drift vector (length n). +- `A_1`: Linear transition matrix (n×n). +- `A_2`: Quadratic transition tensor (n×n×n). Entry `A_2[i,:,:]` gives the matrix + for the `i`-th element of the quadratic term. +- `B`: Noise input matrix (n×k), or `nothing`. +- `u0`: Initial state vector. +- `tspan`: Time span as `(t0, t_end)`. + +# Keyword Arguments +- `C_0`, `C_1`, `C_2`: Observation equation coefficients (analogous to `A_0`, `A_1`, `A_2`). +- `observables_noise`, `observables`, `noise`, `syms`, `obs_syms`: Same as + [`LinearStateSpaceProblem`](@ref). + +# References +- Andreasen, Fernandez-Villaverde, and Rubio-Ramirez (2017), + "The Pruned State-Space System for Non-Linear DSGE Models: Theory and Empirical Applications." + +See also: [`PrunedQuadraticStateSpaceProblem`](@ref), [`LinearStateSpaceProblem`](@ref). +""" +@concrete struct QuadraticStateSpaceProblem <: AbstractStateSpaceProblem + f # ODEFunction (SciML interface/syms only) + A_0 # Constant drift vector + A_1 # Linear transition matrix + A_2 # Quadratic transition tensor (N, N, N) + B # Noise input matrix (or nothing) + C_0 # Observation constant (or nothing) + C_1 # Observation linear matrix (or nothing) + C_2 # Observation quadratic tensor (or nothing) + observables_noise + observables + u0 + tspan + p + noise + obs_syms + kwargs +end + +function QuadraticStateSpaceProblem( + A_0, A_1, A_2, B, u0, tspan, p = NullParameters(); + C_0 = nothing, C_1 = nothing, C_2 = nothing, + observables_noise = nothing, observables = nothing, + noise = nothing, syms = nothing, obs_syms = nothing, kwargs... + ) + f = ODEFunction{false}( + (u, p, t) -> error("not implemented"); + sys = SymbolCache(syms) + ) + _tspan = promote_tspan(tspan) + _dt = _tspan[2] - _tspan[1] + isinteger(_dt) || throw(ArgumentError("tspan must have integer distance, got $_dt")) + return QuadraticStateSpaceProblem( + f, A_0, A_1, A_2, B, C_0, C_1, C_2, + observables_noise, observables, u0, _tspan, p, noise, obs_syms, kwargs + ) +end + +# --- Pruned quadratic --- +# u_f[t+1] = A_1 * u_f[t] + B * w[t] +# x[t+1] = A_0 + A_1 * x[t] + quad(A_2, u_f[t]) + B * w[t] +# z[t] = C_0 + C_1 * x[t] + quad(C_2, u_f[t]) + +""" + PrunedQuadraticStateSpaceProblem(A_0, A_1, A_2, B, u0, tspan[, p]; kwargs...) + +Define a pruned second-order state-space model. Unlike [`QuadraticStateSpaceProblem`](@ref), +the quadratic terms operate on a separate linear-part state ``u_f`` rather than the full state: + +```math +u_f^{n+1} = A_1 \\, u_f^n + B \\, w_{n+1} +``` +```math +u_{n+1} = A_0 + A_1 \\, u_n + (u_f^n)^\\top A_2 \\, u_f^n + B \\, w_{n+1} +``` + +The observation equation similarly uses ``u_f``: +``z_n = C_0 + C_1 \\, u_n + (u_f^n)^\\top C_2 \\, u_f^n + v_n``. + +This pruning approach prevents explosive dynamics in higher-order perturbation solutions. +Arguments are identical to [`QuadraticStateSpaceProblem`](@ref). + +# References +- Andreasen, Fernandez-Villaverde, and Rubio-Ramirez (2017), + "The Pruned State-Space System for Non-Linear DSGE Models: Theory and Empirical Applications." + +See also: [`QuadraticStateSpaceProblem`](@ref). +""" +@concrete struct PrunedQuadraticStateSpaceProblem <: AbstractStateSpaceProblem + f # ODEFunction (SciML interface/syms only) + A_0 # Constant drift vector + A_1 # Linear transition matrix + A_2 # Quadratic transition tensor (N, N, N) + B # Noise input matrix (or nothing) + C_0 # Observation constant (or nothing) + C_1 # Observation linear matrix (or nothing) + C_2 # Observation quadratic tensor (or nothing) + observables_noise + observables + u0 + tspan + p + noise + obs_syms + kwargs +end + +function PrunedQuadraticStateSpaceProblem( + A_0, A_1, A_2, B, u0, tspan, p = NullParameters(); + C_0 = nothing, C_1 = nothing, C_2 = nothing, + observables_noise = nothing, observables = nothing, + noise = nothing, syms = nothing, obs_syms = nothing, kwargs... + ) + f = ODEFunction{false}( + (u, p, t) -> error("not implemented"); + sys = SymbolCache(syms) + ) + _tspan = promote_tspan(tspan) + _dt = _tspan[2] - _tspan[1] + isinteger(_dt) || throw(ArgumentError("tspan must have integer distance, got $_dt")) + return PrunedQuadraticStateSpaceProblem( + f, A_0, A_1, A_2, B, C_0, C_1, C_2, + observables_noise, observables, u0, _tspan, p, noise, obs_syms, kwargs + ) +end + +# Union for shared dispatch (cache allocation, noise matrix, etc.) +const AnyQuadraticProblem = Union{QuadraticStateSpaceProblem, PrunedQuadraticStateSpaceProblem} diff --git a/src/problems/state_space_problems.jl b/src/problems/state_space_problems.jl index 351d7d3..ce20305 100644 --- a/src/problems/state_space_problems.jl +++ b/src/problems/state_space_problems.jl @@ -1,13 +1,20 @@ -abstract type AbstractStateSpaceProblem <: AbstractDEProblem end -abstract type AbstractPerturbationProblem <: AbstractStateSpaceProblem end +""" + AbstractStateSpaceProblem <: DEProblem + +Abstract supertype for all discrete-time state-space problems in DifferenceEquations.jl. + +Subtypes include [`LinearStateSpaceProblem`](@ref), [`QuadraticStateSpaceProblem`](@ref), +[`PrunedQuadraticStateSpaceProblem`](@ref), and [`StateSpaceProblem`](@ref). +""" +abstract type AbstractStateSpaceProblem <: DEProblem end # TODO: Can add in more checks on the algorithm choice DiffEqBase.check_prob_alg_pairing(prob::AbstractStateSpaceProblem, alg) = nothing -# Perturbation problesm don't have f, g +# Perturbation problems don't have f, g # In discrete time, tspan should not have a sensitivity so the concretization is less obvious function DiffEqBase.get_concrete_problem( - prob::AbstractPerturbationProblem, isadapt; + prob::AbstractStateSpaceProblem, isadapt; kwargs... ) p = get_concrete_p(prob, kwargs) @@ -20,20 +27,55 @@ function DiffEqBase.get_concrete_problem( p === prob.p && return prob else - return remake(prob; u0 = u0_promote, p = p) + return remake(prob; u0 = u0_promote, p) end end -SciMLBase.isinplace(prob::AbstractPerturbationProblem) = false # necessary for the get_concrete_u0 overloads +SciMLBase.isinplace(prob::AbstractStateSpaceProblem) = false # necessary for the get_concrete_u0 overloads + +""" + LinearStateSpaceProblem(A, B, u0, tspan[, p]; kwargs...) + +Define a linear time-invariant state-space model: + +```math +u_{n+1} = A \\, u_n + B \\, w_{n+1} +``` + +with optional observation equation ``z_n = C \\, u_n + v_n``. + +# Positional Arguments +- `A`: Transition matrix (n×n). +- `B`: Noise input matrix (n×k), or `nothing` for deterministic dynamics. +- `u0`: Initial state vector, or a `Distribution` for random initial conditions. +- `tspan`: Time span as `(t0, t_end)` with integer distance (e.g., `(0, T)`). +- `p`: Parameters (default: `NullParameters()`). + +# Keyword Arguments +- `C`: Observation matrix (m×n). If `nothing`, no observations are computed. +- `observables_noise`: Observation noise covariance matrix (`AbstractMatrix`, e.g. `Diagonal(d)` or `Symmetric(H * H')`). +- `observables`: Observed data as `Vector{Vector{T}}` for likelihood computation. +- `noise`: Fixed noise sequence as `Vector{Vector{T}}`. If `nothing`, noise is drawn randomly. +- `u0_prior_mean`: Prior mean for Kalman filtering. +- `u0_prior_var`: Prior covariance matrix for Kalman filtering. +- `syms`: State variable names (e.g., `(:x, :y)`) for symbolic indexing. +- `obs_syms`: Observation variable names for symbolic indexing. + +# Notes +- Providing `u0_prior_mean`, `u0_prior_var`, `observables`, and `observables_noise` + (with `noise = nothing`) triggers automatic selection of [`KalmanFilter`](@ref). +- The `observables` timing convention: observations correspond to ``z_1, z_2, \\ldots`` + (starting from the second state), so pass `T` observations for a `tspan` of `(0, T)`. -# the {iip} isn't relevant here at this point, but if we remove it then there are failures in the "remake" call above -# when using the Ensemble unit tests +See also: [`StateSpaceProblem`](@ref), [`QuadraticStateSpaceProblem`](@ref), +[`DirectIteration`](@ref), [`KalmanFilter`](@ref). +""" struct LinearStateSpaceProblem{ uType, uPriorMeanType, uPriorVarType, tType, P, NP, F, AType, BType, CType, - RType, ObsType, K, + RType, ObsType, OS, K, } <: - AbstractPerturbationProblem + AbstractStateSpaceProblem f::F # HACK: used only for standard interfaces/syms/etc., not used in calculations A::AType B::BType @@ -46,6 +88,7 @@ struct LinearStateSpaceProblem{ tspan::tType p::P noise::NP + obs_syms::OS kwargs::K @add_kwonly function LinearStateSpaceProblem{iip}( A, B, u0, tspan, p = NullParameters(); @@ -55,31 +98,34 @@ struct LinearStateSpaceProblem{ observables = nothing, noise = nothing, syms = nothing, - f = ODEFunction{false}( - (u, p, t) -> error("not implemented"); - sys = syms === nothing ? nothing : SymbolCache(collect(syms)) - ), + obs_syms = nothing, + f = nothing, kwargs... ) where {iip} + if f === nothing + f = ODEFunction{false}( + (u, p, t) -> error("not implemented"); + sys = SymbolCache(syms) + ) + end _tspan = promote_tspan(tspan) - # _observables = promote_vv(observables) _observables = observables - # Require integer distances between time periods for now. Later could check with dt != 1 - @assert round(_tspan[2] - _tspan[1]) - (_tspan[2] - _tspan[1]) ≈ 0.0 + _dt = _tspan[2] - _tspan[1] + isinteger(_dt) || throw(ArgumentError("tspan must have integer distance, got $_dt")) return new{ typeof(u0), typeof(u0_prior_mean), typeof(u0_prior_var), typeof(_tspan), typeof(p), typeof(noise), typeof(f), typeof(A), typeof(B), typeof(C), typeof(observables_noise), - typeof(_observables), + typeof(_observables), typeof(obs_syms), typeof(kwargs), }( f, A, B, C, observables_noise, _observables, u0, u0_prior_mean, u0_prior_var, - _tspan, p, noise, kwargs + _tspan, p, noise, obs_syms, kwargs ) end end @@ -88,93 +134,87 @@ function LinearStateSpaceProblem(args...; kwargs...) return LinearStateSpaceProblem{false}(args...; kwargs...) end -# """ -# u_f(t+1) = A_1 u_f(t) .+ B * noise(t+1) -# u(t+1) = A_0 + A_1 u(t) + quad(A_2, u_f(t)) .+ B noise(t+1) -# z(t) = C_0 + C_1 u(t) + quad(C_2, u_f(t)) -# z_tilde(t) = z(t) + v(t+1) -# """ -struct QuadraticStateSpaceProblem{ - uType, uPriorMeanType, uPriorVarType, tType, P, NP, F, - A0Type, A1Type, - A2Type, BType, C0Type, - C1Type, C2Type, RType, ObsType, K, - } <: - AbstractPerturbationProblem +""" + StateSpaceProblem(transition, observation, u0, tspan[, p]; n_shocks, kwargs...) + +Define a generic state-space model with user-provided callback functions: + +```math +u_{n+1} = f(u_n, w_{n+1}, p, t_n), \\quad z_n = g(u_n, p, t_n) +``` + +# Positional Arguments +- `transition`: Callback `f!!(x_next, x, w, p, t) -> x_next`. For mutable arrays, mutate + `x_next` in place and return it; for immutable arrays (e.g., `SVector`), return a new value. +- `observation`: Callback `g!!(y, x, p, t) -> y`, or `nothing` for no observations. +- `u0`: Initial state vector, or a `Distribution` for random initial conditions. +- `tspan`: Time span as `(t0, t_end)` with integer distance. +- `p`: Parameters passed to the callbacks (default: `NullParameters()`). + +# Keyword Arguments +- `n_shocks::Int`: Number of noise dimensions (required). +- `n_obs::Int`: Number of observation dimensions (default: `0`). +- `observables_noise`: Observation noise covariance matrix (`AbstractMatrix`, e.g. `Diagonal(d)` or `Symmetric(H * H')`). +- `observables`: Observed data as `Vector{Vector{T}}`. +- `noise`: Fixed noise sequence as `Vector{Vector{T}}`. +- `syms`: State variable names for symbolic indexing. +- `obs_syms`: Observation variable names for symbolic indexing. + +See also: [`LinearStateSpaceProblem`](@ref), [`DirectIteration`](@ref). +""" +struct StateSpaceProblem{ + uType, tType, P, NP, TF, GF, F, + RType, ObsType, OS, K, + } <: AbstractStateSpaceProblem f::F # HACK: used only for standard interfaces/syms/etc., not used in calculations - A_0::A0Type - A_1::A1Type - A_2::A2Type - B::BType - C_0::C0Type - C_1::C1Type - C_2::C2Type + transition::TF # f!!(x_next, x, w, p, t) -> x_next + observation::GF # g!!(y, x, p, t) -> y (or nothing) observables_noise::RType observables::ObsType u0::uType - u0_prior_mean::uPriorMeanType - u0_prior_var::uPriorVarType tspan::tType p::P noise::NP + n_shocks::Int + n_obs::Int # 0 if no observation equation + obs_syms::OS kwargs::K - @add_kwonly function QuadraticStateSpaceProblem{iip}( - A_0, A_1, A_2, B, u0, tspan, - p = NullParameters(); - u0_prior_mean = nothing, - u0_prior_var = nothing, - C_0 = nothing, C_1 = nothing, - C_2 = nothing, + @add_kwonly function StateSpaceProblem{iip}( + transition, observation, u0, tspan, p = NullParameters(); + n_shocks, + n_obs = 0, observables_noise = nothing, observables = nothing, noise = nothing, syms = nothing, - f = ODEFunction{false}( - (u, p, t) -> error("not implemented"); - sys = syms === nothing ? nothing : SymbolCache(collect(syms)) - ), + obs_syms = nothing, + f = nothing, kwargs... ) where {iip} + if f === nothing + f = ODEFunction{false}( + (u, p, t) -> error("not implemented"); + sys = SymbolCache(syms) + ) + end _tspan = promote_tspan(tspan) - # _observables = promote_vv(observables) _observables = observables - # Require integer distances between time periods for now. Later could check with dt != 1 - @assert round(_tspan[2] - _tspan[1]) - (_tspan[2] - _tspan[1]) ≈ 0.0 + _dt = _tspan[2] - _tspan[1] + isinteger(_dt) || throw(ArgumentError("tspan must have integer distance, got $_dt")) return new{ - typeof(u0), typeof(u0_prior_mean), typeof(u0_prior_var), typeof(_tspan), - typeof(p), - typeof(noise), typeof(f), - typeof(A_0), typeof(A_1), typeof(A_2), typeof(B), typeof(C_0), - typeof(C_1), - typeof(C_2), typeof(observables_noise), typeof(_observables), + typeof(u0), typeof(_tspan), typeof(p), typeof(noise), + typeof(transition), typeof(observation), typeof(f), + typeof(observables_noise), typeof(_observables), typeof(obs_syms), typeof(kwargs), }( - f, - A_0, - A_1, - A_2, - B, - C_0, - C_1, - C_2, - observables_noise, - _observables, - u0, - u0_prior_mean, - u0_prior_var, - _tspan, - p, - noise, - kwargs + f, transition, observation, observables_noise, _observables, + u0, _tspan, p, noise, n_shocks, n_obs, obs_syms, kwargs ) end end # just forwards to a iip = false case -function QuadraticStateSpaceProblem(args...; kwargs...) - return QuadraticStateSpaceProblem{false}( - args...; - kwargs... - ) +function StateSpaceProblem(args...; kwargs...) + return StateSpaceProblem{false}(args...; kwargs...) end diff --git a/src/solutions/state_space_solutions.jl b/src/solutions/state_space_solutions.jl index c99891c..de246dc 100644 --- a/src/solutions/state_space_solutions.jl +++ b/src/solutions/state_space_solutions.jl @@ -1,3 +1,32 @@ +""" + StateSpaceSolution + +Solution type returned by `solve` for all state-space problems. + +# Fields +- `u`: State trajectory as `Vector{Vector{T}}`. +- `t`: Time values. +- `z`: Observation trajectory as `Vector{Vector{T}}`, or `nothing`. +- `W`: Noise sequence as `Vector{Vector{T}}`, or `nothing` (e.g., for `KalmanFilter`). +- `P`: Posterior covariances as `Vector{Matrix{T}}` (`KalmanFilter` only), or `nothing`. +- `logpdf`: Log-likelihood value. Zero when no `observables` are provided. +- `retcode`: `ReturnCode.Success`. Errors are thrown as exceptions, not encoded in the return code. +- `prob`: The original problem. +- `alg`: The algorithm used. + +# Symbolic Indexing +Access time series by symbol name: +```julia +sol[:x] # state variable time series (requires `syms`) +sol[:output] # observation time series (requires `obs_syms`) +``` + +# Standard Indexing +```julia +sol[i] # state at time step i (same as sol.u[i]) +sol[end] # final state +``` +""" struct StateSpaceSolution{ T, N, uType, uType2, DType, tType, randType, P, A, IType, DE, PosteriorType, @@ -14,7 +43,7 @@ struct StateSpaceSolution{ dense::Bool tslocation::Int stats::DE - retcode::Symbol + retcode::SciMLBase.ReturnCode.T P::PosteriorType logpdf::logpdfType z::zType @@ -27,7 +56,7 @@ function SciMLBase.build_solution( dense = false, dense_errors = dense, calculate_error = true, interp = ConstantInterpolation(t, u), - retcode = :Default, + retcode = ReturnCode.Default, stats = nothing, z = nothing, kwargs... ) T = eltype(eltype(u)) @@ -48,11 +77,29 @@ function SciMLBase.build_solution( return sol end -# Just using ConstantInterpolation for now. Worth specializing? -# (sol::StateSpaceSolution)(t, ::Type{deriv} = Val{0}; idxs = nothing, continuity = :left) where {deriv} = _interpolate(sol, t, idxs) -# _interpolate(sol::StateSpaceSolution, t::Integer, idxs::Nothing) = sol.u[t] -# _interpolate(sol::StateSpaceSolution, t::Number, idxs::Nothing) = sol.u[Integer(round(t))] -# _interpolate(sol::StateSpaceSolution, t::Integer, idxs) = sol.u[t][idxs] +# TODO: Worth specializing interpolation beyond ConstantInterpolation? + +"""Return observation symbols from the problem, or nothing.""" +obs_syms(sol::StateSpaceSolution) = sol.prob.obs_syms + +Base.@propagate_inbounds function Base.getindex(sol::StateSpaceSolution, sym::Symbol) + # Check observation symbols first + _obs_syms = sol.prob.obs_syms + if _obs_syms !== nothing + idx = findfirst(==(sym), _obs_syms) + if idx !== nothing + sol.z === nothing && + error("Observation symbol $sym found but no observations in solution") + return [sol.z[t][idx] for t in eachindex(sol.z)] + end + end + # Check state symbols via the ODEFunction's SymbolCache + state_idx = variable_index(sol.prob.f.sys, sym) + if state_idx !== nothing + return [sol.u[t][state_idx] for t in eachindex(sol.u)] + end + throw(ArgumentError("Symbol $sym not found in state or observation symbols")) +end # For recipes SciMLBase.getindepsym(sol::StateSpaceSolution) = :t diff --git a/src/solve.jl b/src/solve.jl index b739c33..d4a385f 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,9 +1,63 @@ -using DiffEqBase: AbstractDEAlgorithm, KeywordArgSilent +using DiffEqBase: DEAlgorithm, KeywordArgSilent -abstract type AbstractDifferenceEquationAlgorithm <: AbstractDEAlgorithm end +abstract type AbstractDifferenceEquationAlgorithm <: DEAlgorithm end + +""" + DirectIteration() + +Forward iteration algorithm for state-space problems. Iterates the state transition +equation forward in time, computing the state trajectory `u`, observations `z`, +noise history `W`, and (if `observables` are provided) the joint log-likelihood `logpdf`. + +This is the default algorithm for all problem types. + +See also: [`KalmanFilter`](@ref). +""" struct DirectIteration <: AbstractDifferenceEquationAlgorithm end + +""" + KalmanFilter() + +Kalman filter algorithm for [`LinearStateSpaceProblem`](@ref). Computes filtered +state estimates, posterior covariances, and the marginal log-likelihood. + +Automatically selected when the problem provides: +- `u0_prior_mean` and `u0_prior_var` (Gaussian prior), +- `observables` (observed data), +- `observables_noise` (observation noise covariance), +- `noise = nothing` (latent noise is not fixed). + +The solution contains filtered means in `sol.u`, posterior covariances in `sol.P`, +predicted observations in `sol.z`, and the marginal log-likelihood in `sol.logpdf`. + +See also: [`DirectIteration`](@ref). +""" struct KalmanFilter <: AbstractDifferenceEquationAlgorithm end +""" + ConditionalLikelihood() + +Conditional likelihood (prediction error decomposition) algorithm for +fully-observed state-space models. At each step, predicts the next +observation from the *observed* current state using the transition equation, +and accumulates the Gaussian log-likelihood of the innovation. + +Works with all problem types (`LinearStateSpaceProblem`, `StateSpaceProblem`, +`QuadraticStateSpaceProblem`, `PrunedQuadraticStateSpaceProblem`). The only +requirement is additive Gaussian observation noise. + +Requires: +- `observables` (observed data y₁, …, y_T), +- `observables_noise` (innovation covariance R). + +The solution contains predicted observations in `sol.z` (when an observation +equation is present), the conditional log-likelihood in `sol.logpdf`, and the +state trajectory (clamped to observables) in `sol.u`. + +See also: [`DirectIteration`](@ref), [`KalmanFilter`](@ref). +""" +struct ConditionalLikelihood <: AbstractDifferenceEquationAlgorithm end + # The typical algorithm in discrete-time is DirectIteration() # Unlike continuous time, there aren't many simple variations default_alg(prob::AbstractStateSpaceProblem) = DirectIteration() @@ -13,8 +67,7 @@ function default_alg( prob::LinearStateSpaceProblem{ uType, uPriorMeanType, uPriorVarType, tType, P, NP, F, AType, BType, CType, - RType, ObsType, - K, + RType, ObsType, OS, K, } ) where { uType, @@ -31,13 +84,10 @@ function default_alg( CType <: AbstractMatrix, RType <: - Union{ - AbstractVector, - AbstractMatrix, - }, - ObsType <: AbstractMatrix, - K, + ObsType <: + AbstractVector, + OS, K, } return KalmanFilter() end diff --git a/src/utilities.jl b/src/utilities.jl index 810f7a7..d76d617 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -1,15 +1,63 @@ -# This file contains utilities for use in the algorithms to make the code more generic and able to handle different model variations (e.g., no observables, no observation equation, etc.) - -# Temporary. Eventually, move to use sciml NoiseProcess with better rng support/etc. -get_concrete_noise(prob, noise, B, T) = noise # maybe do a promotion to an AbstractVectorOfVector type -get_concrete_noise(prob, noise, B::Nothing, T) = nothing # if no noise matrix given, do not create noise -get_concrete_noise(prob, noise::Nothing, B::Nothing, T) = nothing # if no noise matrix given, do not create noise -get_concrete_noise(prob, noise::Nothing, B, T) = randn(eltype(B), size(B, 2), T) # default is unit Gaussian -get_concrete_noise(prob, noise::UnivariateDistribution, B, T) = rand(noise, size(B, 2), T) # iid -get_concrete_noise(prob, noise::UnivariateDistribution, B::Nothing, T) = nothing # disambiguation: no noise matrix takes precedence - -# Utility functions to conditionally check size if not-nothing -maybe_check_size(m::AbstractMatrix, index::Integer, val::Integer) = (size(m, index) == val) +# Utilities for algorithms to handle different model variations +# (e.g., no observables, no observation equation, etc.) + +# ============================================================================= +# Noise handling — vector of vectors only +# ============================================================================= + +# Pass-through: already a vector of vectors +get_concrete_noise(prob, noise::AbstractVector{<:AbstractVector}, B, T) = noise + +# No noise matrix: no noise regardless of noise argument +get_concrete_noise(prob, noise, B::Nothing, T) = nothing +get_concrete_noise(prob, noise::Nothing, B::Nothing, T) = nothing +# Disambiguation: B=nothing takes precedence +get_concrete_noise(prob, noise::AbstractVector{<:AbstractVector}, B::Nothing, T) = nothing + +# Generate random noise as vector of vectors +function get_concrete_noise(prob, noise::Nothing, B, T) + return [randn(eltype(B), size(B, 2)) for _ in 1:T] +end +function get_concrete_noise(prob, noise::Nothing, B::StaticMatrix, T) + K = size(B, 2) + return [SVector{K}(randn(eltype(B), K)) for _ in 1:T] +end + +# ============================================================================= +# Copy noise into cache buffers +# ============================================================================= + +""" + copy_noise_to_cache!(cache_noise, noise) + +Copy concrete noise into preallocated cache noise buffers. +""" +function copy_noise_to_cache!(cache_noise, noise) + @inbounds for t in eachindex(cache_noise) + cache_noise[t] = assign!!(cache_noise[t], noise[t]) + end + return cache_noise +end +copy_noise_to_cache!(cache_noise, ::Nothing) = nothing +copy_noise_to_cache!(::Nothing, ::Nothing) = nothing + +# ============================================================================= +# Observables handling — vector of vectors only +# ============================================================================= + +""" + get_observable(observables::AbstractVector{<:AbstractVector}, t) + +Get observation at time t from vector-of-vectors observables. +""" +Base.@propagate_inbounds @inline get_observable( + observables::AbstractVector{<:AbstractVector}, t +) = observables[t] + +# ============================================================================= +# Conditional size checking +# ============================================================================= + maybe_check_size(m::AbstractVector, index::Integer, val::Integer) = (index == 1 ? length(m) == val : true) maybe_check_size(m::Nothing, index::Integer, val::Integer) = true @@ -17,179 +65,70 @@ function maybe_check_size( m1::AbstractArray, index1::Integer, m2::AbstractArray, index2::Integer ) - return ( - size(m1, index1) == - size(m2, index2) - ) + return size(m1, index1) == size(m2, index2) end maybe_check_size(m1::Nothing, index1::Integer, m2, index2::Integer) = true maybe_check_size(m1, index1::Integer, m2::Nothing, index2::Integer) = true maybe_check_size(m1::Nothing, index1::Integer, m2::Nothing, index2::Integer) = true -Base.@propagate_inbounds @inline function maybe_logpdf( - observables_noise::Distribution, - observables::AbstractMatrix, t, - z::AbstractVector, s - ) - return logpdf( - observables_noise, - view( - observables, - :, - t - ) - - z[s] - ) -end -# Don't accumulate likelihoods if no observations or observatino noise -maybe_logpdf(observables_noise, observable, t, z, s) = 0.0 - -# If no noise process is given, don't add in noise in simulation -Base.@propagate_inbounds @inline function maybe_muladd!(x, B, noise, t) - return mul!(x, B, view(noise, :, t), 1, 1) -end -maybe_muladd!(x, B::Nothing, noise, t) = nothing - -Base.@propagate_inbounds @inline maybe_muladd!(x, A, B) = mul!(x, A, B, 1, 1) -maybe_muladd!(x, A::Nothing, B) = nothing - -# need transpose versions for gradients -Base.@propagate_inbounds @inline maybe_muladd_transpose!(x, C, Δz) = mul!(x, C', Δz, 1, 1) -maybe_muladd_transpose!(x, C::Nothing, Δz) = nothing -Base.@propagate_inbounds @inline function maybe_muladd_transpose!( - ΔB::AbstractMatrix, - Δu_temp, - noise::AbstractMatrix, t - ) - mul!(ΔB, Δu_temp, view(noise, :, t)', 1, 1) - return nothing -end -maybe_muladd_transpose!(ΔB, Δu_temp, noise, t) = nothing -Base.@propagate_inbounds @inline maybe_mul!(x, t, A, y, s) = mul!(x[t], A, y[s]) -maybe_mul!(x::Nothing, t, A, y, s) = nothing -# Need transpose versions for rrule -Base.@propagate_inbounds @inline maybe_mul_transpose!(x, t, A, y, s) = mul!(x[t], A', y[s]) -maybe_mul_transpose!(x::Nothing, t, A, y, s) = nothing -Base.@propagate_inbounds @inline function maybe_mul_transpose!(Δnoise, t, B, y) - return mul!( - view(Δnoise, :, t), - B', y - ) +# Size check for vector of vectors +function maybe_check_size(m::AbstractVector{<:AbstractVector}, index::Integer, val::Integer) + if index == 1 + return isempty(m) || length(m[1]) == val + elseif index == 2 + return length(m) == val + end + return true end -maybe_mul_transpose!(Δnoise::Nothing, t, B, y) = nothing -# Utilities to get distribution for logpdf from observation error argument -make_observables_noise(observables_noise::Nothing) = nothing -make_observables_noise(observables_noise::AbstractMatrix) = MvNormal(observables_noise) -function make_observables_noise(observables_noise::AbstractVector) - return MvNormal(Diagonal(observables_noise)) -end +# ============================================================================= +# Ping-pong index for save_everystep=false (2-element buffer) +# ============================================================================= -# Utilities to get covariance matrix from observation error argument for kalman filter. e.g. vector is diagonal, etc. -make_observables_covariance_matrix(observables_noise::AbstractMatrix) = observables_noise -function make_observables_covariance_matrix(observables_noise::AbstractVector) - return Diagonal(observables_noise) -end +""" + _u_idx_pingpong(t) -#Add in observation noise to the output if simulated (i.e, observables not given) and there is observation_noise provided -function maybe_add_observation_noise!( - z, observables_noise::Distribution, - observables::Nothing - ) - # add noise to the vector of vectors componentwise - for z_val in z - z_val .+= rand(observables_noise) - end - return nothing -end -maybe_add_observation_noise!(z, observables_noise, observables) = nothing #otherwise do nothing +Ping-pong index mapping for 2-element buffers: t=1→1, t=2→2, t=3→1, t=4→2, ... +Used by `save_everystep=false` solver loops to alternate between two scratch slots. +""" +@inline _u_idx_pingpong(t) = 2 - t % 2 -#Maybe add observation noise, if observables and their adjoints given -Base.@propagate_inbounds @inline function maybe_add_Δ!(Δz, Δsol_z::AbstractVector, t) - Δz .+= Δsol_z[t] - return nothing -end -maybe_add_Δ!(Δz, Δsol_z, t) = nothing +# ============================================================================= +# Observation noise covariance +# ============================================================================= -Base.@propagate_inbounds @inline function maybe_add_Δ_slice!( - Δnoise::AbstractMatrix, - ΔW::AbstractMatrix, t +# Covariance matrix for Kalman filter and loglik computation. +# observables_noise must be an AbstractMatrix (e.g., Diagonal(d), Symmetric(H*H'), or Matrix). +make_observables_covariance_matrix(observables_noise::AbstractMatrix) = observables_noise +function make_observables_covariance_matrix(observables_noise::AbstractVector) + return error( + "observables_noise must be an AbstractMatrix (e.g., Diagonal(d)). " * + "Got a Vector. Use Diagonal(d) to construct a diagonal covariance matrix." ) - Δnoise[:, t] .+= view(ΔW, :, t) - return nothing -end -maybe_add_Δ_slice!(Δz, Δsol_A, t) = nothing - -# Don't add logpdf to observables unless provided -# TODO: check if this can be replaced with the following and if it has a performance regression for diagonal noise covariance -# ldiv!(Δz, observables_noise.Σ.chol, innovation[t]) -# rmul!(Δlogpdf, Δz) -Base.@propagate_inbounds @inline function maybe_add_Δ_logpdf!( - Δz::AbstractArray{<:Real, 1}, - Δlogpdf::Number, - observables::AbstractArray{ - <:Real, - 2, - }, - z::AbstractArray{T, 1}, - t, - observables_noise_cov::AbstractArray{ - <:Real, - 1, - } - ) where { - T, - } - Δz .= Δlogpdf * (view(observables, :, t - 1) - z[t]) ./ - observables_noise_cov - return nothing -end -# Otherwise do nothing -function maybe_add_Δ_logpdf!(Δz, Δlogpdf, observables, z, t, observables_noise_cov) - return nothing -end - -# Only allocate if observation equation -allocate_z(prob, C, u0, T) = [zeros(size(C, 1)) for _ in 1:T] -allocate_z(prob, C::Nothing, u0, T) = nothing - -# Maybe zero -maybe_zero(A::AbstractArray) = zero(A) -maybe_zero(A::Nothing) = nothing -maybe_zero(A::AbstractArray, i::Int64) = zero(A[i]) -maybe_zero(A::Nothing, i) = nothing - -# old quad and adjoint replaced by inplace accumulation versions. -# function quad(A::AbstractArray{<:Number,3}, x) -# return map(j -> dot(x, view(A, j, :, :), x), 1:size(A, 1)) -# end -# # quadratic form pullback -# function quad_pb(Δres::AbstractVector, A::AbstractArray{<:Number,3}, x::AbstractVector) -# ΔA = similar(A) -# Δx = zeros(length(x)) -# tmp = x * x' -# for i in 1:size(A, 1) -# ΔA[i, :, :] .= tmp .* Δres[i] -# Δx += (A[i, :, :] + A[i, :, :]') * x .* Δres[i] -# end -# return ΔA, Δx -# end - -# y += quad(A, x) -# The quad_muladd! uses on a vector of matrices for A -function quad_muladd!(y, A, x) - @inbounds for j in 1:size(A, 1) - @views y[j] += dot(x, A[j], x) - end - return y end -# inplace version with accumulation and using the cache of A[i] + A[i]', etc. -function quad_muladd_pb!(ΔA_vec, Δx, Δres, A_vec_sum, x) - tmp = x * x' # could add in a temp here - @inbounds for (i, A_sum) in enumerate(A_vec_sum) # @views @inbounds ADD - ΔA_vec[i] .+= tmp .* Δres[i] - Δx .+= A_sum * x .* Δres[i] +# ============================================================================= +# Observation noise simulation (for DirectIteration without observables) +# ============================================================================= + +""" + _add_observation_noise!!(z, F_chol) + +Add observation noise to simulated observations using a pre-computed Cholesky factor. +`F_chol` is an upper-triangular Cholesky factor (R = U'U), so L = U'. +Bang-bang: works with both mutable (Vector) and immutable (SVector) observation elements. +""" +function _add_observation_noise!!(z, F_chol) + M = size(F_chol, 1) + @inbounds for t in eachindex(z) + if !isnothing(z[t]) + noise = F_chol.L * randn(M) + if ismutable(z[t]) + z[t] .+= noise + else + z[t] = z[t] + noise + end + end end return nothing end diff --git a/src/utilities_bangbang.jl b/src/utilities_bangbang.jl new file mode 100644 index 0000000..2886b19 --- /dev/null +++ b/src/utilities_bangbang.jl @@ -0,0 +1,255 @@ +# Utility functions for generic array operations +# These work with both mutable arrays (Vector) and immutable arrays (SVector) + +""" + mul!!(Y, A, B) + +Computes `Y = A * B`. +- If `Y` is mutable (e.g., Vector), it mutates `Y` in-place using `mul!`. +- If `Y` is immutable (e.g., SVector), it returns a new result. +""" +@inline function mul!!(Y, A, B) + if ismutable(Y) + mul!(Y, A, B) + return Y + else + return A * B + end +end + +""" + mul!!(Y, A, B, α, β) + +Computes `Y = α * A * B + β * Y` (5-argument form). +- If `Y` is mutable, it mutates `Y` in-place using `mul!(Y, A, B, α, β)`. +- If `Y` is immutable, it returns `α * (A * B) + β * Y`. +""" +@inline function mul!!(Y, A, B, α, β) + if ismutable(Y) + mul!(Y, A, B, α, β) + return Y + else + return α * (A * B) + β * Y + end +end + +""" + muladd!!(Y, A, B) + +Computes `Y = Y + A * B`. +- If `Y` is mutable (e.g., Vector), it mutates `Y` in-place using `mul!`. +- If `Y` is immutable (e.g., SVector), it returns a new generic result. +- If `A` or `B` is `nothing`, returns `Y` unchanged (no-op). +""" +@inline function muladd!!(Y, A, B) + if ismutable(Y) + mul!(Y, A, B, 1.0, 1.0) + return Y + else + return Y + A * B + end +end + +# Specializations for nothing arguments (no-op) +@inline muladd!!(Y, ::Nothing, B) = Y +@inline muladd!!(Y, A, ::Nothing) = Y +@inline muladd!!(Y, ::Nothing, ::Nothing) = Y + +""" + ldiv!!(y, F, x) + +Computes `y = F \\ x` (linear solve with factorization F). +- If `y` is mutable, it mutates `y` in-place using `ldiv!(y, F, x)`. +- If `y` is immutable, it returns `F \\ x`. +""" +@inline function ldiv!!(y, F, x) + if ismutable(y) + ldiv!(y, F, x) + return y + else + return F \ x + end +end + +""" + ldiv!!(F, x) + +Computes `x = F \\ x` in-place (2-argument form). +- If `x` is mutable, it modifies `x` in-place using `ldiv!(F, x)`. +- If `x` is immutable, it returns `F \\ x`. +""" +@inline function ldiv!!(F, x) + if ismutable(x) + ldiv!(F, x) + return x + else + return F \ x + end +end + +""" + copyto!!(Y, X) + +Copies `X` to `Y`. +- If `Y` is mutable, it mutates `Y` in-place using `copyto!(Y, X)`. +- If `Y` is immutable, it returns `X` directly (immutables are values). +""" +@inline function copyto!!(Y, X) + if ismutable(Y) + copyto!(Y, X) + return Y + else + return X + end +end + +""" + assign!!(Y, X) + +Copy `X` into `Y` using an explicit loop (Enzyme-safe activity analysis). +- If `Y` is mutable, copies element-by-element with `@inbounds` and returns `Y`. +- If `Y` is immutable (e.g., `SVector`), returns `X` directly. +""" +@inline function assign!!(Y, X) + if ismutable(Y) + @inbounds for i in eachindex(X) + Y[i] = X[i] + end + return Y + else + return X + end +end + +""" + cholesky!!(A, uplo::Symbol=:U) + +Computes Cholesky factorization of symmetric matrix A. +- If `A` is mutable, uses `cholesky!(Symmetric(A, uplo), NoPivot(); check=false)`. +- If `A` is immutable, uses `cholesky(Symmetric(A, uplo))`. +""" +@inline function cholesky!!(A, uplo::Symbol = :U) + if ismutable(A) + return cholesky!(Symmetric(A, uplo), NoPivot(); check = false) + else + return cholesky(Symmetric(A, uplo)) + end +end + +""" + transpose!!(Y, X) + +Transposes `X` into `Y`. +- If `Y` is mutable, uses `transpose!(Y, X)`. +- If `Y` is immutable, returns `transpose(X)`. +""" +@inline function transpose!!(Y, X) + if ismutable(Y) + transpose!(Y, X) + return Y + else + return transpose(X) + end +end + +""" + mul_aat!!(Y, A, A_t) + +Computes `Y = A * A'` without triggering the BLAS `syrk` self-transpose path. +Workaround for Enzyme syrk adjoint bug (https://github.com/EnzymeAD/Enzyme.jl/issues/2355): +when `A` is rectangular, `mul!(Y, A, transpose(A))` dispatches to `syrk` whose Enzyme +reverse-mode rule generates a `DSYMM` call with invalid leading dimension. + +- If `Y` is mutable, materializes `transpose(A)` into buffer `A_t`, then calls `mul!(Y, A, A_t)`. +- If `Y` is immutable, returns `A * transpose(A)` (StaticArrays don't use BLAS). +""" +@inline function mul_aat!!(Y, A, A_t) + if ismutable(Y) + transpose!(A_t, A) + mul!(Y, A, A_t) + return Y + else + return A * transpose(A) + end +end + +""" + logdet_chol(F) + +Compute log-determinant from Cholesky factorization without allocations. +Uses: logdet(A) = logdet(U'U) = 2*sum(log(diag(U))) for upper Cholesky. +""" +@inline function logdet_chol(F) + U = F.U + result = zero(eltype(U)) + @inbounds for i in axes(U, 1) + result += log(U[i, i]) + end + return 2 * result +end + +""" + symmetrize_upper!!(L, A, eps=0.0) + +Symmetrize matrix A into upper triangular form with optional diagonal perturbation. +- If `L` is mutable, modifies in-place and returns `L` +- If `L` is immutable, returns `(A + A')/2 + eps*I` +""" +@inline function symmetrize_upper!!(L, A, eps = 0.0) + if ismutable(L) + @inbounds for j in axes(A, 2) + for i in 1:j + v = (A[i, j] + A[j, i]) * 0.5 + L[i, j] = (i == j) ? v + eps : v + end + for i in (j + 1):size(A, 1) + L[i, j] = 0 + end + end + return L + else + sym = (A + A') / 2 + if eps != 0 + return sym + eps * one(A) + else + return sym + end + end +end + +# ============================================================================= +# Prototype-based allocation utilities +# ============================================================================= + +""" + alloc_like(x) + alloc_like(x, dims::Int...) + +Allocate an array matching the type family of `x`. +""" +@inline alloc_like(x::AbstractArray) = similar(x) +@inline alloc_like(::SVector{N, T}) where {N, T} = zeros(SVector{N, T}) +@inline alloc_like(::SMatrix{N, M, T}) where {N, M, T} = zeros(SMatrix{N, M, T}) + +# Different dimensions +@inline alloc_like(x::AbstractArray, dims::Int...) = similar(x, dims...) +@inline alloc_like(::SVector{<:Any, T}, n::Int) where {T} = zeros(SVector{n, T}) +@inline alloc_like(::SMatrix{<:Any, <:Any, T}, n::Int, m::Int) where {T} = + zeros(SMatrix{n, m, T}) +@inline alloc_like(::SMatrix{<:Any, <:Any, T}, n::Int) where {T} = zeros(SVector{n, T}) + +# ============================================================================= +# Generic zeroing utility for Enzyme compatibility +# ============================================================================= + +""" + fill_zero!!(x) + +Zero out all elements of `x`. +""" +@inline fill_zero!!(::SVector{N, T}) where {N, T} = zeros(SVector{N, T}) +@inline fill_zero!!(::SMatrix{N, M, T}) where {N, M, T} = zeros(SMatrix{N, M, T}) +@inline function fill_zero!!(x::AbstractArray{T}) where {T} + fill!(x, zero(T)) + return x +end diff --git a/src/workspace.jl b/src/workspace.jl new file mode 100644 index 0000000..9cb224d --- /dev/null +++ b/src/workspace.jl @@ -0,0 +1,63 @@ +# SciML-compatible init / solve! API +# Workspace holds pre-allocated output arrays + scratch cache. + +""" + StateSpaceWorkspace + +Workspace for state-space problem solvers. Holds the problem, algorithm, +pre-allocated output arrays, and scratch cache. +Created by `CommonSolve.init` and consumed by `CommonSolve.solve!`. +""" +@concrete mutable struct StateSpaceWorkspace + prob + alg + output # pre-allocated output arrays (u, P, z) — NamedTuple + cache # scratch workspace buffers + save_everystep::Bool +end + +# Public 4-arg constructor — assumes save_everystep=true (full trajectory storage). +# This is the form used by Enzyme wrappers and direct workspace construction. +# The 5-arg form with the Bool is only called internally by init(). +function StateSpaceWorkspace(prob, alg, output, cache) + return StateSpaceWorkspace(prob, alg, output, cache, true) +end + +""" + CommonSolve.init(prob::AbstractStateSpaceProblem, alg=default_alg(prob); save_everystep=true, kwargs...) + +Create a `StateSpaceWorkspace` with pre-allocated output arrays and scratch cache. +When `save_everystep=false`, allocates minimal 2-element buffers (endpoints only). +""" +function CommonSolve.init( + prob::AbstractStateSpaceProblem, alg = default_alg(prob); + save_everystep = true, kwargs... + ) + T = convert(Int64, prob.tspan[2] - prob.tspan[1] + 1) + if save_everystep + output = alloc_sol(prob, alg, T) + cache = alloc_cache(prob, alg, T) + else + se = Val(false) + output = alloc_sol(prob, alg, T, se) + cache = alloc_cache(prob, alg, T, se) + end + return StateSpaceWorkspace(prob, alg, output, cache, save_everystep) +end + +""" + CommonSolve.solve!(ws::StateSpaceWorkspace; kwargs...) + +Solve the state-space problem. Mutates `ws.output` arrays in place, then +wraps them in a `StateSpaceSolution` and returns it. +""" +function CommonSolve.solve!(ws::StateSpaceWorkspace; kwargs...) + if ws.save_everystep + return _solve!(ws.prob, ws.alg, ws.output, ws.cache; kwargs...) + else + return _solve!( + ws.prob, ws.alg, ws.output, ws.cache; + save_everystep = Val(false), kwargs... + ) + end +end diff --git a/test/Project.toml b/test/Project.toml index 655a9fa..ecd2306 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,23 +1,27 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DifferenceEquations = "e0ca9c66-1f9e-11ec-127a-1304ce62169c" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" -FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[sources] +DifferenceEquations = {path = ".."} [compat] Aqua = "0.8" +Enzyme = "0.13" +EnzymeTestUtils = "0.2" ExplicitImports = "1" diff --git a/test/cache_reuse.jl b/test/cache_reuse.jl new file mode 100644 index 0000000..d12e7ae --- /dev/null +++ b/test/cache_reuse.jl @@ -0,0 +1,137 @@ +using DifferenceEquations, Distributions, LinearAlgebra, Test +using DelimitedFiles, DiffEqBase, Random +using DifferenceEquations: init, solve! + +A_rbc = [0.9568351489231076 6.209371005755285; 3.0153731819288737e-18 0.20000000000000007] +B_rbc = reshape([0.0; -0.01], 2, 1) +C_rbc = [0.09579643002426148 0.6746869652592109; 1.0 0.0] +D_rbc = abs2.([0.1, 0.1]) +u0_rbc = zeros(2) + +observables_rbc_matrix = readdlm( + joinpath(pkgdir(DifferenceEquations), "test/data/RBC_observables.csv"), ',' +)' |> collect +noise_rbc_matrix = readdlm( + joinpath(pkgdir(DifferenceEquations), "test/data/RBC_noise.csv"), ',' +)' |> collect +T = 5 +observables_rbc = [observables_rbc_matrix[:, t] for t in 1:T] +noise_rbc = [noise_rbc_matrix[:, t] for t in 1:T] + +# --- Generic callbacks for StateSpaceProblem tests --- +linear_f!! = (x_next, x, w, p, t) -> begin + mul!(x_next, p.A, x) + mul!(x_next, p.B, w, 1.0, 1.0) + return x_next +end +linear_g!! = (y, x, p, t) -> begin + mul!(y, p.C, x) + return y +end + +# --- LinearStateSpaceProblem cache reuse --- + +@testset "init/solve! matches solve for DirectIteration" begin + prob = LinearStateSpaceProblem( + A_rbc, B_rbc, u0_rbc, (0, length(observables_rbc)); + C = C_rbc, observables_noise = Diagonal(D_rbc), noise = noise_rbc, + observables = observables_rbc + ) + sol_direct = solve(prob) + + ws = init(prob, DirectIteration()) + sol_ws = solve!(ws) + + @test sol_ws.u ≈ sol_direct.u + @test sol_ws.z ≈ sol_direct.z + @test sol_ws.logpdf ≈ sol_direct.logpdf +end + +@testset "init/solve! matches solve for KalmanFilter" begin + prob = LinearStateSpaceProblem( + A_rbc, B_rbc, u0_rbc, (0, length(observables_rbc)); + C = C_rbc, observables_noise = Diagonal(D_rbc), observables = observables_rbc, + u0_prior_mean = u0_rbc, + u0_prior_var = diagm(ones(length(u0_rbc))) + ) + sol_direct = solve(prob) + + ws = init(prob, KalmanFilter()) + sol_ws = solve!(ws) + + @test sol_ws.u ≈ sol_direct.u + @test sol_ws.z ≈ sol_direct.z + @test sol_ws.logpdf ≈ sol_direct.logpdf + @test sol_ws.P ≈ sol_direct.P +end + +@testset "repeated solve! gives consistent results" begin + prob = LinearStateSpaceProblem( + A_rbc, B_rbc, u0_rbc, (0, length(observables_rbc)); + C = C_rbc, observables_noise = Diagonal(D_rbc), noise = noise_rbc, + observables = observables_rbc + ) + + ws = init(prob, DirectIteration()) + sol1 = solve!(ws) + sol2 = solve!(ws) + + @test sol1.u ≈ sol2.u + @test sol1.z ≈ sol2.z + @test sol1.logpdf ≈ sol2.logpdf +end + +@testset "repeated solve! for KalmanFilter" begin + prob = LinearStateSpaceProblem( + A_rbc, B_rbc, u0_rbc, (0, length(observables_rbc)); + C = C_rbc, observables_noise = Diagonal(D_rbc), observables = observables_rbc, + u0_prior_mean = u0_rbc, + u0_prior_var = diagm(ones(length(u0_rbc))) + ) + + ws = init(prob, KalmanFilter()) + sol1 = solve!(ws) + sol2 = solve!(ws) + + @test sol1.logpdf ≈ sol2.logpdf + @test sol1.u ≈ sol2.u + @test sol1.P ≈ sol2.P +end + +# --- StateSpaceProblem cache reuse --- + +@testset "Generic init/solve! matches solve" begin + p = (; A = A_rbc, B = B_rbc, C = C_rbc) + + prob = StateSpaceProblem( + linear_f!!, linear_g!!, u0_rbc, (0, T), p; + n_shocks = 1, n_obs = 2, + observables_noise = Diagonal(D_rbc), noise = noise_rbc, observables = observables_rbc + ) + + sol_direct = solve(prob) + ws = init(prob, DirectIteration()) + sol_ws = solve!(ws) + + @test sol_ws.u ≈ sol_direct.u + @test sol_ws.z ≈ sol_direct.z + @test sol_ws.logpdf ≈ sol_direct.logpdf +end + +@testset "Generic repeated solve! gives consistent results" begin + p = (; A = A_rbc, B = B_rbc, C = C_rbc) + + prob = StateSpaceProblem( + linear_f!!, linear_g!!, u0_rbc, (0, T), p; + n_shocks = 1, n_obs = 2, + observables_noise = Diagonal(D_rbc), noise = noise_rbc, observables = observables_rbc + ) + + ws = init(prob, DirectIteration()) + sol1 = solve!(ws) + sol2 = solve!(ws) + + @test sol1.u ≈ sol2.u + @test sol1.z ≈ sol2.z + @test sol1.logpdf ≈ sol2.logpdf +end diff --git a/test/conditional_likelihood.jl b/test/conditional_likelihood.jl new file mode 100644 index 0000000..0227700 --- /dev/null +++ b/test/conditional_likelihood.jl @@ -0,0 +1,505 @@ +# ConditionalLikelihood tests — prediction error decomposition for fully-observed models. +# Validates correctness against manual log-likelihood, type stability, workspace API, +# and StaticArrays support. + +using DifferenceEquations, Distributions, LinearAlgebra, Test, Random +using DifferenceEquations: init, solve! +using StaticArrays + +# ============================================================================= +# AR(1) manual log-likelihood helper +# ============================================================================= + +function manual_ar1_loglik(y, rho, sigma_e; u0 = 0.0) + T = length(y) + loglik = 0.0 + x_prev = u0 + for t in 1:T + mu = rho * x_prev + loglik += logpdf(Normal(mu, sigma_e), y[t]) + x_prev = y[t] + end + return loglik +end + +function manual_var_loglik(y, A, R; u0 = zeros(size(A, 1))) + T = length(y) + loglik = 0.0 + x_prev = u0 + M = size(R, 1) + dist = MvNormal(zeros(M), R) + for t in 1:T + mu = A * x_prev + loglik += logpdf(MvNormal(mu, R), y[t]) + x_prev = y[t] + end + return loglik +end + +# ============================================================================= +# AR(1) — C = nothing (identity observation) +# ============================================================================= + +@testset "ConditionalLikelihood — AR(1), C=nothing" begin + rho = 0.8 + sigma_e = 0.5 + T = 50 + + # Generate AR(1) data + Random.seed!(123) + y_scalar = zeros(T) + x = 0.0 + for t in 1:T + x = rho * x + sigma_e * randn() + y_scalar[t] = x + end + y = [[yi] for yi in y_scalar] + + prob = LinearStateSpaceProblem( + fill(rho, 1, 1), nothing, [0.0], (0, T); + observables = y, + observables_noise = Diagonal([sigma_e^2]), + ) + sol = solve(prob, ConditionalLikelihood()) + + expected = manual_ar1_loglik(y_scalar, rho, sigma_e) + @test sol.logpdf ≈ expected atol = 1.0e-12 + @test sol.z === nothing + @test length(sol.u) == T + 1 + # State should be clamped to observations + for t in 1:T + @test sol.u[t + 1] ≈ y[t] + end +end + +# ============================================================================= +# AR(1) — C = I (explicit observation matrix, same result) +# ============================================================================= + +@testset "ConditionalLikelihood — AR(1), C=I" begin + rho = 0.8 + sigma_e = 0.5 + T = 50 + + Random.seed!(123) + y_scalar = zeros(T) + x = 0.0 + for t in 1:T + x = rho * x + sigma_e * randn() + y_scalar[t] = x + end + y = [[yi] for yi in y_scalar] + + prob_no_c = LinearStateSpaceProblem( + fill(rho, 1, 1), nothing, [0.0], (0, T); + observables = y, + observables_noise = Diagonal([sigma_e^2]), + ) + sol_no_c = solve(prob_no_c, ConditionalLikelihood()) + + prob_with_c = LinearStateSpaceProblem( + fill(rho, 1, 1), nothing, [0.0], (0, T); + C = fill(1.0, 1, 1), + observables = y, + observables_noise = Diagonal([sigma_e^2]), + ) + sol_with_c = solve(prob_with_c, ConditionalLikelihood()) + + @test sol_no_c.logpdf ≈ sol_with_c.logpdf atol = 1.0e-12 + @test !isnothing(sol_with_c.z) +end + +# ============================================================================= +# VAR(1) — multivariate +# ============================================================================= + +@testset "ConditionalLikelihood — VAR(1)" begin + A = [0.8 0.1; -0.1 0.7] + R = Diagonal([0.25, 0.25]) + T = 30 + + Random.seed!(456) + y = Vector{Vector{Float64}}(undef, T) + x = zeros(2) + for t in 1:T + x = A * x + cholesky(R).L * randn(2) + y[t] = copy(x) + end + + prob = LinearStateSpaceProblem( + A, nothing, zeros(2), (0, T); + observables = y, + observables_noise = R, + ) + sol = solve(prob, ConditionalLikelihood()) + + expected = manual_var_loglik(y, A, R) + @test sol.logpdf ≈ expected atol = 1.0e-10 +end + +# ============================================================================= +# With B and noise — prediction includes noise term +# ============================================================================= + +@testset "ConditionalLikelihood — with B and explicit noise" begin + rho = 0.8 + sigma_e = 0.5 + T = 20 + + Random.seed!(789) + noise = [[randn()] for _ in 1:T] + y_scalar = zeros(T) + x = 0.0 + B_val = 0.1 + for t in 1:T + x = rho * x + B_val * noise[t][1] + y_scalar[t] = x + sigma_e * randn() + end + y = [[yi] for yi in y_scalar] + + prob = LinearStateSpaceProblem( + fill(rho, 1, 1), fill(B_val, 1, 1), [0.0], (0, T); + observables = y, + observables_noise = Diagonal([sigma_e^2]), + noise = noise, + ) + sol = solve(prob, ConditionalLikelihood()) + + # Manual: prediction is rho * y[t-1] + B * w[t] + loglik = 0.0 + x_prev = 0.0 + for t in 1:T + mu = rho * x_prev + B_val * noise[t][1] + loglik += logpdf(Normal(mu, sigma_e), y_scalar[t]) + x_prev = y_scalar[t] + end + @test sol.logpdf ≈ loglik atol = 1.0e-12 +end + +# ============================================================================= +# Generic StateSpaceProblem — nonlinear AR(1) +# ============================================================================= + +@testset "ConditionalLikelihood — generic nonlinear AR(1)" begin + rho = 0.8 + alpha = 0.05 + sigma_e = 0.3 + T = 30 + + Random.seed!(111) + y_scalar = zeros(T) + x = 0.0 + for t in 1:T + x = rho * x + alpha * x^2 + sigma_e * randn() + y_scalar[t] = x + end + y = [[yi] for yi in y_scalar] + + nl_f!! = (x_next, x, w, p, t) -> begin + (; rho, alpha) = p + val = rho * x[1] + alpha * x[1]^2 + if ismutable(x_next) + x_next[1] = val + return x_next + else + return typeof(x)(val) + end + end + + p = (; rho, alpha) + prob = StateSpaceProblem( + nl_f!!, nothing, [0.0], (0, T), p; + n_shocks = 0, n_obs = 0, + observables = y, + observables_noise = Diagonal([sigma_e^2]), + ) + sol = solve(prob, ConditionalLikelihood()) + + # Manual + loglik = 0.0 + x_prev = 0.0 + for t in 1:T + mu = rho * x_prev + alpha * x_prev^2 + loglik += logpdf(Normal(mu, sigma_e), y_scalar[t]) + x_prev = y_scalar[t] + end + @test sol.logpdf ≈ loglik atol = 1.0e-12 +end + +# ============================================================================= +# QuadraticStateSpaceProblem +# ============================================================================= + +@testset "ConditionalLikelihood — QuadraticStateSpaceProblem" begin + rho = 0.8 + alpha = 0.05 + sigma_e = 0.3 + T = 30 + + Random.seed!(111) + y_scalar = zeros(T) + x = 0.0 + for t in 1:T + x = rho * x + alpha * x^2 + sigma_e * randn() + y_scalar[t] = x + end + y = [[yi] for yi in y_scalar] + + A_0 = [0.0] + A_1 = fill(rho, 1, 1) + A_2 = fill(alpha, 1, 1, 1) + + prob = QuadraticStateSpaceProblem( + A_0, A_1, A_2, nothing, [0.0], (0, T); + C_0 = [0.0], C_1 = fill(1.0, 1, 1), C_2 = zeros(1, 1, 1), + observables = y, + observables_noise = Diagonal([sigma_e^2]), + ) + sol = solve(prob, ConditionalLikelihood()) + + # Manual — same as the generic nonlinear test + loglik = 0.0 + x_prev = 0.0 + for t in 1:T + mu = rho * x_prev + alpha * x_prev^2 + loglik += logpdf(Normal(mu, sigma_e), y_scalar[t]) + x_prev = y_scalar[t] + end + @test sol.logpdf ≈ loglik atol = 1.0e-10 +end + +# ============================================================================= +# Type stability +# ============================================================================= + +@testset "ConditionalLikelihood — type stability" begin + T = 5 + Random.seed!(42) + y = [randn(2) for _ in 1:T] + + prob_linear = LinearStateSpaceProblem( + [0.8 0.1; -0.1 0.7], nothing, zeros(2), (0, T); + observables = y, + observables_noise = Diagonal([0.25, 0.25]), + ) + @test @inferred(solve(prob_linear, ConditionalLikelihood())) isa Any + + prob_linear_c = LinearStateSpaceProblem( + [0.8 0.1; -0.1 0.7], nothing, zeros(2), (0, T); + C = [1.0 0.0; 0.0 1.0], + observables = y, + observables_noise = Diagonal([0.25, 0.25]), + ) + @test @inferred(solve(prob_linear_c, ConditionalLikelihood())) isa Any +end + +# ============================================================================= +# Workspace init/solve! +# ============================================================================= + +@testset "ConditionalLikelihood — solve!() matches solve()" begin + T = 20 + rho = 0.8 + sigma_e = 0.5 + + Random.seed!(321) + y = [randn(1) for _ in 1:T] + + prob = LinearStateSpaceProblem( + fill(rho, 1, 1), nothing, [0.0], (0, T); + observables = y, + observables_noise = Diagonal([sigma_e^2]), + ) + sol_direct = solve(prob, ConditionalLikelihood()) + + ws = init(prob, ConditionalLikelihood()) + sol_ws = solve!(ws) + @test sol_ws.logpdf ≈ sol_direct.logpdf + @test sol_ws.u ≈ sol_direct.u +end + +@testset "ConditionalLikelihood — solve!() with C matrix" begin + T = 10 + A = [0.8 0.1; -0.1 0.7] + C = [1.0 0.0; 0.0 1.0] + R = Diagonal([0.1, 0.1]) + + Random.seed!(654) + y = [randn(2) for _ in 1:T] + + prob = LinearStateSpaceProblem( + A, nothing, zeros(2), (0, T); + C = C, observables = y, observables_noise = R, + ) + sol_direct = solve(prob, ConditionalLikelihood()) + + ws = init(prob, ConditionalLikelihood()) + sol_ws = solve!(ws) + @test sol_ws.logpdf ≈ sol_direct.logpdf + @test sol_ws.u ≈ sol_direct.u + @test sol_ws.z ≈ sol_direct.z +end + +@testset "ConditionalLikelihood — solve!() repeated is idempotent" begin + T = 10 + rho = 0.8 + sigma_e = 0.5 + + Random.seed!(987) + y = [randn(1) for _ in 1:T] + + prob = LinearStateSpaceProblem( + fill(rho, 1, 1), nothing, [0.0], (0, T); + observables = y, + observables_noise = Diagonal([sigma_e^2]), + ) + ws = init(prob, ConditionalLikelihood()) + sol1 = solve!(ws) + sol2 = solve!(ws) + @test sol1.logpdf ≈ sol2.logpdf + @test sol1.u ≈ sol2.u +end + +# ============================================================================= +# Error handling +# ============================================================================= + +@testset "ConditionalLikelihood — error handling" begin + @testset "missing observables" begin + prob = LinearStateSpaceProblem( + fill(0.8, 1, 1), nothing, [0.0], (0, 5); + observables_noise = Diagonal([0.25]), + ) + @test_throws ArgumentError solve(prob, ConditionalLikelihood()) + end + + @testset "missing observables_noise" begin + y = [randn(1) for _ in 1:5] + prob = LinearStateSpaceProblem( + fill(0.8, 1, 1), nothing, [0.0], (0, 5); + observables = y, + ) + @test_throws ArgumentError solve(prob, ConditionalLikelihood()) + end + + @testset "observables wrong length" begin + y = [randn(1) for _ in 1:3] + prob = LinearStateSpaceProblem( + fill(0.8, 1, 1), nothing, [0.0], (0, 5); + observables = y, + observables_noise = Diagonal([0.25]), + ) + @test_throws ArgumentError solve(prob, ConditionalLikelihood()) + end +end + +# ============================================================================= +# StaticArrays +# ============================================================================= + +@testset "ConditionalLikelihood — StaticArrays AR(1)" begin + rho = 0.8 + sigma_e = 0.5 + T = 20 + + Random.seed!(555) + y_scalar = zeros(T) + x = 0.0 + for t in 1:T + x = rho * x + sigma_e * randn() + y_scalar[t] = x + end + + y_mut = [[yi] for yi in y_scalar] + y_static = [SVector{1}(yi) for yi in y_scalar] + + prob_mut = LinearStateSpaceProblem( + fill(rho, 1, 1), nothing, [0.0], (0, T); + observables = y_mut, + observables_noise = Diagonal([sigma_e^2]), + ) + sol_mut = solve(prob_mut, ConditionalLikelihood()) + + prob_static = LinearStateSpaceProblem( + SMatrix{1, 1}(rho), nothing, SVector{1}(0.0), (0, T); + observables = y_static, + observables_noise = Diagonal(SVector{1}(sigma_e^2)), + ) + sol_static = solve(prob_static, ConditionalLikelihood()) + + @test sol_static.logpdf ≈ sol_mut.logpdf atol = 1.0e-12 +end + +@testset "ConditionalLikelihood — StaticArrays VAR(1)" begin + A = [0.8 0.1; -0.1 0.7] + R = Diagonal([0.25, 0.25]) + T = 15 + + Random.seed!(666) + y_mut = [randn(2) for _ in 1:T] + y_static = [SVector{2}(yi) for yi in y_mut] + + prob_mut = LinearStateSpaceProblem( + A, nothing, zeros(2), (0, T); + observables = y_mut, + observables_noise = R, + ) + sol_mut = solve(prob_mut, ConditionalLikelihood()) + + prob_static = LinearStateSpaceProblem( + SMatrix{2, 2}(A), nothing, SVector{2}(0.0, 0.0), (0, T); + observables = y_static, + observables_noise = Diagonal(SVector{2}(0.25, 0.25)), + ) + sol_static = solve(prob_static, ConditionalLikelihood()) + + @test sol_static.logpdf ≈ sol_mut.logpdf atol = 1.0e-12 +end + +@testset "ConditionalLikelihood — StaticArrays generic nonlinear" begin + rho = 0.8 + alpha = 0.05 + sigma_e = 0.3 + T = 15 + + Random.seed!(777) + y_scalar = zeros(T) + x = 0.0 + for t in 1:T + x = rho * x + alpha * x^2 + sigma_e * randn() + y_scalar[t] = x + end + + y_mut = [[yi] for yi in y_scalar] + y_static = [SVector{1}(yi) for yi in y_scalar] + + nl_f!! = (x_next, x, w, p, t) -> begin + (; rho, alpha) = p + val = rho * x[1] + alpha * x[1]^2 + if ismutable(x_next) + x_next[1] = val + return x_next + else + return typeof(x)(val) + end + end + + p = (; rho, alpha) + prob_mut = StateSpaceProblem( + nl_f!!, nothing, [0.0], (0, T), p; + n_shocks = 0, n_obs = 0, + observables = y_mut, + observables_noise = Diagonal([sigma_e^2]), + ) + sol_mut = solve(prob_mut, ConditionalLikelihood()) + + prob_static = StateSpaceProblem( + nl_f!!, nothing, SVector{1}(0.0), (0, T), p; + n_shocks = 0, n_obs = 0, + observables = y_static, + observables_noise = Diagonal(SVector{1}(sigma_e^2)), + ) + sol_static = solve(prob_static, ConditionalLikelihood()) + + @test sol_static.logpdf ≈ sol_mut.logpdf atol = 1.0e-12 +end diff --git a/test/conditional_likelihood_enzyme.jl b/test/conditional_likelihood_enzyme.jl new file mode 100644 index 0000000..77eb63d --- /dev/null +++ b/test/conditional_likelihood_enzyme.jl @@ -0,0 +1,163 @@ +# Enzyme AD tests for ConditionalLikelihood +# prob passed as Duplicated — observables get zero shadow automatically. +# GC disabled to avoid Enzyme reverse-mode GC corruption (#2355). + +GC.gc() +GC.enable(false) + +using LinearAlgebra, Test, Enzyme, EnzymeTestUtils, Random +using DifferenceEquations +using DifferenceEquations: init, solve!, StateSpaceWorkspace +using FiniteDifferences: central_fdm + +include("enzyme_test_utils.jl") # vech helpers + +# ============================================================================= +# Test data +# ============================================================================= + +const N_cl_e = 2 +const T_cl_e = 5 + +const A_cl_e = [0.8 0.1; -0.1 0.7] +const H_cl_e = [0.1 0.0; 0.0 0.1] +const u0_cl_e = zeros(N_cl_e) + +Random.seed!(42) +const y_cl_e = [randn(N_cl_e) for _ in 1:T_cl_e] + +# max_range needed: FD perturbation of observables_noise inside prob can push +# the matrix non-positive-definite, causing DomainError in logdet_chol. +const _fdm_cl = central_fdm(5, 1; max_range = 1.0e-3) + +# ============================================================================= +# Wrappers — prob as single Duplicated arg +# ============================================================================= + +function cl_forward_prob!(prob, sol, cache) + ws = StateSpaceWorkspace(prob, ConditionalLikelihood(), sol, cache) + solve!(ws) + return sol.u +end + +function cl_loglik_prob(prob, sol, cache)::Float64 + ws = StateSpaceWorkspace(prob, ConditionalLikelihood(), sol, cache) + return solve!(ws).logpdf +end + +# Vech: separate args (y stays Duplicated — remake doesn't work with Enzyme shadows) +function cl_loglik_vech(A, v_R, u0, y, sol, cache)::Float64 + R = make_posdef_from_vech(v_R, size(A, 1)) + prob = LinearStateSpaceProblem( + A, nothing, u0, (0, length(y)); + observables_noise = R, observables = y + ) + ws = StateSpaceWorkspace(prob, ConditionalLikelihood(), sol, cache) + return solve!(ws).logpdf +end + +# ============================================================================= +# Sanity +# ============================================================================= + +@testset "ConditionalLikelihood loglik via solve!() — Enzyme sanity" begin + prob = LinearStateSpaceProblem( + A_cl_e, nothing, u0_cl_e, (0, T_cl_e); + observables_noise = H_cl_e * H_cl_e', observables = y_cl_e + ) + ws = init(prob, ConditionalLikelihood()) + loglik = cl_loglik_prob(prob, ws.output, ws.cache) + @test isfinite(loglik) + loglik2 = cl_loglik_prob(prob, ws.output, ws.cache) + @test loglik ≈ loglik2 rtol = 1.0e-12 +end + +# ============================================================================= +# Forward — prob as Duplicated +# ============================================================================= + +@testset "EnzymeTestUtils — CL forward (prob Duplicated)" begin + prob = LinearStateSpaceProblem( + A_cl_e, nothing, u0_cl_e, (0, T_cl_e); + observables_noise = H_cl_e * H_cl_e', observables = y_cl_e + ) + ws = init(prob, ConditionalLikelihood()) + + test_forward( + cl_forward_prob!, Const, + (prob, Duplicated), + (ws.output, Duplicated), (ws.cache, Duplicated); + fdm = _fdm_cl, + ) +end + +# ============================================================================= +# Reverse — prob as Duplicated +# ============================================================================= + +@testset "EnzymeTestUtils — CL reverse (prob Duplicated)" begin + prob = LinearStateSpaceProblem( + A_cl_e, nothing, u0_cl_e, (0, T_cl_e); + observables_noise = H_cl_e * H_cl_e', observables = y_cl_e + ) + ws = init(prob, ConditionalLikelihood()) + + test_reverse( + cl_loglik_prob, Active, + (prob, Duplicated), + (deepcopy(ws.output), Duplicated), (deepcopy(ws.cache), Duplicated); + fdm = _fdm_cl, + ) +end + +# ============================================================================= +# Reverse — with C matrix +# ============================================================================= + +@testset "EnzymeTestUtils — CL reverse with C (prob Duplicated)" begin + C_cl = [1.0 0.0; 0.0 1.0] + prob = LinearStateSpaceProblem( + A_cl_e, nothing, u0_cl_e, (0, T_cl_e); + C = C_cl, observables_noise = H_cl_e * H_cl_e', observables = y_cl_e + ) + ws = init(prob, ConditionalLikelihood()) + + test_reverse( + cl_loglik_prob, Active, + (prob, Duplicated), + (deepcopy(ws.output), Duplicated), (deepcopy(ws.cache), Duplicated); + fdm = _fdm_cl, + ) +end + +# ============================================================================= +# Reverse — non-diagonal R via vech +# ============================================================================= + +# Vech test: separate args (y as Duplicated — can't avoid due to struct storage). +# Tighter max_range needed for vech: make_posdef_from_vech has high curvature. +@testset "EnzymeTestUtils — CL reverse non-diagonal R (vech)" begin + H_offdiag = [0.1 0.05; 0.02 0.08] + R_offdiag = H_offdiag * H_offdiag' + v0 = make_vech_for(R_offdiag) + _fdm_vech = central_fdm(5, 1) + + prob_v = LinearStateSpaceProblem( + A_cl_e, nothing, u0_cl_e, (0, T_cl_e); + observables_noise = R_offdiag, observables = y_cl_e + ) + ws_v = init(prob_v, ConditionalLikelihood()) + + test_reverse( + cl_loglik_vech, Active, + (copy(A_cl_e), Duplicated), + (copy(v0), Duplicated), + (copy(u0_cl_e), Duplicated), + ([copy(yi) for yi in y_cl_e], Duplicated), + (deepcopy(ws_v.output), Duplicated), + (deepcopy(ws_v.cache), Duplicated); + fdm = _fdm_vech, + ) +end + +GC.enable(true) diff --git a/test/conditional_likelihood_forwarddiff.jl b/test/conditional_likelihood_forwarddiff.jl new file mode 100644 index 0000000..00a8dc6 --- /dev/null +++ b/test/conditional_likelihood_forwarddiff.jl @@ -0,0 +1,234 @@ +# ForwardDiff AD tests for ConditionalLikelihood +# Tests gradient correctness against FiniteDifferences.jl central FD. + +using LinearAlgebra, Test, ForwardDiff, StaticArrays, Random +using DifferenceEquations +using FiniteDifferences: central_fdm, grad + +include("forwarddiff_test_utils.jl") # promote_array only + +const _fdm_cl_fd = central_fdm(5, 1) + +# ============================================================================= +# Problem setup +# ============================================================================= + +const N_cl_fd = 2 +const M_cl_fd = 2 +const T_cl_fd = 10 + +const A_cl_fd = [0.8 0.1; -0.1 0.7] +const H_cl_fd = [0.1 0.0; 0.0 0.1] +const u0_cl_fd = zeros(N_cl_fd) + +# Generate observables from an AR process +Random.seed!(42) +const y_cl_fd = let + y = Vector{Vector{Float64}}(undef, T_cl_fd) + x = zeros(N_cl_fd) + for t in 1:T_cl_fd + x = A_cl_fd * x + H_cl_fd * randn(N_cl_fd) + y[t] = copy(x) + end + y +end + +# ============================================================================= +# Mutable arrays — ForwardDiff gradient tests +# ============================================================================= + +function cl_loglik_fd(A, u0, y, H) + T_el = promote_type(eltype(A), eltype(u0), eltype(H)) + R = promote_array(T_el, H) * promote_array(T_el, H)' + prob = LinearStateSpaceProblem( + promote_array(T_el, A), nothing, + promote_array(T_el, u0), (0, length(y)); + observables_noise = R, + observables = y, + ) + return solve(prob, ConditionalLikelihood()).logpdf +end + +@testset "ForwardDiff - ConditionalLikelihood loglik (mutable)" begin + @testset "primal sanity" begin + loglik_val = cl_loglik_fd(A_cl_fd, u0_cl_fd, y_cl_fd, H_cl_fd) + @test isfinite(loglik_val) + end + + @testset "gradient w.r.t. A" begin + f(a_vec) = cl_loglik_fd( + reshape(a_vec, N_cl_fd, N_cl_fd), u0_cl_fd, y_cl_fd, H_cl_fd + ) + x0 = vec(copy(A_cl_fd)) + @test ForwardDiff.gradient(f, x0) ≈ grad(_fdm_cl_fd, f, x0)[1] rtol = 1.0e-4 + end + + @testset "gradient w.r.t. u0" begin + f(u_vec) = cl_loglik_fd(A_cl_fd, u_vec, y_cl_fd, H_cl_fd) + x0 = [0.1, -0.1] + @test ForwardDiff.gradient(f, x0) ≈ grad(_fdm_cl_fd, f, x0)[1] rtol = 1.0e-4 + end + + @testset "gradient w.r.t. H" begin + f(h_vec) = cl_loglik_fd( + A_cl_fd, u0_cl_fd, y_cl_fd, reshape(h_vec, M_cl_fd, M_cl_fd) + ) + x0 = vec(copy(H_cl_fd)) + @test ForwardDiff.gradient(f, x0) ≈ grad(_fdm_cl_fd, f, x0)[1] rtol = 1.0e-4 + end +end + +# ============================================================================= +# Non-diagonal R — ForwardDiff gradient tests +# ============================================================================= + +const H_cl_fd_offdiag = [0.1 0.05; 0.02 0.08] + +@testset "ForwardDiff - ConditionalLikelihood non-diagonal R (mutable)" begin + @testset "primal sanity" begin + loglik_val = cl_loglik_fd(A_cl_fd, u0_cl_fd, y_cl_fd, H_cl_fd_offdiag) + @test isfinite(loglik_val) + end + + @testset "gradient w.r.t. H (off-diagonal)" begin + f(h_vec) = cl_loglik_fd( + A_cl_fd, u0_cl_fd, y_cl_fd, reshape(h_vec, M_cl_fd, M_cl_fd) + ) + x0 = vec(copy(H_cl_fd_offdiag)) + @test ForwardDiff.gradient(f, x0) ≈ grad(_fdm_cl_fd, f, x0)[1] rtol = 1.0e-4 + end + + @testset "gradient w.r.t. A (with off-diagonal R)" begin + f(a_vec) = cl_loglik_fd( + reshape(a_vec, N_cl_fd, N_cl_fd), u0_cl_fd, y_cl_fd, H_cl_fd_offdiag + ) + x0 = vec(copy(A_cl_fd)) + @test ForwardDiff.gradient(f, x0) ≈ grad(_fdm_cl_fd, f, x0)[1] rtol = 1.0e-4 + end +end + +# ============================================================================= +# With C matrix — ForwardDiff gradient tests +# ============================================================================= + +const C_cl_fd = [1.0 0.0; 0.0 1.0] + +function cl_loglik_fd_with_c(A, C, u0, y, H) + T_el = promote_type(eltype(A), eltype(C), eltype(u0), eltype(H)) + R = promote_array(T_el, H) * promote_array(T_el, H)' + prob = LinearStateSpaceProblem( + promote_array(T_el, A), nothing, + promote_array(T_el, u0), (0, length(y)); + C = promote_array(T_el, C), + observables_noise = R, + observables = y, + ) + return solve(prob, ConditionalLikelihood()).logpdf +end + +@testset "ForwardDiff - ConditionalLikelihood with C (mutable)" begin + @testset "gradient w.r.t. A (with C)" begin + f(a_vec) = cl_loglik_fd_with_c( + reshape(a_vec, N_cl_fd, N_cl_fd), C_cl_fd, u0_cl_fd, y_cl_fd, H_cl_fd + ) + x0 = vec(copy(A_cl_fd)) + @test ForwardDiff.gradient(f, x0) ≈ grad(_fdm_cl_fd, f, x0)[1] rtol = 1.0e-4 + end +end + +# ============================================================================= +# StaticArrays — ForwardDiff gradient tests +# ============================================================================= + +const y_cl_fd_s = [SVector{M_cl_fd}(yi) for yi in y_cl_fd] + +@testset "ForwardDiff - ConditionalLikelihood loglik (static)" begin + @testset "gradient w.r.t. A" begin + function _cl_static_A(a_vec) + T_el = eltype(a_vec) + A_d = SMatrix{N_cl_fd, N_cl_fd}(reshape(a_vec, N_cl_fd, N_cl_fd)) + H_d = SMatrix{M_cl_fd, M_cl_fd}(T_el.(H_cl_fd)) + prob = LinearStateSpaceProblem( + A_d, nothing, + SVector{N_cl_fd}(zeros(T_el, N_cl_fd)), (0, length(y_cl_fd_s)); + observables_noise = H_d * H_d', + observables = y_cl_fd_s, + ) + return solve(prob, ConditionalLikelihood()).logpdf + end + x0 = collect(vec(Matrix(A_cl_fd))) + @test ForwardDiff.gradient(_cl_static_A, x0) ≈ + grad(_fdm_cl_fd, _cl_static_A, x0)[1] rtol = 1.0e-4 + end + + @testset "gradient w.r.t. H" begin + function _cl_static_H(h_vec) + T_el = eltype(h_vec) + A_d = SMatrix{N_cl_fd, N_cl_fd}(T_el.(A_cl_fd)) + H_d = SMatrix{M_cl_fd, M_cl_fd}(reshape(h_vec, M_cl_fd, M_cl_fd)) + prob = LinearStateSpaceProblem( + A_d, nothing, + SVector{N_cl_fd}(zeros(T_el, N_cl_fd)), (0, length(y_cl_fd_s)); + observables_noise = H_d * H_d', + observables = y_cl_fd_s, + ) + return solve(prob, ConditionalLikelihood()).logpdf + end + x0 = collect(vec(Matrix(H_cl_fd))) + @test ForwardDiff.gradient(_cl_static_H, x0) ≈ + grad(_fdm_cl_fd, _cl_static_H, x0)[1] rtol = 1.0e-4 + end +end + +# ============================================================================= +# Generic nonlinear StateSpaceProblem — ForwardDiff gradient +# ============================================================================= + +@testset "ForwardDiff - ConditionalLikelihood generic nonlinear (mutable)" begin + T_nl = 15 + sigma_e_nl = 0.3 + + Random.seed!(99) + y_nl = let + y = Vector{Vector{Float64}}(undef, T_nl) + x = 0.0 + for t in 1:T_nl + x = 0.8 * x + 0.05 * x^2 + sigma_e_nl * randn() + y[t] = [x] + end + y + end + + nl_f!! = (x_next, x, w, p, t) -> begin + (; rho, alpha) = p + val = rho * x[1] + alpha * x[1]^2 + if ismutable(x_next) + x_next[1] = val + return x_next + else + return typeof(x)(val) + end + end + + function cl_nl_loglik(param_vec, y, sigma_e) + T_el = eltype(param_vec) + p = (; rho = param_vec[1], alpha = param_vec[2]) + prob = StateSpaceProblem( + nl_f!!, nothing, [zero(T_el)], (0, length(y)), p; + n_shocks = 0, n_obs = 0, + observables = y, + observables_noise = Diagonal([T_el(sigma_e^2)]), + ) + return solve(prob, ConditionalLikelihood()).logpdf + end + + @testset "primal sanity" begin + @test isfinite(cl_nl_loglik([0.8, 0.05], y_nl, sigma_e_nl)) + end + + @testset "gradient w.r.t. (rho, alpha)" begin + f(p_vec) = cl_nl_loglik(p_vec, y_nl, sigma_e_nl) + x0 = [0.8, 0.05] + @test ForwardDiff.gradient(f, x0) ≈ grad(_fdm_cl_fd, f, x0)[1] rtol = 1.0e-4 + end +end diff --git a/test/direct_iteration.jl b/test/direct_iteration.jl new file mode 100644 index 0000000..c9f9e6f --- /dev/null +++ b/test/direct_iteration.jl @@ -0,0 +1,555 @@ +using DifferenceEquations, Distributions, LinearAlgebra, Test, Random, DelimitedFiles, DiffEqBase +using DifferenceEquations: init, solve! + +# --- Helper: quadratic callbacks --- + +function make_quadratic_callbacks(A_0, A_1, A_2, B, C_0, C_1, C_2, u0) + n_x = length(u0) + n_obs = length(C_0) + u_f = copy(u0) # tracks linear-part state, initialized to u0 + u_f_new = similar(u0) # workspace for updating u_f + + function f!!(x_next, x, w, p, t) + # Compute new linear-part: u_f_new = A_1 * u_f + B * w + mul!(u_f_new, A_1, u_f) + mul!(u_f_new, B, w, 1.0, 1.0) + + # Full transition: x_next = A_0 + A_1 * x + quad(A_2, u_f) + B * w + copyto!(x_next, A_0) + mul!(x_next, A_1, x, 1.0, 1.0) + @inbounds for i in 1:n_x + x_next[i] += dot(u_f, view(A_2, i, :, :), u_f) + end + mul!(x_next, B, w, 1.0, 1.0) + + # Advance u_f for next step + copyto!(u_f, u_f_new) + + return x_next + end + + function g!!(y, x, p, t) + # y = C_0 + C_1 * x + quad(C_2, u_f) + copyto!(y, C_0) + mul!(y, C_1, x, 1.0, 1.0) + @inbounds for i in 1:n_obs + y[i] += dot(u_f, view(C_2, i, :, :), u_f) + end + return y + end + + return f!!, g!! +end + +# --- RBC model matrices (linear) --- + +A_rbc = [ + 0.9568351489231076 6.209371005755285; + 3.0153731819288737e-18 0.20000000000000007 +] +B_rbc = reshape([0.0; -0.01], 2, 1) +C_rbc = [0.09579643002426148 0.6746869652592109; 1.0 0.0] +D_rbc = abs2.([0.1, 0.1]) +u0_rbc = zeros(2) + +observables_rbc_matrix = readdlm( + joinpath(pkgdir(DifferenceEquations), "test/data/RBC_observables.csv"), ',' +)' |> collect +observables_rbc = [observables_rbc_matrix[:, t] for t in 1:size(observables_rbc_matrix, 2)] +noise_rbc_matrix = readdlm( + joinpath(pkgdir(DifferenceEquations), "test/data/RBC_noise.csv"), ',' +)' |> collect +noise_rbc = [noise_rbc_matrix[:, t] for t in 1:size(noise_rbc_matrix, 2)] + +# --- Linear callbacks match LinearStateSpaceProblem --- + +@testset "Generic linear matches LinearStateSpaceProblem — with observations and noise" begin + Random.seed!(1234) + sol_linear = solve(LinearStateSpaceProblem(A_rbc, B_rbc, u0_rbc, (0, 5); C = C_rbc)) + + linear_f!! = (x_next, x, w, p, t) -> begin + mul!(x_next, p.A, x) + mul!(x_next, p.B, w, 1.0, 1.0) + return x_next + end + linear_g!! = (y, x, p, t) -> begin + mul!(y, p.C, x) + return y + end + p = (; A = A_rbc, B = B_rbc, C = C_rbc) + + Random.seed!(1234) + sol_generic = solve( + StateSpaceProblem( + linear_f!!, linear_g!!, u0_rbc, (0, 5), p; + n_shocks = 1, n_obs = 2 + ) + ) + + @test sol_linear.u ≈ sol_generic.u + @test sol_linear.z ≈ sol_generic.z + @test sol_linear.W ≈ sol_generic.W + @test sol_linear.logpdf == 0.0 + @test sol_generic.logpdf == 0.0 +end + +@testset "Generic linear matches — with explicit noise and observables" begin + T = 5 + obs = observables_rbc[1:T] + nse = noise_rbc[1:T] + + sol_linear = solve( + LinearStateSpaceProblem( + A_rbc, B_rbc, u0_rbc, (0, T); C = C_rbc, + observables_noise = Diagonal(D_rbc), noise = nse, observables = obs + ) + ) + + linear_f!! = (x_next, x, w, p, t) -> begin + mul!(x_next, p.A, x) + mul!(x_next, p.B, w, 1.0, 1.0) + return x_next + end + linear_g!! = (y, x, p, t) -> begin + mul!(y, p.C, x) + return y + end + p = (; A = A_rbc, B = B_rbc, C = C_rbc) + + sol_generic = solve( + StateSpaceProblem( + linear_f!!, linear_g!!, u0_rbc, (0, T), p; + n_shocks = 1, n_obs = 2, + observables_noise = Diagonal(D_rbc), noise = nse, observables = obs + ) + ) + + @test sol_linear.u ≈ sol_generic.u + @test sol_linear.z ≈ sol_generic.z + @test sol_linear.logpdf ≈ sol_generic.logpdf +end + +# --- No observation process --- + +@testset "Generic no observation" begin + Random.seed!(1234) + linear_f!! = (x_next, x, w, p, t) -> begin + mul!(x_next, p.A, x) + mul!(x_next, p.B, w, 1.0, 1.0) + return x_next + end + p = (; A = A_rbc, B = B_rbc) + + sol = solve( + StateSpaceProblem( + linear_f!!, nothing, [1.0, 0.5], (0, 5), p; + n_shocks = 1, n_obs = 0 + ) + ) + @test sol.z === nothing + @test length(sol.u) == 6 + + # Compare to LinearStateSpaceProblem with C=nothing + Random.seed!(1234) + sol_linear = solve( + LinearStateSpaceProblem( + A_rbc, B_rbc, [1.0, 0.5], (0, 5); C = nothing + ) + ) + # Must use same seed → same random noise + Random.seed!(1234) + sol_generic = solve( + StateSpaceProblem( + linear_f!!, nothing, [1.0, 0.5], (0, 5), p; + n_shocks = 1, n_obs = 0 + ) + ) + @test sol_linear.u ≈ sol_generic.u +end + +# --- No noise (n_shocks = 0) --- + +@testset "Generic no noise" begin + linear_f!! = (x_next, x, w, p, t) -> begin + mul!(x_next, p.A, x) + return x_next + end + linear_g!! = (y, x, p, t) -> begin + mul!(y, p.C, x) + return y + end + p = (; A = A_rbc, C = C_rbc) + + sol = solve( + StateSpaceProblem( + linear_f!!, linear_g!!, [1.0, 0.5], (0, 5), p; + n_shocks = 0, n_obs = 2 + ) + ) + + @test sol.W === nothing + @test length(sol.u) == 6 + @test length(sol.z) == 6 + + # Compare to LinearStateSpaceProblem with B=nothing + sol_linear = solve( + LinearStateSpaceProblem( + A_rbc, nothing, [1.0, 0.5], (0, 5); C = C_rbc + ) + ) + @test sol_linear.u ≈ sol.u + @test sol_linear.z ≈ sol.z +end + +# --- Observation noise --- + +@testset "Generic with observation noise" begin + T = 20 + B_no_noise = reshape([0.0; 0.0], 2, 1) + u0 = [1.0, 0.5] + + linear_f!! = (x_next, x, w, p, t) -> begin + mul!(x_next, p.A, x) + mul!(x_next, p.B, w, 1.0, 1.0) + return x_next + end + linear_g!! = (y, x, p, t) -> begin + mul!(y, p.C, x) + return y + end + p = (; A = A_rbc, B = B_no_noise, C = C_rbc) + + sol_no_noise = solve( + StateSpaceProblem( + linear_f!!, linear_g!!, u0, (0, T), p; + n_shocks = 1, n_obs = 2 + ) + ) + + sol_obs_noise = solve( + StateSpaceProblem( + linear_f!!, linear_g!!, u0, (0, T), p; + n_shocks = 1, n_obs = 2, observables_noise = Diagonal(D_rbc) + ) + ) + + # Tiny observation noise → nearly deterministic + sol_tiny = solve( + StateSpaceProblem( + linear_f!!, linear_g!!, u0, (0, T), p; + n_shocks = 1, n_obs = 2, observables_noise = Diagonal([1.0e-16, 1.0e-16]) + ) + ) + @test maximum(maximum.(sol_tiny.z - sol_no_noise.z)) < 1.0e-7 + @test maximum(maximum.(sol_tiny.z - sol_no_noise.z)) > 0.0 +end + +# --- Quadratic callbacks: RBC model --- + +A_0_rbc = [-7.824904812740593e-5, 0.0] +A_1_rbc = [0.9568351489231076 6.209371005755285; 3.0153731819288737e-18 0.20000000000000007] +A_2_rbc = cat( + [-0.00019761505863889124 0.03375055315837927; 0.0 0.0], + [0.03375055315837913 3.128758481817603; 0.0 0.0]; dims = 3 +) +B_2_rbc = reshape([0.0; -0.01], 2, 1) +C_0_rbc = [7.824904812740593e-5, 0.0] +C_1_rbc = [0.09579643002426148 0.6746869652592109; 1.0 0.0] +C_2_rbc = cat( + [-0.00018554166974717046 0.0025652363153049716; 0.0 0.0], + [0.002565236315304951 0.3132705036896446; 0.0 0.0]; dims = 3 +) +D_2_rbc = abs2.([0.1, 0.1]) +u0_2_rbc = zeros(2) + +observables_2_rbc_matrix = readdlm( + joinpath(pkgdir(DifferenceEquations), "test/data/RBC_observables.csv"), ',' +)' |> collect +observables_2_rbc = [observables_2_rbc_matrix[:, t] for t in 1:size(observables_2_rbc_matrix, 2)] + +@testset "Quadratic RBC basic inference, simulated noise" begin + f!!, g!! = make_quadratic_callbacks( + A_0_rbc, A_1_rbc, A_2_rbc, B_2_rbc, C_0_rbc, C_1_rbc, C_2_rbc, u0_2_rbc + ) + prob = StateSpaceProblem( + f!!, g!!, u0_2_rbc, (0, length(observables_2_rbc)); + n_shocks = 1, n_obs = 2, + observables_noise = Diagonal(D_2_rbc), observables = observables_2_rbc + ) + sol = solve(prob) + @test sol.logpdf isa Number +end + +@testset "Quadratic RBC simulation, no observations" begin + T = 20 + f!!, g!! = make_quadratic_callbacks( + A_0_rbc, A_1_rbc, A_2_rbc, B_2_rbc, C_0_rbc, C_1_rbc, C_2_rbc, u0_2_rbc + ) + prob = StateSpaceProblem( + f!!, g!!, u0_2_rbc, (0, T); + n_shocks = 1, n_obs = 2 + ) + sol = solve(prob) + @test length(sol.u) == T + 1 + @test length(sol.z) == T + 1 + @test sol.logpdf == 0.0 +end + +@testset "Quadratic RBC deterministic with observation noise" begin + T = 20 + B_no_noise = reshape([0.0; 0.0], 2, 1) + u0 = [1.0, 0.5] + + f_nn!!, g_nn!! = make_quadratic_callbacks( + A_0_rbc, A_1_rbc, A_2_rbc, B_no_noise, C_0_rbc, C_1_rbc, C_2_rbc, u0 + ) + sol_no_noise = solve( + StateSpaceProblem( + f_nn!!, g_nn!!, u0, (0, T); + n_shocks = 1, n_obs = 2 + ) + ) + + f_on!!, g_on!! = make_quadratic_callbacks( + A_0_rbc, A_1_rbc, A_2_rbc, B_no_noise, C_0_rbc, C_1_rbc, C_2_rbc, u0 + ) + sol_obs_noise = solve( + StateSpaceProblem( + f_on!!, g_on!!, u0, (0, T); + n_shocks = 1, n_obs = 2, observables_noise = Diagonal(D_2_rbc) + ) + ) + + f_ti!!, g_ti!! = make_quadratic_callbacks( + A_0_rbc, A_1_rbc, A_2_rbc, B_no_noise, C_0_rbc, C_1_rbc, C_2_rbc, u0 + ) + sol_tiny = solve( + StateSpaceProblem( + f_ti!!, g_ti!!, u0, (0, T); + n_shocks = 1, n_obs = 2, observables_noise = Diagonal([1.0e-16, 1.0e-16]) + ) + ) + @test maximum(maximum.(sol_tiny.z - sol_no_noise.z)) < 1.0e-7 + @test maximum(maximum.(sol_tiny.z - sol_no_noise.z)) > 0.0 +end + +# --- Quadratic likelihood regression values --- + +function quadratic_joint_likelihood( + A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, observables, D; + kwargs... + ) + f!!, g!! = make_quadratic_callbacks(A_0, A_1, A_2, B, C_0, C_1, C_2, u0) + problem = StateSpaceProblem( + f!!, g!!, u0, (0, length(observables)); + n_shocks = size(B, 2), n_obs = length(C_0), + observables_noise = Diagonal(D), noise, observables, + kwargs... + ) + return solve(problem).logpdf +end + +noise_2_rbc_matrix = readdlm( + joinpath(pkgdir(DifferenceEquations), "test/data/RBC_noise.csv"), ',' +)' |> collect +noise_2_rbc = [noise_2_rbc_matrix[:, t] for t in 1:size(noise_2_rbc_matrix, 2)] +T_rbc = 5 +observables_2_rbc_short = observables_2_rbc[1:T_rbc] +noise_2_rbc_short = noise_2_rbc[1:T_rbc] + +@testset "Quadratic RBC basic inference with known noise" begin + f!!, g!! = make_quadratic_callbacks( + A_0_rbc, A_1_rbc, A_2_rbc, B_2_rbc, C_0_rbc, C_1_rbc, C_2_rbc, u0_2_rbc + ) + prob = StateSpaceProblem( + f!!, g!!, u0_2_rbc, (0, length(observables_2_rbc_short)); + n_shocks = 1, n_obs = 2, + observables_noise = Diagonal(D_2_rbc), noise = noise_2_rbc_short, + observables = observables_2_rbc_short + ) + DiffEqBase.get_concrete_problem(prob, false) + sol = solve(prob) +end + +@testset "Quadratic RBC joint likelihood" begin + @test quadratic_joint_likelihood( + A_0_rbc, A_1_rbc, A_2_rbc, B_2_rbc, C_0_rbc, C_1_rbc, C_2_rbc, + u0_2_rbc, noise_2_rbc_short, observables_2_rbc_short, D_2_rbc + ) ≈ -690.81094364573 +end + +# --- FVGQ quadratic data --- + +A_0_raw = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_A_0.csv"), ',') +A_0_FVGQ = vec(A_0_raw) +A_1_FVGQ = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_A_1.csv"), ',') +A_2_raw = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_A_2.csv"), ',') +A_2_FVGQ = reshape(A_2_raw, length(A_0_FVGQ), length(A_0_FVGQ), length(A_0_FVGQ)) +B_2_FVGQ = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_B.csv"), ',') +C_0_raw = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_C_0.csv"), ',') +C_0_FVGQ = vec(C_0_raw) +C_1_FVGQ = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_C_1.csv"), ',') +C_2_raw = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_C_2.csv"), ',') +C_2_FVGQ = reshape(C_2_raw, length(C_0_FVGQ), length(A_0_FVGQ), length(A_0_FVGQ)) +D_2_FVGQ = ones(6) * 1.0e-3 +u0_2_FVGQ = zeros(size(A_1_FVGQ, 1)) + +observables_2_FVGQ_matrix = readdlm( + joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_observables.csv"), ',' +)' |> collect +observables_2_FVGQ = [observables_2_FVGQ_matrix[:, t] for t in 1:size(observables_2_FVGQ_matrix, 2)] +noise_2_FVGQ_matrix = readdlm( + joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_noise.csv"), ',' +)' |> collect +noise_2_FVGQ = [noise_2_FVGQ_matrix[:, t] for t in 1:size(noise_2_FVGQ_matrix, 2)] + +@testset "Quadratic FVGQ joint likelihood" begin + @test quadratic_joint_likelihood( + A_0_FVGQ, A_1_FVGQ, A_2_FVGQ, B_2_FVGQ, C_0_FVGQ, C_1_FVGQ, C_2_FVGQ, + u0_2_FVGQ, noise_2_FVGQ, observables_2_FVGQ, D_2_FVGQ + ) ≈ -1.4728927648336522e7 +end + +# --- Workspace (init/solve!) tests --- + +@testset "solve!() matches solve() — generic linear with noise + obs + obs_noise" begin + T = 5 + obs = observables_rbc[1:T] + nse = noise_rbc[1:T] + + sol_direct = solve( + LinearStateSpaceProblem( + A_rbc, B_rbc, u0_rbc, (0, T); C = C_rbc, + observables_noise = Diagonal(D_rbc), noise = nse, observables = obs + ) + ) + + linear_f!! = (x_next, x, w, p, t) -> begin + mul!(x_next, p.A, x) + mul!(x_next, p.B, w, 1.0, 1.0) + return x_next + end + linear_g!! = (y, x, p, t) -> begin + mul!(y, p.C, x) + return y + end + p = (; A = A_rbc, B = B_rbc, C = C_rbc) + + prob_gen = StateSpaceProblem( + linear_f!!, linear_g!!, u0_rbc, (0, T), p; + n_shocks = 1, n_obs = 2, + observables_noise = Diagonal(D_rbc), noise = nse, observables = obs + ) + ws = init(prob_gen, DirectIteration()) + sol_ws = solve!(ws) + @test sol_ws.u ≈ sol_direct.u + @test sol_ws.z ≈ sol_direct.z + @test sol_ws.logpdf ≈ sol_direct.logpdf +end + +@testset "solve!() matches solve() — quadratic RBC with noise + obs + obs_noise" begin + f!!, g!! = make_quadratic_callbacks( + A_0_rbc, A_1_rbc, A_2_rbc, B_2_rbc, C_0_rbc, C_1_rbc, C_2_rbc, u0_2_rbc + ) + prob = StateSpaceProblem( + f!!, g!!, u0_2_rbc, (0, length(observables_2_rbc_short)); + n_shocks = 1, n_obs = 2, + observables_noise = Diagonal(D_2_rbc), noise = noise_2_rbc_short, + observables = observables_2_rbc_short + ) + sol_direct = solve(prob) + + f!!2, g!!2 = make_quadratic_callbacks( + A_0_rbc, A_1_rbc, A_2_rbc, B_2_rbc, C_0_rbc, C_1_rbc, C_2_rbc, u0_2_rbc + ) + prob2 = StateSpaceProblem( + f!!2, g!!2, u0_2_rbc, (0, length(observables_2_rbc_short)); + n_shocks = 1, n_obs = 2, + observables_noise = Diagonal(D_2_rbc), noise = noise_2_rbc_short, + observables = observables_2_rbc_short + ) + ws = init(prob2, DirectIteration()) + sol_ws = solve!(ws) + @test sol_ws.u ≈ sol_direct.u + @test sol_ws.z ≈ sol_direct.z + @test sol_ws.logpdf ≈ sol_direct.logpdf +end + +@testset "solve!() matches solve() — generic no observation (n_obs=0)" begin + linear_f!! = (x_next, x, w, p, t) -> begin + mul!(x_next, p.A, x) + mul!(x_next, p.B, w, 1.0, 1.0) + return x_next + end + p = (; A = A_rbc, B = B_rbc) + + Random.seed!(1234) + sol_direct = solve( + StateSpaceProblem( + linear_f!!, nothing, [1.0, 0.5], (0, 5), p; + n_shocks = 1, n_obs = 0 + ) + ) + + prob_gen = StateSpaceProblem( + linear_f!!, nothing, [1.0, 0.5], (0, 5), p; + n_shocks = 1, n_obs = 0 + ) + Random.seed!(1234) + ws = init(prob_gen, DirectIteration()) + sol_ws = solve!(ws) + @test sol_ws.z === nothing + @test sol_ws.u ≈ sol_direct.u +end + +@testset "solve!() matches solve() — generic no noise (n_shocks=0)" begin + linear_f!! = (x_next, x, w, p, t) -> begin + mul!(x_next, p.A, x) + return x_next + end + linear_g!! = (y, x, p, t) -> begin + mul!(y, p.C, x) + return y + end + p = (; A = A_rbc, C = C_rbc) + + prob_gen = StateSpaceProblem( + linear_f!!, linear_g!!, [1.0, 0.5], (0, 5), p; + n_shocks = 0, n_obs = 2 + ) + sol_direct = solve(prob_gen) + ws = init(prob_gen, DirectIteration()) + sol_ws = solve!(ws) + @test sol_ws.W === nothing + @test sol_ws.u ≈ sol_direct.u + @test sol_ws.z ≈ sol_direct.z +end + +@testset "solve!() repeated — idempotent generic results" begin + T = 5 + obs = observables_rbc[1:T] + nse = noise_rbc[1:T] + + linear_f!! = (x_next, x, w, p, t) -> begin + mul!(x_next, p.A, x) + mul!(x_next, p.B, w, 1.0, 1.0) + return x_next + end + linear_g!! = (y, x, p, t) -> begin + mul!(y, p.C, x) + return y + end + p = (; A = A_rbc, B = B_rbc, C = C_rbc) + + prob_gen = StateSpaceProblem( + linear_f!!, linear_g!!, u0_rbc, (0, T), p; + n_shocks = 1, n_obs = 2, + observables_noise = Diagonal(D_rbc), noise = nse, observables = obs + ) + ws = init(prob_gen, DirectIteration()) + sol1 = solve!(ws) + sol2 = solve!(ws) + @test sol1.u ≈ sol2.u + @test sol1.z ≈ sol2.z + @test sol1.logpdf ≈ sol2.logpdf +end diff --git a/test/enzyme_test_utils.jl b/test/enzyme_test_utils.jl new file mode 100644 index 0000000..839cb22 --- /dev/null +++ b/test/enzyme_test_utils.jl @@ -0,0 +1,74 @@ +# Shared utilities for Enzyme AD tests + +using LinearAlgebra: LowerTriangular, Symmetric, cholesky + +""" + vech_length(n) + +Number of elements in the lower triangle of an n×n matrix. +""" +vech_length(n) = n * (n + 1) ÷ 2 + +""" + vech(L::AbstractMatrix) + +Extract lower-triangular elements of L into a vector (column-major order). +""" +function vech(L::AbstractMatrix) + n = size(L, 1) + v = zeros(eltype(L), vech_length(n)) + k = 1 + for j in 1:n + for i in j:n + v[k] = L[i, j] + k += 1 + end + end + return v +end + +""" + unvech(v, n) + +Reconstruct an n×n `LowerTriangular` matrix from a vech vector. +""" +function unvech(v, n) + L = zeros(eltype(v), n, n) + k = 1 + for j in 1:n + for i in j:n + L[i, j] = v[k] + k += 1 + end + end + return LowerTriangular(L) +end + +""" + make_posdef_from_vech(v, n) + +Construct a guaranteed positive-definite matrix from a vech vector. +Computes L = unvech(v, n), then returns L * L' as a plain Matrix +(not Symmetric, to avoid type instability with Enzyme AD). +""" +function make_posdef_from_vech(v, n) + L = unvech(v, n) + # Use Matrix(L) * Matrix(L') to avoid LowerTriangular BLAS dispatch + # which Enzyme cannot differentiate (trmm! has no derivative rule). + L_mat = Matrix(L) + return L_mat * L_mat' +end + +""" + make_vech_for(M::AbstractMatrix) + +Given a positive-definite matrix M, compute its Cholesky L factor and return vech(L). +Round-trips: make_posdef_from_vech(make_vech_for(M), n) ≈ Symmetric(M). +""" +function make_vech_for(M::AbstractMatrix) + F = cholesky(Symmetric(M)) + return vech(F.L) +end + +# Note: fdm_gradient removed — Enzyme tests use test_forward/test_reverse from +# EnzymeTestUtils, which handle FD internally via FiniteDifferences.jl. diff --git a/test/forwarddiff_test_utils.jl b/test/forwarddiff_test_utils.jl new file mode 100644 index 0000000..c584987 --- /dev/null +++ b/test/forwarddiff_test_utils.jl @@ -0,0 +1,29 @@ +# Shared utilities for ForwardDiff AD tests + +""" + promote_array(::Type{T}, x) + +Convert array `x` to element type `T`. No-op if already the right type. +""" +promote_array(::Type{T}, x::AbstractArray{T}) where {T} = x +promote_array(::Type{T}, x::AbstractArray) where {T} = T.(x) + +""" + fdm_gradient(f, x; h=1e-7) + +Central finite-difference gradient of scalar function `f` at point `x`. +""" +function fdm_gradient(f, x; h = 1.0e-7) + n = length(x) + grad = zeros(n) + xp = copy(x) + xm = copy(x) + for i in 1:n + xp[i] = x[i] + h + xm[i] = x[i] - h + grad[i] = (f(xp) - f(xm)) / (2h) + xp[i] = x[i] + xm[i] = x[i] + end + return grad +end diff --git a/test/gradient_comparison.jl b/test/gradient_comparison.jl new file mode 100644 index 0000000..571a1fa --- /dev/null +++ b/test/gradient_comparison.jl @@ -0,0 +1,384 @@ +# Apples-to-apples gradient comparison: ForwardDiff vs Enzyme BatchDuplicated vs Enzyme Reverse +# All methods compute the SAME quantity: full gradient of loglik w.r.t. vec(A) (N² components). + +using LinearAlgebra, Test, ForwardDiff, Enzyme, Random +using Enzyme: make_zero, make_zero! +using DifferenceEquations +using DifferenceEquations: init, solve!, StateSpaceWorkspace, fill_zero!! + +include("forwarddiff_test_utils.jl") + +# ============================================================================= +# Kalman problem setup +# ============================================================================= + +const N_gc = 2 +const M_gc = 2 +const K_gc = 2 +const T_gc = 3 +const CHUNK_gc = 2 # batch size for Enzyme BatchDuplicated + +const A_gc = [0.8 0.1; -0.1 0.7] +const B_gc = [0.1 0.0; 0.0 0.1] +const C_gc = [1.0 0.0; 0.0 1.0] +const R_gc = [0.01 0.0; 0.0 0.01] +const mu_0_gc = zeros(N_gc) +const Sigma_0_gc = Matrix{Float64}(I, N_gc, N_gc) +const y_gc = [[0.5, 0.3], [0.2, 0.1], [0.8, 0.4]] + +# Enzyme workspace (pre-allocated Float64) +function _make_gc_workspace() + prob = LinearStateSpaceProblem( + A_gc, B_gc, zeros(N_gc), (0, T_gc); C = C_gc, + u0_prior_mean = mu_0_gc, u0_prior_var = Sigma_0_gc, + observables_noise = R_gc, observables = y_gc + ) + ws = init(prob, KalmanFilter()) + return ws.output, ws.cache +end + +# ============================================================================= +# 1. ForwardDiff gradient w.r.t. vec(A) +# ============================================================================= + +function _kf_loglik_fd(A_vec) + T_el = eltype(A_vec) + A = reshape(A_vec, N_gc, N_gc) + prob = LinearStateSpaceProblem( + A, promote_array(T_el, B_gc), + zeros(T_el, N_gc), (0, T_gc); + C = promote_array(T_el, C_gc), + u0_prior_mean = promote_array(T_el, mu_0_gc), + u0_prior_var = promote_array(T_el, Sigma_0_gc), + observables_noise = promote_array(T_el, R_gc), + observables = y_gc + ) + sol = solve(prob, KalmanFilter()) + return sol.logpdf +end + +# ============================================================================= +# 2. Enzyme BatchDuplicated forward — full gradient via chunked forward passes +# ============================================================================= + +function _kf_loglik_enzyme!(A, B, C, mu_0, Sigma_0, R, y, sol_out, cache) + prob = LinearStateSpaceProblem( + A, B, zeros(eltype(A), size(A, 1)), (0, length(y)); C, + u0_prior_mean = mu_0, u0_prior_var = Sigma_0, + observables_noise = R, observables = y + ) + ws = StateSpaceWorkspace(prob, KalmanFilter(), sol_out, cache) + return solve!(ws).logpdf +end + +function enzyme_batched_forward_gradient_kf!( + grad_out, A, B, C, mu_0, Sigma_0, R, y, + sol_out, cache, chunk_size, + dAs, dBs, dCs, dmu0s, dSig0s, dRs, dys, dsols, dcaches + ) + N_params = length(vec(A)) + for chunk_start in 1:chunk_size:N_params + chunk_end = min(chunk_start + chunk_size - 1, N_params) + actual = chunk_end - chunk_start + 1 + + # Zero all shadows + for k in 1:chunk_size + fill_zero!!(dAs[k]); fill_zero!!(dBs[k]); fill_zero!!(dCs[k]) + fill_zero!!(dmu0s[k]); fill_zero!!(dSig0s[k]); fill_zero!!(dRs[k]) + for t in eachindex(dys[k]) + dys[k][t] = fill_zero!!(dys[k][t]) + end + make_zero!(dsols[k]); make_zero!(dcaches[k]) + end + + # Seed directions: standard basis vectors for vec(A) + for k in 1:actual + dAs[k][chunk_start + k - 1] = 1.0 + end + + result = autodiff( + Forward, _kf_loglik_enzyme!, + BatchDuplicated(A, dAs), + BatchDuplicated(B, dBs), + BatchDuplicated(C, dCs), + BatchDuplicated(mu_0, dmu0s), + BatchDuplicated(Sigma_0, dSig0s), + BatchDuplicated(R, dRs), + BatchDuplicated(y, dys), + BatchDuplicated(sol_out, dsols), + BatchDuplicated(cache, dcaches) + ) + + # Result is ((d1, d2, ...),) for scalar return + derivs = values(result[1]) + for k in 1:actual + grad_out[chunk_start + k - 1] = derivs[k] + end + end + return grad_out +end + +# ============================================================================= +# 3. Enzyme Reverse — full gradient, extract dA +# ============================================================================= + +function enzyme_reverse_gradient_kf!( + A, B, C, mu_0, Sigma_0, R, y, + sol_out, cache, dA, dB, dC, dmu_0, dSigma_0, dR, dy, dsol_out, dcache + ) + make_zero!(dsol_out); make_zero!(dcache) + fill_zero!!(dA); fill_zero!!(dB); fill_zero!!(dC) + fill_zero!!(dmu_0); fill_zero!!(dSigma_0); fill_zero!!(dR) + @inbounds for i in eachindex(dy) + dy[i] = fill_zero!!(dy[i]) + end + + autodiff( + Reverse, _kf_loglik_enzyme!, Active, + Duplicated(A, dA), Duplicated(B, dB), Duplicated(C, dC), + Duplicated(mu_0, dmu_0), Duplicated(Sigma_0, dSigma_0), + Duplicated(R, dR), Duplicated(y, dy), + Duplicated(sol_out, dsol_out), Duplicated(cache, dcache) + ) + return vec(dA) +end + +# ============================================================================= +# Tests +# ============================================================================= + +@testset "Gradient comparison - Kalman loglik w.r.t. vec(A)" begin + A_vec = vec(copy(A_gc)) + + # Finite differences (baseline) + grad_fin = fdm_gradient(_kf_loglik_fd, A_vec) + + # ForwardDiff + grad_fd = ForwardDiff.gradient(_kf_loglik_fd, A_vec) + + # Enzyme BatchDuplicated forward + sol_out_bf, cache_bf = _make_gc_workspace() + N_params = length(A_vec) + dAs = ntuple(_ -> make_zero(A_gc), CHUNK_gc) + dBs = ntuple(_ -> make_zero(B_gc), CHUNK_gc) + dCs = ntuple(_ -> make_zero(C_gc), CHUNK_gc) + dmu0s = ntuple(_ -> make_zero(mu_0_gc), CHUNK_gc) + dSig0s = ntuple(_ -> make_zero(Sigma_0_gc), CHUNK_gc) + dRs = ntuple(_ -> make_zero(R_gc), CHUNK_gc) + dys = ntuple(_ -> [make_zero(y_gc[1]) for _ in 1:T_gc], CHUNK_gc) + dsols = ntuple(_ -> make_zero(sol_out_bf), CHUNK_gc) + dcaches = ntuple(_ -> make_zero(cache_bf), CHUNK_gc) + + grad_enzyme_fwd = zeros(N_params) + enzyme_batched_forward_gradient_kf!( + grad_enzyme_fwd, + A_gc, B_gc, C_gc, mu_0_gc, Sigma_0_gc, R_gc, y_gc, + sol_out_bf, cache_bf, CHUNK_gc, + dAs, dBs, dCs, dmu0s, dSig0s, dRs, dys, dsols, dcaches + ) + + # Enzyme Reverse + sol_out_rv, cache_rv = _make_gc_workspace() + dA_rv = make_zero(A_gc); dB_rv = make_zero(B_gc); dC_rv = make_zero(C_gc) + dmu0_rv = make_zero(mu_0_gc); dSig0_rv = make_zero(Sigma_0_gc); dR_rv = make_zero(R_gc) + dy_rv = [make_zero(y_gc[1]) for _ in 1:T_gc] + dsol_rv = make_zero(sol_out_rv); dcache_rv = make_zero(cache_rv) + + grad_enzyme_rev = enzyme_reverse_gradient_kf!( + A_gc, B_gc, C_gc, mu_0_gc, Sigma_0_gc, R_gc, y_gc, + sol_out_rv, cache_rv, dA_rv, dB_rv, dC_rv, dmu0_rv, dSig0_rv, dR_rv, + dy_rv, dsol_rv, dcache_rv + ) + + @testset "all methods finite" begin + @test all(isfinite, grad_fin) + @test all(isfinite, grad_fd) + @test all(isfinite, grad_enzyme_fwd) + @test all(isfinite, grad_enzyme_rev) + end + + @testset "ForwardDiff matches finite differences" begin + @test grad_fd ≈ grad_fin rtol = 1.0e-4 + end + + @testset "Enzyme BatchDuplicated forward matches finite differences" begin + @test grad_enzyme_fwd ≈ grad_fin rtol = 1.0e-4 + end + + @testset "Enzyme reverse matches finite differences" begin + @test grad_enzyme_rev ≈ grad_fin rtol = 1.0e-4 + end + + @testset "ForwardDiff matches Enzyme reverse (high precision)" begin + @test grad_fd ≈ grad_enzyme_rev rtol = 1.0e-10 + end + + @testset "Enzyme BatchDuplicated forward matches Enzyme reverse (high precision)" begin + @test grad_enzyme_fwd ≈ grad_enzyme_rev rtol = 1.0e-10 + end +end + +# ============================================================================= +# DirectIteration variant +# ============================================================================= + +const A_di_gc = [0.8 0.1; -0.1 0.7] +const B_di_gc = [0.1 0.0; 0.0 0.1] +const C_di_gc = [1.0 0.0; 0.0 1.0] +const H_di_gc = [0.1 0.0; 0.0 0.1] +const u0_di_gc = [0.1, -0.1] +const noise_di_gc = [[0.1, -0.1], [0.2, 0.05], [0.0, 0.1]] +const y_di_gc = [[0.5, 0.3], [0.2, 0.1], [0.8, 0.4]] + +function _make_di_gc_workspace() + R = H_di_gc * H_di_gc' + prob = LinearStateSpaceProblem( + A_di_gc, B_di_gc, u0_di_gc, (0, T_gc); + C = C_di_gc, observables_noise = R, observables = y_di_gc, noise = noise_di_gc + ) + ws = init(prob, DirectIteration()) + return ws.output, ws.cache +end + +function _di_loglik_fd_gc(A_vec) + T_el = eltype(A_vec) + A = reshape(A_vec, N_gc, N_gc) + H = promote_array(T_el, H_di_gc) + R = H * H' + prob = LinearStateSpaceProblem( + A, promote_array(T_el, B_di_gc), + promote_array(T_el, u0_di_gc), (0, T_gc); + C = promote_array(T_el, C_di_gc), + observables_noise = R, + observables = y_di_gc, noise = noise_di_gc + ) + sol = solve(prob, DirectIteration()) + return sol.logpdf +end + +function _di_loglik_enzyme!(A, B, C, u0, noise, y, H, sol_out, cache) + R = H * H' + prob = LinearStateSpaceProblem( + A, B, u0, (0, length(y)); + C, observables_noise = R, observables = y, noise + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol_out, cache) + return solve!(ws).logpdf +end + +function enzyme_batched_forward_gradient_di!( + grad_out, A, B, C, u0, noise, y, H, + sol_out, cache, chunk_size, + dAs, dBs, dCs, du0s, dnoises, dys, dHs, dsols, dcaches + ) + N_params = length(vec(A)) + for chunk_start in 1:chunk_size:N_params + chunk_end = min(chunk_start + chunk_size - 1, N_params) + actual = chunk_end - chunk_start + 1 + + for k in 1:chunk_size + fill_zero!!(dAs[k]); fill_zero!!(dBs[k]); fill_zero!!(dCs[k]) + fill_zero!!(du0s[k]); fill_zero!!(dHs[k]) + for t in eachindex(dnoises[k]) + dnoises[k][t] = fill_zero!!(dnoises[k][t]) + end + for t in eachindex(dys[k]) + dys[k][t] = fill_zero!!(dys[k][t]) + end + make_zero!(dsols[k]); make_zero!(dcaches[k]) + end + for k in 1:actual + dAs[k][chunk_start + k - 1] = 1.0 + end + + result = autodiff( + Forward, _di_loglik_enzyme!, + BatchDuplicated(A, dAs), + BatchDuplicated(B, dBs), + BatchDuplicated(C, dCs), + BatchDuplicated(u0, du0s), + BatchDuplicated(noise, dnoises), + BatchDuplicated(y, dys), + BatchDuplicated(H, dHs), + BatchDuplicated(sol_out, dsols), + BatchDuplicated(cache, dcaches) + ) + + derivs = values(result[1]) + for k in 1:actual + grad_out[chunk_start + k - 1] = derivs[k] + end + end + return grad_out +end + +@testset "Gradient comparison - DI loglik w.r.t. vec(A)" begin + A_vec = vec(copy(A_di_gc)) + N_params = length(A_vec) + + grad_fin = fdm_gradient(_di_loglik_fd_gc, A_vec) + grad_fd = ForwardDiff.gradient(_di_loglik_fd_gc, A_vec) + + # Enzyme BatchDuplicated forward + sol_out_bf, cache_bf = _make_di_gc_workspace() + dAs = ntuple(_ -> make_zero(A_di_gc), CHUNK_gc) + dBs = ntuple(_ -> make_zero(B_di_gc), CHUNK_gc) + dCs = ntuple(_ -> make_zero(C_di_gc), CHUNK_gc) + du0s = ntuple(_ -> make_zero(u0_di_gc), CHUNK_gc) + dnoises = ntuple(_ -> [make_zero(noise_di_gc[1]) for _ in 1:T_gc], CHUNK_gc) + dys = ntuple(_ -> [make_zero(y_di_gc[1]) for _ in 1:T_gc], CHUNK_gc) + dHs = ntuple(_ -> make_zero(H_di_gc), CHUNK_gc) + dsols = ntuple(_ -> make_zero(sol_out_bf), CHUNK_gc) + dcaches = ntuple(_ -> make_zero(cache_bf), CHUNK_gc) + + grad_enzyme_fwd = zeros(N_params) + enzyme_batched_forward_gradient_di!( + grad_enzyme_fwd, + A_di_gc, B_di_gc, C_di_gc, u0_di_gc, noise_di_gc, y_di_gc, H_di_gc, + sol_out_bf, cache_bf, CHUNK_gc, + dAs, dBs, dCs, du0s, dnoises, dys, dHs, dsols, dcaches + ) + + # Enzyme Reverse + sol_out_rv, cache_rv = _make_di_gc_workspace() + dA_rv = make_zero(A_di_gc); dB_rv = make_zero(B_di_gc); dC_rv = make_zero(C_di_gc) + du0_rv = make_zero(u0_di_gc); dH_rv = make_zero(H_di_gc) + dnoise_rv = [make_zero(noise_di_gc[1]) for _ in 1:T_gc] + dy_rv = [make_zero(y_di_gc[1]) for _ in 1:T_gc] + dsol_rv = make_zero(sol_out_rv); dcache_rv = make_zero(cache_rv) + + autodiff( + Reverse, _di_loglik_enzyme!, Active, + Duplicated(A_di_gc, dA_rv), Duplicated(B_di_gc, dB_rv), + Duplicated(C_di_gc, dC_rv), Duplicated(u0_di_gc, du0_rv), + Duplicated(noise_di_gc, dnoise_rv), + Duplicated(y_di_gc, dy_rv), + Duplicated(H_di_gc, dH_rv), + Duplicated(sol_out_rv, dsol_rv), Duplicated(cache_rv, dcache_rv) + ) + grad_enzyme_rev = vec(dA_rv) + + @testset "all methods finite" begin + @test all(isfinite, grad_fin) + @test all(isfinite, grad_fd) + @test all(isfinite, grad_enzyme_fwd) + @test all(isfinite, grad_enzyme_rev) + end + + @testset "ForwardDiff matches finite differences" begin + @test grad_fd ≈ grad_fin rtol = 1.0e-4 + end + + @testset "Enzyme BatchDuplicated forward matches finite differences" begin + @test grad_enzyme_fwd ≈ grad_fin rtol = 1.0e-4 + end + + @testset "Enzyme reverse matches finite differences" begin + @test grad_enzyme_rev ≈ grad_fin rtol = 1.0e-4 + end + + @testset "all AD methods agree (high precision)" begin + @test grad_fd ≈ grad_enzyme_rev rtol = 1.0e-10 + @test grad_enzyme_fwd ≈ grad_enzyme_rev rtol = 1.0e-10 + end +end diff --git a/test/jet/jet_tests.jl b/test/jet/jet_tests.jl index 24d84d8..dff8057 100644 --- a/test/jet/jet_tests.jl +++ b/test/jet/jet_tests.jl @@ -14,7 +14,7 @@ using Test tspan = (0, 10) noise = randn(2, 10) - prob = LinearStateSpaceProblem(A, B, u0, tspan; noise = noise) + prob = LinearStateSpaceProblem(A, B, u0, tspan; noise) rep = JET.report_call(solve, (typeof(prob), typeof(DirectIteration()))) @test length(JET.get_reports(rep)) == 0 end @@ -27,17 +27,17 @@ using Test u0 = [1.0, 0.5] tspan = (0, 10) observables = randn(1, 10) - observables_noise = [0.1] + observables_noise = Diagonal([0.1]) u0_prior_mean = [0.0, 0.0] u0_prior_var = [1.0 0.0; 0.0 1.0] prob = LinearStateSpaceProblem( A, B, u0, tspan; - C = C, - u0_prior_mean = u0_prior_mean, - u0_prior_var = u0_prior_var, - observables_noise = observables_noise, - observables = observables + C, + u0_prior_mean, + u0_prior_var, + observables_noise, + observables ) rep = JET.report_call(solve, (typeof(prob), typeof(KalmanFilter()))) @test length(JET.get_reports(rep)) == 0 @@ -51,7 +51,7 @@ using Test u0 = [1.0, 0.5] tspan = (0, 10) - prob = LinearStateSpaceProblem(A, B, u0, tspan; C = C) + prob = LinearStateSpaceProblem(A, B, u0, tspan; C) rep = JET.report_call(solve, (typeof(prob), typeof(DirectIteration()))) @test length(JET.get_reports(rep)) == 0 end diff --git a/test/kalman.jl b/test/kalman.jl new file mode 100644 index 0000000..260fa84 --- /dev/null +++ b/test/kalman.jl @@ -0,0 +1,433 @@ +using DifferenceEquations, Distributions, LinearAlgebra, Test, DelimitedFiles, DiffEqBase +using DifferenceEquations: init, solve! + +# --- Helpers --- + +function solve_kalman(A, B, C, u0_prior_mean, u0_prior_var, observables, D; kwargs...) + problem = LinearStateSpaceProblem( + A, B, u0_prior_mean, (0, length(observables)); C, + observables_noise = D, + u0_prior_mean, u0_prior_var, + noise = nothing, observables, kwargs... + ) + return solve(problem) +end + +function unvech_5(v) + return LowerTriangular( + hcat( + v[1:5], + [ + zeros(1); + v[6:9] + ], + [ + zeros(2); + v[10:12] + ], + [ + zeros(3); + v[13:14] + ], + [ + zeros(4); + v[15] + ] + ) + ) +end + +function solve_kalman_cov(A, B, C, u0_mean, u0_variance_vech, observables, D; kwargs...) + u0_variance_cholesky = unvech_5(u0_variance_vech) + u0_variance = u0_variance_cholesky * u0_variance_cholesky' + problem = LinearStateSpaceProblem( + A, B, zeros(length(u0_mean)), + (0, length(observables)); C, + observables_noise = D, + u0_prior_mean = u0_mean, u0_prior_var = u0_variance, + noise = nothing, observables, kwargs... + ) + return solve(problem) +end + +get_matrix(R::AbstractVector) = Diagonal(R) +get_matrix(R::AbstractMatrix) = R + +function solve_manual(observables, A, B, C, R_raw, u0_mean, u0_variance, tspan) + T = tspan[2] + @assert tspan[1] == 0 + @assert length(observables) == T + + # Gaussian prior + B_prod = B * B' + R = get_matrix(R_raw) + + u = Vector{Vector{Float64}}(undef, T + 1) # prior mean + P = Vector{Matrix{Float64}}(undef, T + 1) # prior variance + z = Vector{Vector{Float64}}(undef, T + 1) # mean observation + + u[1] = u0_mean + P[1] = u0_variance + z[1] = C * u[1] + loglik = 0.0 + for i in 2:(T + 1) + # Kalman iteration + u[i] = A * u[i - 1] + P[i] = A * P[i - 1] * A' + B_prod + z[i] = C * u[i] + + CP_i = C * P[i] + V_temp = CP_i * C' + R + V = Symmetric((V_temp + V_temp') / 2) + loglik += logpdf(MvNormal(z[i], V), observables[i - 1]) + K = CP_i' / V # gain + u[i] += K * (observables[i - 1] - z[i]) + P[i] -= K * CP_i + end + return z, u, P, loglik +end + +function solve_manual_cov_lik(A, B, C, u0_mean, u0_variance_vech, observables, R_raw, tspan) + T = tspan[2] + @assert tspan[1] == 0 + @assert length(observables) == T + + # Gaussian prior — u0 prior taken from params + u0_variance_cholesky = unvech_5(u0_variance_vech) + u0_variance = u0_variance_cholesky * u0_variance_cholesky' + B_prod = B * B' + R = get_matrix(R_raw) + + u = u0_mean + P = u0_variance + z = C * u + loglik = 0.0 + for i in 2:(T + 1) + # Kalman iteration + u = A * u + P = A * P * A' + B_prod + z = C * u + + CP_i = C * P + V_temp = CP_i * C' + R + V = (V_temp + V_temp') / 2 + loglik += logpdf(MvNormal(z, V), observables[i - 1]) + K = CP_i' / V # gain + u += K * (observables[i - 1] - z) + P -= K * CP_i + end + return loglik +end + +function kalman_likelihood(A, B, C, u0, observables, D; kwargs...) + problem = LinearStateSpaceProblem( + A, B, u0, (0, length(observables)); C, + observables_noise = Diagonal(D), + u0_prior_mean = u0, + u0_prior_var = diagm(ones(length(u0))), + noise = nothing, observables, kwargs... + ) + return solve(problem).logpdf +end + +# --- Kalman test data (5x5 model) --- + +A_kalman = [ + 0.0495388 0.0109918 0.0960529 0.0767147 0.0404643; + 0.020344 0.0627784 0.00865501 0.0394004 0.0601155; + 0.0260677 0.039467 0.0344606 0.033846 0.00224089; + 0.0917289 0.081082 0.0341586 0.0591207 0.0411927; + 0.0837549 0.0515705 0.0429467 0.0209615 0.014668 +] +B_kalman = [ + 0.589064 0.97337 2.32677; + 0.864922 0.695811 0.618615; + 2.07924 1.11661 0.721113; + 0.995325 1.8416 2.30442; + 1.76884 1.56082 0.749023 +] +C_kalman = [ + 0.0979797 0.114992 0.0964536 0.110065 0.0946794; + 0.110095 0.0856981 0.0841296 0.0981172 0.0811817; + 0.109134 0.103406 0.112622 0.0925896 0.112384; + 0.0848231 0.0821602 0.099332 0.113586 0.115105 +] +D_kalman = abs2.(ones(4) * 0.1) +u0_mean_kalman = zeros(5) +u0_var_kalman = diagm(ones(length(u0_mean_kalman))) + +observables_kalman_matrix = readdlm( + joinpath( + pkgdir(DifferenceEquations), + "test/data/Kalman_observables.csv" + ), ',' +)' |> collect +observables_kalman = [observables_kalman_matrix[:, t] for t in 1:size(observables_kalman_matrix, 2)] +T_kalman = 200 + +D_offdiag = [ + 0.01 0.0 0.0 0.0; + 0.0 0.02 0.005 0.01; + 0.0 0.005 0.03 0.0; + 0.0 0.01 0.0 0.04 +] + +u0_mean = [0.0, 0.0, 0.0, 0.0, 0.0] +u0_var_vech = [ + 1.1193770675024004, -0.1755391543370492, -0.8351442110561855, + 0.6799242624030147, + -0.7627861222280011, 0.1346800868329039, 0.46537792458084976, + -0.16223737917345768, 0.1772417632124954, 0.2722945202387173, + -0.3971349857502508, -0.1474011998331263, 0.18113754883619412, + 0.13433861105247683, 0.029171596025489813, +] + +R = [ + 0.01 0.0 0.0 0.0; + 0.0 0.02 0.005 0.01; + 0.0 0.005 0.03 0.0; + 0.0 0.01 0.0 0.04 +] + +# --- RBC model data --- + +A_rbc = [ + 0.9568351489231076 6.209371005755285; + 3.0153731819288737e-18 0.20000000000000007 +] +B_rbc = reshape([0.0; -0.01], 2, 1) +C_rbc = [0.09579643002426148 0.6746869652592109; 1.0 0.0] +D_rbc = abs2.([0.1, 0.1]) +u0_rbc = zeros(2) + +observables_rbc_matrix = readdlm( + joinpath( + pkgdir(DifferenceEquations), + "test/data/RBC_observables.csv" + ), + ',' +)' |> collect +noise_rbc_matrix = readdlm( + joinpath(pkgdir(DifferenceEquations), "test/data/RBC_noise.csv"), + ',' +)' |> + collect +T_rbc = 5 +observables_rbc = [observables_rbc_matrix[:, t] for t in 1:T_rbc] +noise_rbc = [noise_rbc_matrix[:, t] for t in 1:T_rbc] + +# --- FVGQ model data --- + +A_FVGQ = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_A.csv"), ',') +B_FVGQ = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_B.csv"), ',') +C_FVGQ = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_C.csv"), ',') +D_FVGQ = ones(6) * 1.0e-3 + +observables_FVGQ_matrix = readdlm( + joinpath( + pkgdir(DifferenceEquations), + "test/data/FVGQ20_observables.csv" + ), ',' +)' |> collect +observables_FVGQ = [observables_FVGQ_matrix[:, t] for t in 1:size(observables_FVGQ_matrix, 2)] + +noise_FVGQ_matrix = readdlm( + joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_noise.csv"), + ',' +)' |> + collect +noise_FVGQ = [noise_FVGQ_matrix[:, t] for t in 1:size(noise_FVGQ_matrix, 2)] +u0_FVGQ = zeros(size(A_FVGQ, 1)) + +# --- Tests --- + +@testset "Kalman filter — non-square matrices" begin + z, u, P, loglik = solve_manual( + observables_kalman, A_kalman, B_kalman, C_kalman, + D_kalman, + u0_mean_kalman, u0_var_kalman, [0, T_kalman] + ) + sol = solve_kalman( + A_kalman, B_kalman, C_kalman, u0_mean_kalman, u0_var_kalman, + observables_kalman, Diagonal(D_kalman) + ) + @inferred solve_kalman( + A_kalman, B_kalman, C_kalman, u0_mean_kalman, u0_var_kalman, + observables_kalman, + Diagonal(D_kalman) + ) + @test sol.logpdf ≈ loglik + @test sol.logpdf ≈ 329.7550738722514 + @test sol.z ≈ z + @test sol.u ≈ u + @test sol.P ≈ P +end + +@testset "Kalman filter — off-diagonal D" begin + z, u, P, loglik = solve_manual( + observables_kalman, A_kalman, B_kalman, C_kalman, + D_offdiag, + u0_mean_kalman, u0_var_kalman, [0, T_kalman] + ) + sol = solve_kalman( + A_kalman, B_kalman, C_kalman, u0_mean_kalman, u0_var_kalman, + observables_kalman, D_offdiag + ) + @inferred solve_kalman( + A_kalman, B_kalman, C_kalman, u0_mean_kalman, u0_var_kalman, + observables_kalman, + D_offdiag + ) + @test sol.logpdf ≈ loglik + @test sol.logpdf ≈ 124.86949661078718 + @test sol.z ≈ z + @test sol.u ≈ u + @test sol.P ≈ P +end + +@testset "Kalman filter — covariance prior likelihood" begin + loglik = solve_manual_cov_lik( + A_kalman, B_kalman, C_kalman, u0_mean, u0_var_vech, + observables_kalman, + R, [0, T_kalman] + ) + sol = solve_kalman_cov( + A_kalman, B_kalman, C_kalman, u0_mean, u0_var_vech, + observables_kalman, + R + ) + @test sol.logpdf ≈ loglik +end + +@testset "Kalman inference — RBC" begin + prob = LinearStateSpaceProblem( + A_rbc, B_rbc, u0_rbc, (0, length(observables_rbc)); + C = C_rbc, + observables_noise = Diagonal(D_rbc), observables = observables_rbc, + u0_prior_mean = u0_rbc, + u0_prior_var = diagm(ones(length(u0_rbc))) + ) + @inferred LinearStateSpaceProblem( + A_rbc, B_rbc, u0_rbc, (0, length(observables_rbc)); + C = C_rbc, + observables_noise = Diagonal(D_rbc), + observables = observables_rbc, + u0_prior_mean = u0_rbc, + u0_prior_var = diagm(ones(length(u0_rbc))) + ) + + sol = solve(prob) + @inferred solve(prob) + + prob_concrete = DiffEqBase.get_concrete_problem(prob, false) + @inferred DiffEqBase.get_concrete_problem(prob, false) + + kalman_likelihood(A_rbc, B_rbc, C_rbc, u0_rbc, observables_rbc, D_rbc) + @inferred kalman_likelihood(A_rbc, B_rbc, C_rbc, u0_rbc, observables_rbc, D_rbc) +end + +@testset "Kalman likelihood — RBC" begin + @test kalman_likelihood(A_rbc, B_rbc, C_rbc, u0_rbc, observables_rbc, D_rbc) ≈ + -607.3698273765538 + @inferred kalman_likelihood(A_rbc, B_rbc, C_rbc, u0_rbc, observables_rbc, D_rbc) +end + +@testset "Kalman likelihood — FVGQ" begin + @test kalman_likelihood( + A_FVGQ, B_FVGQ, C_FVGQ, u0_FVGQ, observables_FVGQ, + D_FVGQ + ) ≈ + 2253.0905386483046 +end + +@testset "Kalman failure — ill-conditioned A" begin + A = [1.0e20 0.0; 1.0e20 0.0] + u0_prior_var = diagm(1.0e10 * ones(length(u0_rbc))) + prob = LinearStateSpaceProblem( + A, B_rbc, u0_rbc, (0, length(observables_rbc)); + C = C_rbc, + observables_noise = Diagonal(D_rbc), observables = observables_rbc, + u0_prior_mean = u0_rbc, u0_prior_var + ) + @test_throws Exception solve(prob) +end + +# --- Workspace (init/solve!) tests --- + +@testset "solve!() matches solve() — basic Kalman (5x5, non-square)" begin + z_ref, u_ref, P_ref, loglik_ref = solve_manual( + observables_kalman, A_kalman, B_kalman, C_kalman, + D_kalman, u0_mean_kalman, u0_var_kalman, [0, T_kalman] + ) + prob = LinearStateSpaceProblem( + A_kalman, B_kalman, u0_mean_kalman, (0, length(observables_kalman)); + C = C_kalman, observables_noise = Diagonal(D_kalman), + u0_prior_mean = u0_mean_kalman, u0_prior_var = u0_var_kalman, + noise = nothing, observables = observables_kalman + ) + ws = init(prob, KalmanFilter()) + sol_ws = solve!(ws) + @test sol_ws.logpdf ≈ loglik_ref + @test sol_ws.logpdf ≈ 329.7550738722514 + @test sol_ws.z ≈ z_ref + @test sol_ws.u ≈ u_ref + @test sol_ws.P ≈ P_ref +end + +@testset "solve!() matches solve() — off-diagonal D" begin + sol_direct = solve_kalman( + A_kalman, B_kalman, C_kalman, u0_mean_kalman, u0_var_kalman, + observables_kalman, D_offdiag + ) + prob = LinearStateSpaceProblem( + A_kalman, B_kalman, u0_mean_kalman, (0, length(observables_kalman)); + C = C_kalman, observables_noise = D_offdiag, + u0_prior_mean = u0_mean_kalman, u0_prior_var = u0_var_kalman, + noise = nothing, observables = observables_kalman + ) + ws = init(prob, KalmanFilter()) + sol_ws = solve!(ws) + @test sol_ws.logpdf ≈ sol_direct.logpdf + @test sol_ws.logpdf ≈ 124.86949661078718 + @test sol_ws.u ≈ sol_direct.u + @test sol_ws.z ≈ sol_direct.z + @test sol_ws.P ≈ sol_direct.P +end + +@testset "solve!() matches solve() — covariance prior likelihood" begin + sol_direct = solve_kalman_cov( + A_kalman, B_kalman, C_kalman, u0_mean, u0_var_vech, + observables_kalman, R + ) + u0_variance_cholesky = unvech_5(u0_var_vech) + u0_variance = u0_variance_cholesky * u0_variance_cholesky' + prob = LinearStateSpaceProblem( + A_kalman, B_kalman, zeros(length(u0_mean)), + (0, length(observables_kalman)); + C = C_kalman, observables_noise = R, + u0_prior_mean = u0_mean, u0_prior_var = u0_variance, + noise = nothing, observables = observables_kalman + ) + ws = init(prob, KalmanFilter()) + sol_ws = solve!(ws) + @test sol_ws.logpdf ≈ sol_direct.logpdf + @test sol_ws.u ≈ sol_direct.u + @test sol_ws.z ≈ sol_direct.z + @test sol_ws.P ≈ sol_direct.P +end + +@testset "solve!() repeated — idempotent Kalman results" begin + prob = LinearStateSpaceProblem( + A_kalman, B_kalman, u0_mean_kalman, (0, length(observables_kalman)); + C = C_kalman, observables_noise = Diagonal(D_kalman), + u0_prior_mean = u0_mean_kalman, u0_prior_var = u0_var_kalman, + noise = nothing, observables = observables_kalman + ) + ws = init(prob, KalmanFilter()) + sol1 = solve!(ws) + sol2 = solve!(ws) + @test sol1.u ≈ sol2.u + @test sol1.z ≈ sol2.z + @test sol1.P ≈ sol2.P + @test sol1.logpdf ≈ sol2.logpdf +end diff --git a/test/kalman_enzyme.jl b/test/kalman_enzyme.jl new file mode 100644 index 0000000..997af96 --- /dev/null +++ b/test/kalman_enzyme.jl @@ -0,0 +1,281 @@ +# Enzyme AD tests for KalmanFilter +# prob passed as Duplicated — observables get zero shadow automatically. +# GC disabled to avoid Enzyme reverse-mode GC corruption (#2355). + +GC.gc() +GC.enable(false) + +using LinearAlgebra, Test, Enzyme, EnzymeTestUtils, StaticArrays, Random +using DifferenceEquations +using DifferenceEquations: init, solve!, StateSpaceWorkspace +using FiniteDifferences: central_fdm + +include("enzyme_test_utils.jl") # vech helpers only + +# max_range needed: FD perturbation of observables_noise inside prob can push +# the matrix non-positive-definite, causing DomainError in logdet_chol. +const _fdm_kf = central_fdm(5, 1; max_range = 1.0e-3) + +# --- Test setup --- + +const N_kf = 3 +const M_kf = 2 +const K_kf = 2 +const L_kf = 2 +const T_kf = 5 + +Random.seed!(42) +A_raw = randn(N_kf, N_kf) +const A_kf = 0.5 * A_raw / maximum(abs.(eigvals(A_raw))) +const B_kf = 0.1 * randn(N_kf, K_kf) +const C_kf = randn(M_kf, N_kf) +const H_kf = 0.1 * randn(M_kf, L_kf) +const R_kf = H_kf * H_kf' +const mu_0_kf = zeros(N_kf) +const Sigma_0_kf = Matrix{Float64}(I, N_kf, N_kf) + +Random.seed!(123) +const x0_kf = mu_0_kf + cholesky(Sigma_0_kf).L * randn(N_kf) +const noise_kf = [randn(K_kf) for _ in 1:T_kf] +const obs_noise_kf = [randn(L_kf) for _ in 1:T_kf] +const sim_sol_kf = solve( + LinearStateSpaceProblem( + A_kf, B_kf, x0_kf, (0, T_kf); C = C_kf, noise = noise_kf + ) +) +const y_kf = [sim_sol_kf.z[t + 1] + H_kf * obs_noise_kf[t] for t in 1:T_kf] + +# --- Helpers --- + +function make_kalman_prob(A, B, C, R, mu_0, Sigma_0, y) + return LinearStateSpaceProblem( + A, B, zeros(eltype(A), size(A, 1)), (0, length(y)); + C, u0_prior_mean = mu_0, u0_prior_var = Sigma_0, + observables_noise = R, observables = y + ) +end + +# --- Wrappers — prob as single Duplicated arg --- + +function kalman_solve_prob!(prob, sol, cache) + ws = StateSpaceWorkspace(prob, KalmanFilter(), sol, cache) + solve!(ws) + return (sol.u, sol.P, sol.z) +end + +function kalman_loglik_prob(prob, sol, cache)::Float64 + ws = StateSpaceWorkspace(prob, KalmanFilter(), sol, cache) + return solve!(ws).logpdf +end + +# Vech: separate args (y stays Duplicated — remake doesn't work with Enzyme shadows) +function kalman_solve_vech!( + A, B, C, mu_0, sigma_0_vech, r_vech, y, sol, cache, + n_state, n_obs + ) + Sigma_0 = make_posdef_from_vech(sigma_0_vech, n_state) + R = make_posdef_from_vech(r_vech, n_obs) + prob = LinearStateSpaceProblem( + A, B, zeros(eltype(A), size(A, 1)), (0, length(y)); + C, u0_prior_mean = mu_0, u0_prior_var = Sigma_0, + observables_noise = R, observables = y + ) + ws = StateSpaceWorkspace(prob, KalmanFilter(), sol, cache) + solve!(ws) + return (sol.u, sol.P, sol.z) +end + +function kalman_loglik_vech( + A, B, C, mu_0, sigma_0_vech, r_vech, y, sol, cache, + n_state, n_obs + )::Float64 + Sigma_0 = make_posdef_from_vech(sigma_0_vech, n_state) + R = make_posdef_from_vech(r_vech, n_obs) + prob = LinearStateSpaceProblem( + A, B, zeros(eltype(A), size(A, 1)), (0, length(y)); + C, u0_prior_mean = mu_0, u0_prior_var = Sigma_0, + observables_noise = R, observables = y + ) + ws = StateSpaceWorkspace(prob, KalmanFilter(), sol, cache) + return solve!(ws).logpdf +end + +# --- Basic sanity test --- + +@testset "Kalman loglik via solve!() - sanity" begin + prob = make_kalman_prob(A_kf, B_kf, C_kf, R_kf, mu_0_kf, Sigma_0_kf, y_kf) + ws = init(prob, KalmanFilter()) + loglik = kalman_loglik_prob(prob, ws.output, ws.cache) + @test isfinite(loglik) + @test loglik < 0 + + loglik2 = kalman_loglik_prob(prob, ws.output, ws.cache) + @test loglik ≈ loglik2 rtol = 1.0e-12 +end + +# --- Forward — prob as Duplicated (small model, N=M=K=L=2, T=2) --- + +@testset "EnzymeTestUtils - Kalman forward (prob Duplicated)" begin + A_s = [0.8 0.1; -0.1 0.7]; B_s = [0.1 0.0; 0.0 0.1] + C_s = [1.0 0.0; 0.0 1.0]; R_s = [0.01 0.0; 0.0 0.01] + mu_0_s = zeros(2); Sigma_0_s = Matrix{Float64}(I, 2, 2) + y_s = [[0.5, 0.3], [0.2, 0.1]] + prob = make_kalman_prob(A_s, B_s, C_s, R_s, mu_0_s, Sigma_0_s, y_s) + ws = init(prob, KalmanFilter()) + + test_forward( + kalman_solve_prob!, Const, + (prob, Duplicated), + (ws.output, Duplicated), (ws.cache, Duplicated); + fdm = _fdm_kf, + ) +end + +# --- Reverse via vech (all Duplicated) --- + +@testset "EnzymeTestUtils - Kalman reverse via vech (all Duplicated)" begin + _fdm_vech = central_fdm(5, 1) + A_s = [0.8 0.1; -0.1 0.7]; B_s = [0.1 0.0; 0.0 0.1] + C_s = [1.0 0.0; 0.0 1.0]; R_s = [0.01 0.0; 0.0 0.01] + mu_0_s = zeros(2); Sigma_0_s = Matrix{Float64}(I, 2, 2) + y_s = [[0.5, 0.3], [0.2, 0.1]] + sigma_0_v = make_vech_for(Sigma_0_s) + r_v = make_vech_for(R_s) + prob = make_kalman_prob(A_s, B_s, C_s, R_s, mu_0_s, Sigma_0_s, y_s) + ws = init(prob, KalmanFilter()) + + test_reverse( + kalman_loglik_vech, Active, + (copy(A_s), Duplicated), (copy(B_s), Duplicated), + (copy(C_s), Duplicated), (copy(mu_0_s), Duplicated), + (copy(sigma_0_v), Duplicated), (copy(r_v), Duplicated), + ([copy(y) for y in y_s], Duplicated), + (deepcopy(ws.output), Duplicated), (deepcopy(ws.cache), Duplicated), + (2, Const), (2, Const); + fdm = _fdm_vech, + ) +end + +# --- Forward — rectangular B (N!=K) — validates mul_aat!! workaround --- + +@testset "EnzymeTestUtils - Kalman rectangular B forward (prob Duplicated)" begin + A_r = [ + 0.3 0.1 0.0 0.05 0.02; -0.1 0.3 0.05 0.0 0.01; + 0.02 -0.05 0.3 0.1 0.0; 0.0 0.02 -0.1 0.3 0.05; + 0.01 0.0 0.02 -0.05 0.3 + ] + B_r = 0.1 * [1.0 0.5; 0.3 -0.2; 0.7 0.1; -0.4 0.6; 0.2 -0.3] + C_r = [1.0 0.0 0.5 0.0 0.0; 0.0 1.0 0.0 0.5 0.0; 0.0 0.0 1.0 0.0 0.5] + R_r = 0.01 * Matrix{Float64}(I, 3, 3) + mu_0_r = zeros(5); Sigma_0_r = Matrix{Float64}(I, 5, 5) + y_r = [[0.5, 0.3, 0.1], [0.2, -0.1, 0.4], [0.8, 0.4, -0.2]] + prob = make_kalman_prob(A_r, B_r, C_r, R_r, mu_0_r, Sigma_0_r, y_r) + ws = init(prob, KalmanFilter()) + + test_forward( + kalman_solve_prob!, Const, + (prob, Duplicated), + (ws.output, Duplicated), (ws.cache, Duplicated); + fdm = _fdm_kf, + ) +end + +# --- Reverse — rectangular B via vech --- + +@testset "EnzymeTestUtils - Kalman rectangular B reverse via vech (all Duplicated)" begin + _fdm_vech = central_fdm(5, 1) + N_r, M_r = 5, 3 + A_r = [ + 0.3 0.1 0.0 0.05 0.02; -0.1 0.3 0.05 0.0 0.01; + 0.02 -0.05 0.3 0.1 0.0; 0.0 0.02 -0.1 0.3 0.05; + 0.01 0.0 0.02 -0.05 0.3 + ] + B_r = 0.1 * [1.0 0.5; 0.3 -0.2; 0.7 0.1; -0.4 0.6; 0.2 -0.3] + C_r = [1.0 0.0 0.5 0.0 0.0; 0.0 1.0 0.0 0.5 0.0; 0.0 0.0 1.0 0.0 0.5] + R_r = 0.01 * Matrix{Float64}(I, M_r, M_r) + mu_0_r = zeros(N_r); Sigma_0_r = Matrix{Float64}(I, N_r, N_r) + y_r = [[0.5, 0.3, 0.1], [0.2, -0.1, 0.4], [0.8, 0.4, -0.2]] + sigma_0_v = make_vech_for(Sigma_0_r) + r_v = make_vech_for(R_r) + prob = make_kalman_prob(A_r, B_r, C_r, R_r, mu_0_r, Sigma_0_r, y_r) + ws = init(prob, KalmanFilter()) + + test_reverse( + kalman_loglik_vech, Active, + (copy(A_r), Duplicated), (copy(B_r), Duplicated), + (copy(C_r), Duplicated), (copy(mu_0_r), Duplicated), + (copy(sigma_0_v), Duplicated), (copy(r_v), Duplicated), + ([copy(y) for y in y_r], Duplicated), + (deepcopy(ws.output), Duplicated), (deepcopy(ws.cache), Duplicated), + (N_r, Const), (M_r, Const); + fdm = _fdm_vech, + ) +end + +# --- Non-diagonal R via vech (genuinely off-diagonal) --- + +@testset "EnzymeTestUtils - Kalman non-diagonal R forward (vech)" begin + _fdm_vech = central_fdm(5, 1) + A_s = [0.8 0.1; -0.1 0.7]; B_s = [0.1 0.0; 0.0 0.1] + C_s = [1.0 0.0; 0.0 1.0] + R_offdiag = [0.02 0.005; 0.005 0.01] + r_v = make_vech_for(R_offdiag) + mu_0_s = zeros(2); Sigma_0_s = Matrix{Float64}(I, 2, 2) + sigma_0_v = make_vech_for(Sigma_0_s) + y_s = [[0.5, 0.3], [0.2, 0.1]] + prob = make_kalman_prob(A_s, B_s, C_s, R_offdiag, mu_0_s, Sigma_0_s, y_s) + ws = init(prob, KalmanFilter()) + + test_forward( + kalman_solve_vech!, Const, + (copy(A_s), Duplicated), (copy(B_s), Duplicated), + (copy(C_s), Duplicated), (copy(mu_0_s), Duplicated), + (copy(sigma_0_v), Duplicated), (copy(r_v), Duplicated), + ([copy(y) for y in y_s], Duplicated), + (ws.output, Duplicated), (ws.cache, Duplicated), + (2, Const), (2, Const); + fdm = _fdm_vech, + ) +end + +@testset "EnzymeTestUtils - Kalman non-diagonal R reverse (vech)" begin + _fdm_vech = central_fdm(5, 1) + A_s = [0.8 0.1; -0.1 0.7]; B_s = [0.1 0.0; 0.0 0.1] + C_s = [1.0 0.0; 0.0 1.0] + R_offdiag = [0.02 0.005; 0.005 0.01] + r_v = make_vech_for(R_offdiag) + mu_0_s = zeros(2); Sigma_0_s = Matrix{Float64}(I, 2, 2) + sigma_0_v = make_vech_for(Sigma_0_s) + y_s = [[0.5, 0.3], [0.2, 0.1]] + prob = make_kalman_prob(A_s, B_s, C_s, R_offdiag, mu_0_s, Sigma_0_s, y_s) + ws = init(prob, KalmanFilter()) + + test_reverse( + kalman_loglik_vech, Active, + (copy(A_s), Duplicated), (copy(B_s), Duplicated), + (copy(C_s), Duplicated), (copy(mu_0_s), Duplicated), + (copy(sigma_0_v), Duplicated), (copy(r_v), Duplicated), + ([copy(y) for y in y_s], Duplicated), + (deepcopy(ws.output), Duplicated), (deepcopy(ws.cache), Duplicated), + (2, Const), (2, Const); + fdm = _fdm_vech, + ) +end + +# --- Regression test --- + +@testset "Kalman loglik - regression test" begin + A_reg = [0.9 0.1; -0.1 0.9]; B_reg = [0.1 0.0; 0.0 0.1] + C_reg = [1.0 0.0; 0.0 1.0]; R_reg = [0.01 0.0; 0.0 0.01] + mu_0_reg = [0.0, 0.0]; Sigma_0_reg = [1.0 0.0; 0.0 1.0] + y_reg = [[0.5, -0.3], [0.8, -0.1], [0.6, 0.2]] + prob = make_kalman_prob(A_reg, B_reg, C_reg, R_reg, mu_0_reg, Sigma_0_reg, y_reg) + ws = init(prob, KalmanFilter()) + + loglik = kalman_loglik_prob(prob, ws.output, ws.cache) + @test isfinite(loglik) + @test loglik < 0 + @test length(ws.output.u) == 4 +end + +GC.enable(true) diff --git a/test/kalman_forwarddiff.jl b/test/kalman_forwarddiff.jl new file mode 100644 index 0000000..f5f7173 --- /dev/null +++ b/test/kalman_forwarddiff.jl @@ -0,0 +1,168 @@ +# ForwardDiff AD tests for Kalman filter +# Tests gradient correctness against central finite differences. + +using LinearAlgebra, Test, ForwardDiff, StaticArrays, Random +using DifferenceEquations + +include("forwarddiff_test_utils.jl") + +# ============================================================================= +# Problem setup +# ============================================================================= + +const N_kf_fd = 3 +const M_kf_fd = 2 +const K_kf_fd = 2 +const T_kf_fd = 5 + +Random.seed!(42) +A_raw_kf_fd = randn(N_kf_fd, N_kf_fd) +const A_kf_fd = 0.5 * A_raw_kf_fd / maximum(abs.(eigvals(A_raw_kf_fd))) +const B_kf_fd = 0.1 * randn(N_kf_fd, K_kf_fd) +const C_kf_fd = randn(M_kf_fd, N_kf_fd) +const H_kf_fd = 0.1 * randn(M_kf_fd, M_kf_fd) +const R_kf_fd = H_kf_fd * H_kf_fd' + 0.01 * I +const mu_0_kf_fd = zeros(N_kf_fd) +const Sigma_0_kf_fd = Matrix{Float64}(I, N_kf_fd, N_kf_fd) + +Random.seed!(123) +const x0_kf_fd = randn(N_kf_fd) +const noise_sim_kf_fd = [randn(K_kf_fd) for _ in 1:T_kf_fd] +const sim_sol_kf_fd = solve( + LinearStateSpaceProblem( + A_kf_fd, B_kf_fd, x0_kf_fd, (0, T_kf_fd); C = C_kf_fd, noise = noise_sim_kf_fd + ) +) +const y_kf_fd = [sim_sol_kf_fd.z[t + 1] + H_kf_fd * randn(M_kf_fd) for t in 1:T_kf_fd] + +# ============================================================================= +# Mutable arrays — ForwardDiff gradient tests +# ============================================================================= + +function kalman_loglik_fd(A, B, C, mu_0, Sigma_0, R, y) + T_el = promote_type( + eltype(A), eltype(B), eltype(C), + eltype(mu_0), eltype(Sigma_0), eltype(R) + ) + prob = LinearStateSpaceProblem( + promote_array(T_el, A), promote_array(T_el, B), + zeros(T_el, size(A, 1)), (0, length(y)); + C = promote_array(T_el, C), + u0_prior_mean = promote_array(T_el, mu_0), + u0_prior_var = promote_array(T_el, Sigma_0), + observables_noise = promote_array(T_el, R), + observables = y + ) + sol = solve(prob, KalmanFilter()) + return sol.logpdf +end + +@testset "ForwardDiff - Kalman Filter (mutable)" begin + @testset "primal sanity" begin + loglik_val = kalman_loglik_fd( + A_kf_fd, B_kf_fd, C_kf_fd, + mu_0_kf_fd, Sigma_0_kf_fd, R_kf_fd, y_kf_fd + ) + @test isfinite(loglik_val) + @test loglik_val < 0 + end + + @testset "gradient w.r.t. A" begin + f = a_vec -> kalman_loglik_fd( + reshape(a_vec, N_kf_fd, N_kf_fd), + B_kf_fd, C_kf_fd, mu_0_kf_fd, Sigma_0_kf_fd, R_kf_fd, y_kf_fd + ) + x0 = vec(copy(A_kf_fd)) + @test ForwardDiff.gradient(f, x0) ≈ fdm_gradient(f, x0) rtol = 1.0e-4 + end + + @testset "gradient w.r.t. B" begin + f = b_vec -> kalman_loglik_fd( + A_kf_fd, reshape(b_vec, N_kf_fd, K_kf_fd), + C_kf_fd, mu_0_kf_fd, Sigma_0_kf_fd, R_kf_fd, y_kf_fd + ) + x0 = vec(copy(B_kf_fd)) + @test ForwardDiff.gradient(f, x0) ≈ fdm_gradient(f, x0) rtol = 1.0e-4 + end + + @testset "gradient w.r.t. C" begin + f = c_vec -> kalman_loglik_fd( + A_kf_fd, B_kf_fd, + reshape(c_vec, M_kf_fd, N_kf_fd), mu_0_kf_fd, Sigma_0_kf_fd, R_kf_fd, y_kf_fd + ) + x0 = vec(copy(C_kf_fd)) + @test ForwardDiff.gradient(f, x0) ≈ fdm_gradient(f, x0) rtol = 1.0e-4 + end + + @testset "gradient w.r.t. mu_0" begin + f = m_vec -> kalman_loglik_fd( + A_kf_fd, B_kf_fd, C_kf_fd, + m_vec, Sigma_0_kf_fd, R_kf_fd, y_kf_fd + ) + x0 = copy(mu_0_kf_fd) + @test ForwardDiff.gradient(f, x0) ≈ fdm_gradient(f, x0) rtol = 1.0e-4 + end +end + +# ============================================================================= +# StaticArrays — ForwardDiff gradient tests +# ============================================================================= + +const A_kf_fd_s = SMatrix{N_kf_fd, N_kf_fd}(A_kf_fd) +const B_kf_fd_s = SMatrix{N_kf_fd, K_kf_fd}(B_kf_fd) +const C_kf_fd_s = SMatrix{M_kf_fd, N_kf_fd}(C_kf_fd) +const R_kf_fd_s = SMatrix{M_kf_fd, M_kf_fd}(R_kf_fd) +const mu_0_kf_fd_s = SVector{N_kf_fd}(mu_0_kf_fd) +const Sigma_0_kf_fd_s = SMatrix{N_kf_fd, N_kf_fd}(Sigma_0_kf_fd) +const y_kf_fd_s = [SVector{M_kf_fd}(yi) for yi in y_kf_fd] + +function kalman_loglik_fd_static( + A_vec, B, C, mu_0, Sigma_0, R, y, + ::Val{N}, ::Val{M}, ::Val{K} + ) where {N, M, K} + T_el = eltype(A_vec) + A = SMatrix{N, N}(reshape(A_vec, N, N)) + prob = LinearStateSpaceProblem( + A, SMatrix{N, K}(T_el.(B)), + SVector{N}(zeros(T_el, N)), (0, length(y)); + C = SMatrix{M, N}(T_el.(C)), + u0_prior_mean = SVector{N}(T_el.(mu_0)), + u0_prior_var = SMatrix{N, N}(T_el.(Sigma_0)), + observables_noise = SMatrix{M, M}(T_el.(R)), + observables = y + ) + sol = solve(prob, KalmanFilter()) + return sol.logpdf +end + +@testset "ForwardDiff - Kalman Filter (static)" begin + @testset "gradient w.r.t. A" begin + f = a_vec -> kalman_loglik_fd_static( + a_vec, B_kf_fd_s, C_kf_fd_s, + mu_0_kf_fd_s, Sigma_0_kf_fd_s, R_kf_fd_s, y_kf_fd_s, + Val(N_kf_fd), Val(M_kf_fd), Val(K_kf_fd) + ) + x0 = collect(vec(Matrix(A_kf_fd))) + @test ForwardDiff.gradient(f, x0) ≈ fdm_gradient(f, x0) rtol = 1.0e-4 + end + + @testset "gradient w.r.t. C" begin + f = c_vec -> begin + T_el = eltype(c_vec) + prob = LinearStateSpaceProblem( + SMatrix{N_kf_fd, N_kf_fd}(T_el.(A_kf_fd)), + SMatrix{N_kf_fd, K_kf_fd}(T_el.(B_kf_fd)), + SVector{N_kf_fd}(zeros(T_el, N_kf_fd)), (0, length(y_kf_fd_s)); + C = SMatrix{M_kf_fd, N_kf_fd}(reshape(c_vec, M_kf_fd, N_kf_fd)), + u0_prior_mean = SVector{N_kf_fd}(T_el.(mu_0_kf_fd)), + u0_prior_var = SMatrix{N_kf_fd, N_kf_fd}(T_el.(Sigma_0_kf_fd)), + observables_noise = SMatrix{M_kf_fd, M_kf_fd}(T_el.(R_kf_fd)), + observables = y_kf_fd_s + ) + sol = solve(prob, KalmanFilter()) + return sol.logpdf + end + x0 = collect(vec(Matrix(C_kf_fd))) + @test ForwardDiff.gradient(f, x0) ≈ fdm_gradient(f, x0) rtol = 1.0e-4 + end +end diff --git a/test/kalman_likelihood.jl b/test/kalman_likelihood.jl deleted file mode 100644 index 20e3087..0000000 --- a/test/kalman_likelihood.jl +++ /dev/null @@ -1,365 +0,0 @@ -using ChainRulesCore, ChainRulesTestUtils, DifferenceEquations, Distributions, - LinearAlgebra, Test, - Zygote -using DelimitedFiles -using DiffEqBase -using FiniteDiff: finite_difference_gradient - -# inv_vech in dssm repo manually with slices instead of given code -# in diffeq turn the cholesky pdef check off in fvgq in linear.jl - -function solve_kalman(A, B, C, u0_prior_mean, u0_prior_var, observables, D; kwargs...) - problem = LinearStateSpaceProblem( - A, B, u0_prior_mean, (0, size(observables, 2)); C, - observables_noise = D, - u0_prior_mean, u0_prior_var, - noise = nothing, observables, kwargs... - ) - return solve(problem) -end - -function unvech_5(v) - return LowerTriangular( - hcat( - v[1:5], - [ - zeros(1); - v[6:9] - ], - [ - zeros(2); - v[10:12] - ], - [ - zeros(3); - v[13:14] - ], - [ - zeros(4); - v[15] - ] - ) - ) -end - -function solve_kalman_cov(A, B, C, u0_mean, u0_variance_vech, observables, D; kwargs...) - # manually inverse-vech the u0_variance_vech back into a matrix - u0_variance_cholesky = unvech_5(u0_variance_vech) - u0_variance = u0_variance_cholesky * u0_variance_cholesky' - problem = LinearStateSpaceProblem( - A, B, zeros(length(u0_mean)), - (0, size(observables, 2)); C, - observables_noise = D, - u0_prior_mean = u0_mean, u0_prior_var = u0_variance, - noise = nothing, observables, kwargs... - ) - return solve(problem) -end - -get_matrix(R::AbstractVector) = Diagonal(R) -get_matrix(R::AbstractMatrix) = R -function solve_manual(observables, A, B, C, R_raw, u0_mean, u0_variance, tspan) - # hardcoded right now for tspan = (0, T) for T+1 points - T = tspan[2] - @assert tspan[1] == 0 - @assert size(observables)[2] == T # i.e. we do not calculate the likelihood of the initial condition - - # Gaussian Prior - B_prod = B * B' - R = get_matrix(R_raw) - - # TODO: when saveall = false, etc. don't allocate everything, or at least don't save it - u = Zygote.Buffer(Vector{Vector{Float64}}(undef, T + 1)) # prior mean - P = Zygote.Buffer(Vector{Matrix{Float64}}(undef, T + 1)) # prior variance - z = Zygote.Buffer(Vector{Vector{Float64}}(undef, T + 1)) # mean observation - - u[1] = u0_mean - P[1] = u0_variance - z[1] = C * u[1] - loglik = 0.0 - for i in 2:(T + 1) - # Kalman iteration - u[i] = A * u[i - 1] - P[i] = A * P[i - 1] * A' + B_prod - z[i] = C * u[i] - - CP_i = C * P[i] - V_temp = CP_i * C' + R - V = Symmetric((V_temp + V_temp') / 2) - loglik += logpdf(MvNormal(z[i], V), observables[:, i - 1]) - K = CP_i' / V # gain - u[i] += K * (observables[:, i - 1] - z[i]) - P[i] -= K * CP_i - end - return copy(z), copy(u), copy(P), loglik -end - -function solve_manual_cov_lik(A, B, C, u0_mean, u0_variance_vech, observables, R_raw, tspan) - # hardcoded right now for tspan = (0, T) for T+1 points - T = tspan[2] - @assert tspan[1] == 0 - @assert size(observables)[2] == T # i.e. we do not calculate the likelihood of the initial condition - - # Gaussian Prior - # u0 prior taken from params - u0_variance_cholesky = unvech_5(u0_variance_vech) - u0_variance = u0_variance_cholesky * u0_variance_cholesky' - B_prod = B * B' - R = get_matrix(R_raw) - - u = u0_mean - P = u0_variance - z = C * u - loglik = 0.0 - for i in 2:(T + 1) - # Kalman iteration - u = A * u - P = A * P * A' + B_prod - z = C * u - - CP_i = C * P - V_temp = CP_i * C' + R - V = (V_temp + V_temp') / 2 - loglik += logpdf(MvNormal(z, V), observables[:, i - 1]) - K = CP_i' / V # gain - u += K * (observables[:, i - 1] - z) - P -= K * CP_i - end - return loglik -end - -A_kalman = [ - 0.0495388 0.0109918 0.0960529 0.0767147 0.0404643; - 0.020344 0.0627784 0.00865501 0.0394004 0.0601155; - 0.0260677 0.039467 0.0344606 0.033846 0.00224089; - 0.0917289 0.081082 0.0341586 0.0591207 0.0411927; - 0.0837549 0.0515705 0.0429467 0.0209615 0.014668 -] -B_kalman = [ - 0.589064 0.97337 2.32677; - 0.864922 0.695811 0.618615; - 2.07924 1.11661 0.721113; - 0.995325 1.8416 2.30442; - 1.76884 1.56082 0.749023 -] -C_kalman = [ - 0.0979797 0.114992 0.0964536 0.110065 0.0946794; - 0.110095 0.0856981 0.0841296 0.0981172 0.0811817; - 0.109134 0.103406 0.112622 0.0925896 0.112384; - 0.0848231 0.0821602 0.099332 0.113586 0.115105 -] -D_kalman = abs2.(ones(4) * 0.1) -u0_mean_kalman = zeros(5) -u0_var_kalman = diagm(ones(length(u0_mean_kalman))) - -observables_kalman = readdlm( - joinpath( - pkgdir(DifferenceEquations), - "test/data/Kalman_observables.csv" - ), ',' -)' |> collect -T = 200 - -@testset "basic test, non-square matrices" begin - z, u, - P, - loglik = solve_manual( - observables_kalman, A_kalman, B_kalman, C_kalman, - D_kalman, - u0_mean_kalman, u0_var_kalman, [0, T] - ) - sol = solve_kalman( - A_kalman, B_kalman, C_kalman, u0_mean_kalman, u0_var_kalman, - observables_kalman, D_kalman - ) - @inferred solve_kalman( - A_kalman, B_kalman, C_kalman, u0_mean_kalman, u0_var_kalman, - observables_kalman, - D_kalman - ) - @test sol.logpdf ≈ loglik - @test sol.logpdf ≈ 329.7550738722514 - @test sol.z ≈ z - @test sol.u ≈ u - @test sol.P ≈ P - gradient( - (args...) -> solve_kalman( - args..., u0_var_kalman, observables_kalman, - D_kalman - ).logpdf, - A_kalman, - B_kalman, - C_kalman, u0_mean_kalman - ) - test_rrule( - Zygote.ZygoteRuleConfig(), - (args...) -> solve_kalman( - args..., u0_var_kalman, observables_kalman, - D_kalman - ).logpdf, A_kalman, - B_kalman, C_kalman, - u0_mean_kalman; rrule_f = rrule_via_ad, check_inferred = false - ) -end - -D_offdiag = [ - 0.01 0.0 0.0 0.0; - 0.0 0.02 0.005 0.01; - 0.0 0.005 0.03 0.0; - 0.0 0.01 0.0 0.04 -] -@testset "off-diagonal D" begin - z, u, - P, - loglik = solve_manual( - observables_kalman, A_kalman, B_kalman, C_kalman, - D_offdiag, - u0_mean_kalman, u0_var_kalman, [0, T] - ) - sol = solve_kalman( - A_kalman, B_kalman, C_kalman, u0_mean_kalman, u0_var_kalman, - observables_kalman, D_offdiag - ) - @inferred solve_kalman( - A_kalman, B_kalman, C_kalman, u0_mean_kalman, u0_var_kalman, - observables_kalman, - D_offdiag - ) - @test sol.logpdf ≈ loglik - @test sol.logpdf ≈ 124.86949661078718 - @test sol.z ≈ z - @test sol.u ≈ u - @test sol.P ≈ P - gradient( - (args...) -> solve_kalman( - args..., u0_var_kalman, observables_kalman, - D_offdiag - ).logpdf, - A_kalman, - B_kalman, - C_kalman, u0_mean_kalman - ) - test_rrule( - Zygote.ZygoteRuleConfig(), - (args...) -> solve_kalman( - args..., u0_var_kalman, observables_kalman, - D_offdiag - ).logpdf, A_kalman, - B_kalman, C_kalman, - u0_mean_kalman; rrule_f = rrule_via_ad, check_inferred = false - ) -end - -@testset "direct rrule" begin - z, u, - P, - loglik = solve_manual( - observables_kalman, A_kalman, B_kalman, C_kalman, - D_kalman, - u0_mean_kalman, u0_var_kalman, [0, T] - ) - problem = LinearStateSpaceProblem( - A_kalman, B_kalman, u0_mean_kalman, - (0, size(observables_kalman, 2)); - C = C_kalman, observables_noise = D_kalman, - u0_prior_mean = u0_mean_kalman, - u0_prior_var = u0_var_kalman, - observables = observables_kalman - ) - - sol, pb = ChainRulesCore.rrule(DiffEqBase.solve, problem, KalmanFilter()) - @test sol.logpdf ≈ loglik - @test sol.logpdf ≈ 329.7550738722514 - @test sol.z ≈ z - @test sol.u ≈ u - @test sol.P ≈ P -end - -u0_mean = [0.0, 0.0, 0.0, 0.0, 0.0] -# [0.46278392661230217, -0.35157252508544934, -0.33952978655645105, -0.3486954393399204, 0.6934920135433433] -u0_var_vech = [ - 1.1193770675024004, -0.1755391543370492, -0.8351442110561855, - 0.6799242624030147, - -0.7627861222280011, 0.1346800868329039, 0.46537792458084976, - -0.16223737917345768, 0.1772417632124954, 0.2722945202387173, - -0.3971349857502508, -0.1474011998331263, 0.18113754883619412, - 0.13433861105247683, 0.029171596025489813, -] -@testset "covariance prior" begin - sol = solve_kalman_cov( - A_kalman, B_kalman, C_kalman, u0_mean, u0_var_vech, - observables_kalman, - D_offdiag - ) - test_rrule( - Zygote.ZygoteRuleConfig(), - (args...) -> solve_kalman_cov(args..., observables_kalman, D_offdiag).logpdf, - A_kalman, - B_kalman, C_kalman, - u0_mean, u0_var_vech; rrule_f = rrule_via_ad, check_inferred = false - ) - grad_values = gradient( - (args...) -> solve_kalman_cov(args...).logpdf, A_kalman, - B_kalman, - C_kalman, - u0_mean, - u0_var_vech, observables_kalman, - D_offdiag - ) - - @test grad_values[1] ≈ - finite_difference_gradient( - A -> solve_kalman_cov( - A, B_kalman, C_kalman, u0_mean, - u0_var_vech, - observables_kalman, - D_offdiag - ).logpdf, - A_kalman - ) rtol = 1.0e-7 - - # try this with non-zero mean - @test grad_values[4] ≈ - finite_difference_gradient( - u0_mean_vec -> solve_kalman_cov( - A_kalman, B_kalman, C_kalman, - u0_mean_vec, - u0_var_vech, - observables_kalman, - D_offdiag - ).logpdf, - u0_mean - ) rtol = 1.0e-6 - - @test grad_values[5] ≈ - finite_difference_gradient( - u0_var -> solve_kalman_cov( - A_kalman, B_kalman, C_kalman, - u0_mean, - u0_var, - observables_kalman, - D_offdiag - ).logpdf, - u0_var_vech - ) rtol = 1.4e-5 -end - -R = [ - 0.01 0.0 0.0 0.0; - 0.0 0.02 0.005 0.01; - 0.0 0.005 0.03 0.0; - 0.0 0.01 0.0 0.04 -] -@testset "covariance prior likelihood" begin - loglik = solve_manual_cov_lik( - A_kalman, B_kalman, C_kalman, u0_mean, u0_var_vech, - observables_kalman, - R, [0, T] - ) - sol = solve_kalman_cov( - A_kalman, B_kalman, C_kalman, u0_mean, u0_var_vech, - observables_kalman, - R - ) - @test sol.logpdf ≈ loglik -end diff --git a/test/linear_direct_iteration.jl b/test/linear_direct_iteration.jl new file mode 100644 index 0000000..e1b5e8b --- /dev/null +++ b/test/linear_direct_iteration.jl @@ -0,0 +1,462 @@ +using DifferenceEquations, Distributions, LinearAlgebra, Test, Random +using DelimitedFiles +using DiffEqBase +using DifferenceEquations: init, solve! + +# --- RBC Model Data --- + +A_rbc = [ + 0.9568351489231076 6.209371005755285; + 3.0153731819288737e-18 0.20000000000000007 +] +B_rbc = reshape([0.0; -0.01], 2, 1) # make sure B is a matrix +C_rbc = [0.09579643002426148 0.6746869652592109; 1.0 0.0] +D_rbc = abs2.([0.1, 0.1]) +u0_rbc = zeros(2) + +observables_rbc_matrix = readdlm( + joinpath( + pkgdir(DifferenceEquations), + "test/data/RBC_observables.csv" + ), + ',' +)' |> collect +observables_rbc = [observables_rbc_matrix[:, t] for t in 1:size(observables_rbc_matrix, 2)] + +noise_rbc_matrix = readdlm( + joinpath(pkgdir(DifferenceEquations), "test/data/RBC_noise.csv"), + ',' +)' |> collect + +T_rbc = 5 +observables_rbc_5 = [observables_rbc_matrix[:, t] for t in 1:T_rbc] +noise_rbc_5 = [noise_rbc_matrix[:, t] for t in 1:T_rbc] + +# --- Joint Likelihood Helper --- + +function joint_likelihood_1(A, B, C, u0, noise, observables, D; kwargs...) + problem = LinearStateSpaceProblem( + A, B, u0, (0, length(observables)); C, + observables_noise = Diagonal(D), + noise, observables, kwargs... + ) + return solve(problem).logpdf +end + +# --- Simulation Tests --- + +@testset "simulation with noise, observations, and observation noise" begin + prob = LinearStateSpaceProblem( + A_rbc, B_rbc, u0_rbc, (0, length(observables_rbc)); + C = C_rbc, + observables_noise = Diagonal(D_rbc), observables = observables_rbc, + syms = [:a, :b] + ) + + sol = solve(prob) + @inferred solve(prob) +end + +@testset "simulation with noise, no observations, no observation noise" begin + T = 20 + prob = LinearStateSpaceProblem(A_rbc, B_rbc, u0_rbc, (0, T); C = C_rbc, syms = [:a, :b]) + + sol = solve(prob) + @inferred solve(prob) +end + +@testset "simulation with noise and C, no observation noise" begin + Random.seed!(1234) + sol = solve(LinearStateSpaceProblem(A_rbc, B_rbc, u0_rbc, (0, 5); C = C_rbc)) + @test sol.u ≈ + [ + [0.0, 0.0], [0.0, 0.003597289068234817], + [0.02233690243961772, -0.010152627110638895], + [-0.04166869504075366, 0.0021653707472607075], + [-0.026424481689999797, -0.006756025225207251], + [-0.06723454002062011, -0.00555367682297924], + ] + @test sol.z ≈ + [ + [0.0, 0.0], [0.0024270440446074832, 0.0], + [-0.004710049663169753, 0.02233690243961772], + [-0.002530764810543453, -0.04166869504075366], + [-0.007089573167553201, -0.026424481689999797], + [-0.010187822270025022, -0.06723454002062011], + ] + @test sol.W ≈ + [[-0.3597289068234817], [1.0872084924285859], [-0.4195896169388487], [0.7189099374659392], [0.4202471777937789]] + @test sol.logpdf == 0.0 +end + +@testset "simulation with noise, C, and observation noise" begin + Random.seed!(1234) + sol = solve( + LinearStateSpaceProblem( + A_rbc, B_rbc, u0_rbc, (0, 5); C = C_rbc, + observables_noise = Diagonal(D_rbc) + ) + ) + @test sol.u ≈ + [ + [0.0, 0.0], [0.0, 0.003597289068234817], + [0.02233690243961772, -0.010152627110638895], + [-0.04166869504075366, 0.0021653707472607075], + [-0.026424481689999797, -0.006756025225207251], + [-0.06723454002062011, -0.00555367682297924], + ] + @test sol.z ≈ + [ + [-0.06856709022761191, 0.20547630560640365], + [0.034916316989299055, -0.030490125519643224], + [0.0414594477647271, -0.06215886919798015], + [0.08614040809827415, -0.040311314885592704], + [0.0034755874208198837, -0.08053882074804589], + [-0.07921183287013331, -0.16087605412196193], + ] + @test sol.W ≈ + [[-0.3597289068234817], [1.0872084924285859], [-0.4195896169388487], [0.7189099374659392], [0.4202471777937789]] + @test sol.logpdf == 0.0 +end + +@testset "no noise (B=zeros) vs observation noise" begin + T = 20 + B_no_noise = zeros(2, 2) + u0 = [1.0, 0.5] + prob_no_noise = LinearStateSpaceProblem( + A_rbc, B_no_noise, u0, (0, T); C = C_rbc, + syms = [:a, :b] + ) + + sol_no_noise = solve(prob_no_noise) + + prob_obs_noise = LinearStateSpaceProblem( + A_rbc, B_no_noise, u0, (0, T); C = C_rbc, + syms = [:a, :b], observables_noise = Diagonal(D_rbc) + ) + sol_obs_noise = solve(prob_obs_noise) + @inferred solve(prob_obs_noise) + + sol_tiny_obs_noise = solve( + LinearStateSpaceProblem( + A_rbc, B_no_noise, u0, (0, T); + C = C_rbc, + syms = [:a, :b], + observables_noise = Diagonal([1.0e-16, 1.0e-16]) + ) + ) + @test maximum(maximum.(sol_tiny_obs_noise.z - sol_no_noise.z)) < 1.0e-7 + @test maximum(maximum.(sol_tiny_obs_noise.z - sol_no_noise.z)) > 0.0 +end + +@testset "B=nothing matches B=zeros" begin + T = 5 + B_no_noise = zeros(2, 2) + u0 = [1.0, 0.5] + sol_no_noise = solve( + LinearStateSpaceProblem( + A_rbc, B_no_noise, u0, (0, T); C = C_rbc, + syms = [:a, :b] + ) + ) + + prob = LinearStateSpaceProblem( + A_rbc, nothing, u0, (0, T); C = C_rbc, + syms = [:a, :b] + ) + + sol_nothing_noise = solve(prob) + @inferred solve(prob) + + @test sol_no_noise.z ≈ sol_nothing_noise.z + @test sol_no_noise.u ≈ sol_nothing_noise.u + @test sol_nothing_noise.W === nothing +end + +@testset "C=nothing, no observation process" begin + Random.seed!(1234) + T = 5 + u0 = [1.0, 0.5] + prob = LinearStateSpaceProblem( + A_rbc, B_rbc, u0, (0, T); C = nothing, + syms = [:a, :b] + ) + sol = solve(prob) + @inferred solve(prob) + + @test sol.z === nothing + @test sol.u ≈ [ + [1.0, 0.5], [4.06152065180075, 0.10359728906823484], + [4.5294797207351944, 0.009847372889361128], + [4.395111394835915, 0.006165370747260727], + [4.243680140369242, -0.005956025225207233], + [4.023519148749289, -0.005393676822979223], + ] + @test sol.W ≈ + [[-0.3597289068234817], [1.0872084924285859], [-0.4195896169388487], [0.7189099374659392], [0.4202471777937789]] + @test sol.logpdf == 0.0 +end + +# --- Joint Likelihood Tests --- + +@testset "joint likelihood inference" begin + prob = LinearStateSpaceProblem( + A_rbc, B_rbc, u0_rbc, (0, length(observables_rbc_5)); + C = C_rbc, + observables_noise = Diagonal(D_rbc), noise = noise_rbc_5, + observables = observables_rbc_5 + ) + @inferred LinearStateSpaceProblem( + A_rbc, B_rbc, u0_rbc, (0, length(observables_rbc_5)); + C = C_rbc, observables_noise = Diagonal(D_rbc), + noise = noise_rbc_5, + observables = observables_rbc_5 + ) + + sol = solve(prob) + @inferred solve(prob) + + DiffEqBase.get_concrete_problem(prob, false) + @inferred DiffEqBase.get_concrete_problem(prob, false) + + joint_likelihood_1(A_rbc, B_rbc, C_rbc, u0_rbc, noise_rbc_5, observables_rbc_5, D_rbc) + @inferred joint_likelihood_1( + A_rbc, B_rbc, C_rbc, u0_rbc, noise_rbc_5, observables_rbc_5, + D_rbc + ) +end + +@testset "linear RBC joint likelihood value" begin + @test joint_likelihood_1( + A_rbc, B_rbc, C_rbc, u0_rbc, noise_rbc_5, observables_rbc_5, + D_rbc + ) ≈ + -690.9407412360038 + @inferred joint_likelihood_1( + A_rbc, B_rbc, C_rbc, u0_rbc, noise_rbc_5, observables_rbc_5, + D_rbc + ) +end + +# --- FVGQ Data --- + +A_FVGQ = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_A.csv"), ',') +B_FVGQ = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_B.csv"), ',') +C_FVGQ = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_C.csv"), ',') +D_FVGQ = ones(6) * 1.0e-3 + +observables_FVGQ_matrix = readdlm( + joinpath( + pkgdir(DifferenceEquations), + "test/data/FVGQ20_observables.csv" + ), ',' +)' |> collect +observables_FVGQ = [observables_FVGQ_matrix[:, t] for t in 1:size(observables_FVGQ_matrix, 2)] + +noise_FVGQ_matrix = readdlm( + joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_noise.csv"), + ',' +)' |> collect +noise_FVGQ = [noise_FVGQ_matrix[:, t] for t in 1:size(noise_FVGQ_matrix, 2)] +u0_FVGQ = zeros(size(A_FVGQ, 1)) + +@testset "linear FVGQ joint likelihood" begin + @test joint_likelihood_1( + A_FVGQ, B_FVGQ, C_FVGQ, u0_FVGQ, noise_FVGQ, observables_FVGQ, + D_FVGQ + ) ≈ -1.4613614369686982e6 + @inferred joint_likelihood_1( + A_FVGQ, B_FVGQ, C_FVGQ, u0_FVGQ, noise_FVGQ, + observables_FVGQ, + D_FVGQ + ) +end + +# --- Primal Edge-Case Checks --- + +@testset "z_sum primal" begin + function z_sum(A, B, C, u0, noise, observables, D; kwargs...) + problem = LinearStateSpaceProblem( + A, B, u0, (0, length(observables)); C, + observables_noise = Diagonal(D), + noise, observables, kwargs... + ) + sol = solve(problem) + return sol.z[5][1] + sol.z[3][2] + end + @test z_sum(A_rbc, B_rbc, C_rbc, u0_rbc, noise_rbc_5, observables_rbc_5, D_rbc) ≈ + -0.09008162336682057 +end + +@testset "u_sum primal" begin + function u_sum(A, B, C, u0, noise, observables, D; kwargs...) + problem = LinearStateSpaceProblem( + A, B, u0, (0, length(observables)); C, + observables_noise = Diagonal(D), + noise, observables, kwargs... + ) + sol = solve(problem) + return sol.u[3][1] + sol.u[3][2] + end + @test u_sum(A_rbc, B_rbc, C_rbc, u0_rbc, noise_rbc_5, observables_rbc_5, D_rbc) ≈ + -0.08780558376240931 +end + +@testset "no_observables_sum primal" begin + function no_observables_sum(A, B, C, u0, noise; kwargs...) + problem = LinearStateSpaceProblem( + A, B, u0, (0, length(noise)); C, noise, + kwargs... + ) + sol = solve(problem) + return sol.W[2][1] + sol.W[4][1] + sol.z[2][2] + end + @test no_observables_sum(A_rbc, B_rbc, C_rbc, u0_rbc, noise_rbc_5) ≈ + -0.08892781958364693 +end + +@testset "no_noise primal (B=nothing, C present)" begin + function no_noise(A, C, u0; kwargs...) + problem = LinearStateSpaceProblem(A, nothing, u0, (0, 5); C, kwargs...) + sol = solve(problem) + return sol.z[2][2] + end + u_nonzero = [1.1, 0.2] + @test no_noise(A_rbc, C_rbc, u_nonzero) ≈ 2.2943928649664755 +end + +@testset "no_observation_equation primal (B=nothing, C=nothing)" begin + function no_observation_equation(A, u0; kwargs...) + problem = LinearStateSpaceProblem(A, nothing, u0, (0, 5); kwargs...) + sol = solve(problem) + return sol.u[2][2] + sol.u[4][1] + end + u_nonzero = [1.1, 0.2] + @test no_observation_equation(A_rbc, u_nonzero) ≈ 2.4279222804056597 +end + +@testset "no_observation_equation_noise primal (B present, C=nothing)" begin + function no_observation_equation_noise(A, B, u0; kwargs...) + Random.seed!(1234) + problem = LinearStateSpaceProblem(A, B, u0, (0, 5); kwargs...) + sol = solve(problem) + return sol.u[2][2] + sol.u[4][1] + end + u_nonzero = [1.1, 0.2] + @test no_observation_equation_noise(A_rbc, B_rbc, u_nonzero) ≈ 2.3898508744331406 +end + +@testset "last_state with impulse noise" begin + function last_state_pass_noise(A, B, C, u0, noise) + problem = LinearStateSpaceProblem( + A, B, u0, (0, length(noise)); C, noise, + observables_noise = nothing, observables = nothing + ) + sol = solve(problem) + return sol.u[end][2] + end + T_imp = 20 + impulse_noise = [[i == 1 ? 1.0 : 0.0] for i in 1:T_imp] + u_nonzero = [0.1, 0.2] + val = last_state_pass_noise(A_rbc, B_rbc, C_rbc, u_nonzero, impulse_noise) + @test isfinite(val) +end + +@testset "last_observable with impulse noise" begin + function last_observable_pass_noise(A, B, C, u0, noise) + problem = LinearStateSpaceProblem( + A, B, u0, (0, length(noise)); C, noise, + observables_noise = nothing, observables = nothing + ) + sol = solve(problem) + return sol.z[end][2] + end + T_imp = 20 + impulse_noise = [[i == 1 ? 1.0 : 0.0] for i in 1:T_imp] + u_nonzero = [0.1, 0.2] + val = last_observable_pass_noise(A_rbc, B_rbc, C_rbc, u_nonzero, impulse_noise) + @test isfinite(val) +end + +# --- Workspace (init/solve!) tests --- + +@testset "solve!() matches solve() — simulation with noise, C, and obs_noise" begin + Random.seed!(1234) + prob = LinearStateSpaceProblem( + A_rbc, B_rbc, u0_rbc, (0, 5); C = C_rbc, + observables_noise = Diagonal(D_rbc) + ) + Random.seed!(1234) + sol_direct = solve(prob) + Random.seed!(1234) + ws = init(prob, DirectIteration()) + sol_ws = solve!(ws) + @test sol_ws.u ≈ sol_direct.u + @test sol_ws.z ≈ sol_direct.z + @test sol_ws.logpdf ≈ sol_direct.logpdf +end + +@testset "solve!() matches solve() — joint likelihood (noise + obs + obs_noise)" begin + prob = LinearStateSpaceProblem( + A_rbc, B_rbc, u0_rbc, (0, length(observables_rbc_5)); + C = C_rbc, observables_noise = Diagonal(D_rbc), + noise = noise_rbc_5, observables = observables_rbc_5 + ) + sol_direct = solve(prob) + ws = init(prob, DirectIteration()) + sol_ws = solve!(ws) + @test sol_ws.u ≈ sol_direct.u + @test sol_ws.z ≈ sol_direct.z + @test sol_ws.logpdf ≈ sol_direct.logpdf +end + +@testset "solve!() matches solve() — no observables (noise + C, no obs/obs_noise)" begin + Random.seed!(1234) + prob = LinearStateSpaceProblem(A_rbc, B_rbc, u0_rbc, (0, 5); C = C_rbc) + Random.seed!(1234) + sol_direct = solve(prob) + Random.seed!(1234) + ws = init(prob, DirectIteration()) + sol_ws = solve!(ws) + @test sol_ws.u ≈ sol_direct.u + @test sol_ws.z ≈ sol_direct.z + @test sol_ws.logpdf ≈ sol_direct.logpdf +end + +@testset "solve!() matches solve() — no noise (B=nothing, C present)" begin + u_nonzero = [1.1, 0.2] + prob = LinearStateSpaceProblem(A_rbc, nothing, u_nonzero, (0, 5); C = C_rbc) + sol_direct = solve(prob) + ws = init(prob, DirectIteration()) + sol_ws = solve!(ws) + @test sol_ws.u ≈ sol_direct.u + @test sol_ws.z ≈ sol_direct.z + @test sol_ws.W === nothing + @test sol_ws.logpdf ≈ sol_direct.logpdf +end + +@testset "solve!() matches solve() — no observation equation (B=nothing, C=nothing)" begin + u_nonzero = [1.1, 0.2] + prob = LinearStateSpaceProblem(A_rbc, nothing, u_nonzero, (0, 5)) + sol_direct = solve(prob) + ws = init(prob, DirectIteration()) + sol_ws = solve!(ws) + @test sol_ws.u ≈ sol_direct.u + @test sol_ws.z === nothing + @test sol_ws.W === nothing + @test sol_ws.logpdf ≈ sol_direct.logpdf +end + +@testset "solve!() repeated — idempotent results" begin + prob = LinearStateSpaceProblem( + A_rbc, B_rbc, u0_rbc, (0, length(observables_rbc_5)); + C = C_rbc, observables_noise = Diagonal(D_rbc), + noise = noise_rbc_5, observables = observables_rbc_5 + ) + ws = init(prob, DirectIteration()) + sol1 = solve!(ws) + sol2 = solve!(ws) + @test sol1.u ≈ sol2.u + @test sol1.z ≈ sol2.z + @test sol1.logpdf ≈ sol2.logpdf +end diff --git a/test/linear_direct_iteration_enzyme.jl b/test/linear_direct_iteration_enzyme.jl new file mode 100644 index 0000000..8c8a06b --- /dev/null +++ b/test/linear_direct_iteration_enzyme.jl @@ -0,0 +1,434 @@ +# Enzyme AD tests for DirectIteration +# prob passed as Duplicated — observables get zero shadow automatically. +# GC disabled to avoid Enzyme reverse-mode GC corruption (#2355). + +GC.gc() +GC.enable(false) + +using LinearAlgebra, Test, Enzyme, EnzymeTestUtils, StaticArrays, Random +using DifferenceEquations +using DifferenceEquations: init, solve!, StateSpaceWorkspace +using FiniteDifferences: central_fdm + +include("enzyme_test_utils.jl") # vech helpers only + +# max_range needed: FD perturbation of observables_noise inside prob can push +# the matrix non-positive-definite, causing DomainError in logdet_chol. +const _fdm_di = central_fdm(5, 1; max_range = 1.0e-3) + +# --- Test setup --- + +const N_di = 3; const M_di = 2; const K_di = 2; const L_di = 2; const T_di = 5 + +Random.seed!(42) +A_raw_di = randn(N_di, N_di) +const A_di = 0.5 * A_raw_di / maximum(abs.(eigvals(A_raw_di))) +const B_di = 0.1 * randn(N_di, K_di) +const C_di = randn(M_di, N_di) +const H_di = 0.1 * randn(M_di, L_di) +const u0_di = zeros(N_di) + +Random.seed!(123) +const noise_di = [randn(K_di) for _ in 1:T_di] +const obs_noise_di = [randn(L_di) for _ in 1:T_di] +const sim_sol_di = solve( + LinearStateSpaceProblem( + A_di, B_di, u0_di, (0, T_di); C = C_di, noise = noise_di + ) +) +const y_di = [sim_sol_di.z[t + 1] + H_di * obs_noise_di[t] for t in 1:T_di] + +# --- Helpers --- + +function make_di_prob(A, B, C, u0, noise, y, H) + R = H * H' + return LinearStateSpaceProblem( + A, B, u0, (0, length(y)); + C, observables_noise = R, observables = y, noise + ) +end + +function make_di_sol_cache(A, B, C, u0, noise, y, H) + ws = init(make_di_prob(A, B, C, u0, noise, y, H), DirectIteration()) + return ws.output, ws.cache +end + +# --- Wrappers — prob as single Duplicated arg --- + +function di_solve_prob!(prob, sol, cache) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol, cache) + solve!(ws) + return (sol.u, sol.z) +end + +function di_loglik_prob(prob, sol, cache)::Float64 + ws = StateSpaceWorkspace(prob, DirectIteration(), sol, cache) + return solve!(ws).logpdf +end + +# Scalar wrappers for reverse mode (prob pattern) +function di_z_sum_prob(prob, sol, cache)::Float64 + ws = StateSpaceWorkspace(prob, DirectIteration(), sol, cache) + solve!(ws) + return sol.z[2][1] + sol.z[3][2] +end + +function di_u_sum_prob(prob, sol, cache)::Float64 + ws = StateSpaceWorkspace(prob, DirectIteration(), sol, cache) + solve!(ws) + return sol.u[2][1] + sol.u[3][2] +end + +# Vech: separate args (y stays Duplicated — remake doesn't work with Enzyme shadows) +function di_solve_vech!(A, B, C, u0, noise, y, r_v, n_obs, sol, cache) + prob = LinearStateSpaceProblem( + A, B, u0, (0, length(y)); + C, observables_noise = make_posdef_from_vech(r_v, n_obs), observables = y, noise + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol, cache) + solve!(ws) + return (sol.u, sol.z) +end + +function di_loglik_vech(A, B, C, u0, noise, y, r_v, n_obs, sol, cache)::Float64 + prob = LinearStateSpaceProblem( + A, B, u0, (0, length(y)); + C, observables_noise = make_posdef_from_vech(r_v, n_obs), observables = y, noise + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol, cache) + return solve!(ws).logpdf +end + +# --- Sanity test --- + +@testset "DirectIteration loglik via solve!() - sanity" begin + prob = make_di_prob(A_di, B_di, C_di, u0_di, noise_di, y_di, H_di) + ws = init(prob, DirectIteration()) + loglik = di_loglik_prob(prob, ws.output, ws.cache) + @test isfinite(loglik) + + loglik2 = di_loglik_prob(prob, ws.output, ws.cache) + @test loglik ≈ loglik2 rtol = 1.0e-12 +end + +# --- Forward — prob as Duplicated --- + +@testset "EnzymeTestUtils - DirectIteration forward (prob Duplicated)" begin + A_s = [0.8 0.1; -0.1 0.7]; B_s = [0.1 0.0; 0.0 0.1] + C_s = [1.0 0.0; 0.0 1.0]; H_s = [0.1 0.0; 0.0 0.1] + u0_s = zeros(2); noise_s = [[0.1, -0.1], [0.2, 0.05]] + y_s = [[0.5, 0.3], [0.2, 0.1]] + prob = make_di_prob(A_s, B_s, C_s, u0_s, noise_s, y_s, H_s) + ws = init(prob, DirectIteration()) + + test_forward( + di_solve_prob!, Const, + (prob, Duplicated), + (ws.output, Duplicated), (ws.cache, Duplicated); + fdm = _fdm_di, + ) +end + +# --- Reverse — prob as Duplicated (logpdf) --- + +@testset "EnzymeTestUtils - DirectIteration reverse (prob Duplicated, logpdf)" begin + A_s = [0.8 0.1; -0.1 0.7]; B_s = [0.1 0.0; 0.0 0.1] + C_s = [1.0 0.0; 0.0 1.0]; H_s = [0.1 0.0; 0.0 0.1] + u0_s = zeros(2); noise_s = [[0.1, -0.1], [0.2, 0.05]] + y_s = [[0.5, 0.3], [0.2, 0.1]] + prob = make_di_prob(A_s, B_s, C_s, u0_s, noise_s, y_s, H_s) + ws = init(prob, DirectIteration()) + + test_reverse( + di_loglik_prob, Active, + (prob, Duplicated), + (deepcopy(ws.output), Duplicated), (deepcopy(ws.cache), Duplicated); + fdm = _fdm_di, + ) +end + +# --- Forward — rectangular H (prob as Duplicated) --- + +@testset "EnzymeTestUtils - DirectIteration rectangular H forward (prob Duplicated)" begin + A_r = [0.5 0.1 0.0; -0.1 0.5 0.05; 0.02 -0.05 0.5] + B_r = 0.1 * [1.0 0.5; 0.3 -0.2; 0.7 0.1] + C_r = [1.0 0.0 0.5; 0.0 1.0 0.0] + H_r = 0.1 * [1.0 0.5 0.3; -0.2 0.7 0.1] + u0_r = zeros(3); noise_r = [[0.1, -0.1], [0.2, 0.05]] + y_r = [[0.5, 0.3], [0.2, -0.1]] + prob = make_di_prob(A_r, B_r, C_r, u0_r, noise_r, y_r, H_r) + ws = init(prob, DirectIteration()) + + test_forward( + di_solve_prob!, Const, + (prob, Duplicated), + (ws.output, Duplicated), (ws.cache, Duplicated); + fdm = _fdm_di, + ) +end + +# --- Non-diagonal R via vech parameterization --- + +@testset "EnzymeTestUtils - DirectIteration non-diagonal R forward (vech)" begin + _fdm_vech = central_fdm(5, 1) + A_s = [0.8 0.1; -0.1 0.7]; B_s = [0.1 0.0; 0.0 0.1] + C_s = [1.0 0.0; 0.0 1.0] + u0_s = zeros(2); noise_s = [[0.1, -0.1], [0.2, 0.05]] + y_s = [[0.5, 0.3], [0.2, 0.1]] + R_offdiag = [0.02 0.005; 0.005 0.01] + r_v = make_vech_for(R_offdiag) + sol, cache = make_di_sol_cache( + A_s, B_s, C_s, u0_s, noise_s, y_s, + [sqrt(0.02) 0.0; 0.0 sqrt(0.01)] + ) + + test_forward( + di_solve_vech!, Const, + (copy(A_s), Duplicated), (copy(B_s), Duplicated), + (copy(C_s), Duplicated), (copy(u0_s), Duplicated), + ([copy(n) for n in noise_s], Duplicated), + ([copy(y) for y in y_s], Duplicated), + (copy(r_v), Duplicated), (2, Const), + (sol, Duplicated), (cache, Duplicated); + fdm = _fdm_vech, + ) +end + +@testset "EnzymeTestUtils - DirectIteration non-diagonal R reverse (vech)" begin + _fdm_vech = central_fdm(5, 1) + A_s = [0.8 0.1; -0.1 0.7]; B_s = [0.1 0.0; 0.0 0.1] + C_s = [1.0 0.0; 0.0 1.0] + u0_s = zeros(2); noise_s = [[0.1, -0.1], [0.2, 0.05]] + y_s = [[0.5, 0.3], [0.2, 0.1]] + R_offdiag = [0.02 0.005; 0.005 0.01] + r_v = make_vech_for(R_offdiag) + sol, cache = make_di_sol_cache( + A_s, B_s, C_s, u0_s, noise_s, y_s, + [sqrt(0.02) 0.0; 0.0 sqrt(0.01)] + ) + + test_reverse( + di_loglik_vech, Active, + (copy(A_s), Duplicated), (copy(B_s), Duplicated), + (copy(C_s), Duplicated), (copy(u0_s), Duplicated), + ([copy(n) for n in noise_s], Duplicated), + ([copy(y) for y in y_s], Duplicated), + (copy(r_v), Duplicated), (2, Const), + (deepcopy(sol), Duplicated), (deepcopy(cache), Duplicated); + fdm = _fdm_vech, + ) +end + +# --- Regression test --- + +@testset "DirectIteration loglik - regression test" begin + A_reg = [0.9 0.1; -0.1 0.9]; B_reg = [0.1 0.0; 0.0 0.1] + C_reg = [1.0 0.0; 0.0 1.0]; H_reg = [0.1 0.0; 0.0 0.1] + u0_reg = [0.0, 0.0]; noise_reg = [[0.1, -0.1], [0.2, 0.05], [0.0, 0.1]] + y_reg = [[0.5, -0.3], [0.8, -0.1], [0.6, 0.2]] + prob = make_di_prob(A_reg, B_reg, C_reg, u0_reg, noise_reg, y_reg, H_reg) + ws = init(prob, DirectIteration()) + + loglik = di_loglik_prob(prob, ws.output, ws.cache) + @test isfinite(loglik) + + loglik2 = di_loglik_prob(prob, ws.output, ws.cache) + @test loglik ≈ loglik2 rtol = 1.0e-12 +end + +# --- Edge-case helpers --- + +function _alloc_u(u0, T) + return [similar(u0) for _ in 1:T] +end +function _alloc_uz(u0, C, T) + M = size(C, 1) + return [similar(u0) for _ in 1:T], [zeros(eltype(u0), M) for _ in 1:T] +end +function _alloc_noise_cache(B, T) + return [Vector{eltype(B)}(undef, size(B, 2)) for _ in 1:(T - 1)] +end + +# No observables: B+C present, no obs/obs_noise +function di_no_obs_solve!(A, B, C, u0, noise, u_out, z_out, noise_cache) + prob = LinearStateSpaceProblem(A, B, u0, (0, length(noise)); C, noise) + sol = (; u = u_out, z = z_out) + cache = (; + noise = noise_cache, R = nothing, R_chol = nothing, + innovation = nothing, innovation_solved = nothing, + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol, cache) + solve!(ws) + return (u_out, z_out) +end + +# No noise: B=nothing, C present +function di_no_noise_solve!(A, C, u0, u_out, z_out, T) + prob = LinearStateSpaceProblem(A, nothing, u0, (0, T); C) + sol = (; u = u_out, z = z_out) + cache = (; + noise = nothing, R = nothing, R_chol = nothing, + innovation = nothing, innovation_solved = nothing, + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol, cache) + solve!(ws) + return (u_out, z_out) +end + +# No observation equation: B=nothing, C=nothing +function di_no_obs_eq_solve!(A, u0, u_out, T) + prob = LinearStateSpaceProblem(A, nothing, u0, (0, T)) + sol = (; u = u_out, z = nothing) + cache = (; + noise = nothing, R = nothing, R_chol = nothing, + innovation = nothing, innovation_solved = nothing, + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol, cache) + solve!(ws) + return u_out +end + +# Noise but no observation equation: B present, C=nothing +function di_noise_no_obs_eq_solve!(A, B, u0, noise, u_out, noise_cache) + prob = LinearStateSpaceProblem(A, B, u0, (0, length(noise)); noise) + sol = (; u = u_out, z = nothing) + cache = (; + noise = noise_cache, R = nothing, R_chol = nothing, + innovation = nothing, innovation_solved = nothing, + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol, cache) + solve!(ws) + return u_out +end + +# Impulse response: B+C present, long trajectory +function di_impulse_solve!(A, B, C, u0, noise, u_out, z_out, noise_cache) + prob = LinearStateSpaceProblem(A, B, u0, (0, length(noise)); C, noise) + sol = (; u = u_out, z = z_out) + cache = (; + noise = noise_cache, R = nothing, R_chol = nothing, + innovation = nothing, innovation_solved = nothing, + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol, cache) + solve!(ws) + return (u_out, z_out) +end + +# --- Edge-case forward tests --- + +@testset "EnzymeTestUtils - DirectIteration no observables forward" begin + A_s = [0.8 0.1; -0.1 0.7]; B_s = [0.1 0.0; 0.0 0.1] + C_s = [1.0 0.0; 0.0 1.0] + u0_s = zeros(2); noise_s = [[0.1, -0.1], [0.2, 0.05]] + T_a = length(noise_s) + 1 + u_out, z_out = _alloc_uz(u0_s, C_s, T_a) + nc = _alloc_noise_cache(B_s, T_a) + + test_forward( + di_no_obs_solve!, Const, + (copy(A_s), Duplicated), (copy(B_s), Duplicated), + (copy(C_s), Duplicated), (copy(u0_s), Duplicated), + ([copy(n) for n in noise_s], Duplicated), + (u_out, Duplicated), (z_out, Duplicated), (nc, Duplicated); + fdm = _fdm_di, + ) +end + +@testset "EnzymeTestUtils - DirectIteration no noise forward" begin + A_s = [0.8 0.1; -0.1 0.7]; C_s = [1.0 0.0; 0.0 1.0] + u0_s = [0.5, -0.3]; T_val = 3 + u_out, z_out = _alloc_uz(u0_s, C_s, T_val + 1) + + test_forward( + di_no_noise_solve!, Const, + (copy(A_s), Duplicated), (copy(C_s), Duplicated), + (copy(u0_s), Duplicated), + (u_out, Duplicated), (z_out, Duplicated), + (T_val, Const); + fdm = _fdm_di, + ) +end + +@testset "EnzymeTestUtils - DirectIteration no observation equation forward" begin + A_s = [0.8 0.1; -0.1 0.7] + u0_s = [0.5, -0.3]; T_val = 3 + u_out = _alloc_u(u0_s, T_val + 1) + + test_forward( + di_no_obs_eq_solve!, Const, + (copy(A_s), Duplicated), (copy(u0_s), Duplicated), + (u_out, Duplicated), (T_val, Const); + fdm = _fdm_di, + ) +end + +@testset "EnzymeTestUtils - DirectIteration noise no observation equation forward" begin + A_s = [0.8 0.1; -0.1 0.7]; B_s = [0.1 0.0; 0.0 0.1] + u0_s = zeros(2); noise_s = [[0.1, -0.1], [0.2, 0.05]] + T_d = length(noise_s) + 1 + u_out = _alloc_u(u0_s, T_d) + nc = _alloc_noise_cache(B_s, T_d) + + test_forward( + di_noise_no_obs_eq_solve!, Const, + (copy(A_s), Duplicated), (copy(B_s), Duplicated), + (copy(u0_s), Duplicated), + ([copy(n) for n in noise_s], Duplicated), + (u_out, Duplicated), (nc, Duplicated); + fdm = _fdm_di, + ) +end + +@testset "EnzymeTestUtils - DirectIteration impulse response forward" begin + A_s = [0.8 0.1; -0.1 0.7]; B_s = [0.1 0.0; 0.0 0.1] + C_s = [1.0 0.0; 0.0 1.0]; u0_s = zeros(2) + noise_s = [[1.0, 0.0]]; append!(noise_s, [[0.0, 0.0] for _ in 2:10]) + T_e = length(noise_s) + 1 + u_out, z_out = _alloc_uz(u0_s, C_s, T_e) + nc = _alloc_noise_cache(B_s, T_e) + + test_forward( + di_impulse_solve!, Const, + (copy(A_s), Duplicated), (copy(B_s), Duplicated), + (copy(C_s), Duplicated), (copy(u0_s), Duplicated), + ([copy(n) for n in noise_s], Duplicated), + (u_out, Duplicated), (z_out, Duplicated), (nc, Duplicated); + fdm = _fdm_di, + ) +end + +# --- Reverse: z_sum and u_sum (prob as Duplicated) --- + +@testset "EnzymeTestUtils - DirectIteration z_sum reverse (prob Duplicated)" begin + A_s = [0.8 0.1; -0.1 0.7]; B_s = [0.1 0.0; 0.0 0.1] + C_s = [1.0 0.0; 0.0 1.0]; H_s = [0.1 0.0; 0.0 0.1] + u0_s = zeros(2); noise_s = [[0.1, -0.1], [0.2, 0.05]] + y_s = [[0.5, 0.3], [0.2, 0.1]] + prob = make_di_prob(A_s, B_s, C_s, u0_s, noise_s, y_s, H_s) + ws = init(prob, DirectIteration()) + + test_reverse( + di_z_sum_prob, Active, + (prob, Duplicated), + (deepcopy(ws.output), Duplicated), (deepcopy(ws.cache), Duplicated); + fdm = _fdm_di, + ) +end + +@testset "EnzymeTestUtils - DirectIteration u_sum reverse (prob Duplicated)" begin + A_s = [0.8 0.1; -0.1 0.7]; B_s = [0.1 0.0; 0.0 0.1] + C_s = [1.0 0.0; 0.0 1.0]; H_s = [0.1 0.0; 0.0 0.1] + u0_s = zeros(2); noise_s = [[0.1, -0.1], [0.2, 0.05]] + y_s = [[0.5, 0.3], [0.2, 0.1]] + prob = make_di_prob(A_s, B_s, C_s, u0_s, noise_s, y_s, H_s) + ws = init(prob, DirectIteration()) + + test_reverse( + di_u_sum_prob, Active, + (prob, Duplicated), + (deepcopy(ws.output), Duplicated), (deepcopy(ws.cache), Duplicated); + fdm = _fdm_di, + ) +end + +GC.enable(true) diff --git a/test/linear_direct_iteration_forwarddiff.jl b/test/linear_direct_iteration_forwarddiff.jl new file mode 100644 index 0000000..048db61 --- /dev/null +++ b/test/linear_direct_iteration_forwarddiff.jl @@ -0,0 +1,184 @@ +# ForwardDiff AD tests for DirectIteration (loglik path) +# Tests gradient correctness against central finite differences. + +using LinearAlgebra, Test, ForwardDiff, StaticArrays, Random +using DifferenceEquations + +include("forwarddiff_test_utils.jl") + +# ============================================================================= +# Problem setup +# ============================================================================= + +const N_di_fd = 2 +const M_di_fd = 2 +const K_di_fd = 2 +const T_di_fd = 5 + +const A_di_fd = [0.8 0.1; -0.1 0.7] +const B_di_fd = [0.1 0.0; 0.0 0.1] +const C_di_fd = [1.0 0.0; 0.0 1.0] +const H_di_fd = [0.1 0.0; 0.0 0.1] +const u0_di_fd = zeros(N_di_fd) + +Random.seed!(42) +const noise_di_fd = [randn(K_di_fd) for _ in 1:T_di_fd] +const sim_sol_di_fd = solve( + LinearStateSpaceProblem( + A_di_fd, B_di_fd, u0_di_fd, (0, T_di_fd); C = C_di_fd, noise = noise_di_fd + ) +) +const y_di_fd = [sim_sol_di_fd.z[t + 1] + H_di_fd * randn(M_di_fd) for t in 1:T_di_fd] + +# ============================================================================= +# Mutable arrays — ForwardDiff gradient tests +# ============================================================================= + +function di_loglik_fd(A, B, C, u0, noise, y, H) + T_el = promote_type(eltype(A), eltype(B), eltype(C), eltype(u0), eltype(H)) + R = promote_array(T_el, H) * promote_array(T_el, H)' + prob = LinearStateSpaceProblem( + promote_array(T_el, A), promote_array(T_el, B), + promote_array(T_el, u0), (0, length(y)); + C = promote_array(T_el, C), + observables_noise = R, + observables = y, noise = noise + ) + sol = solve(prob, DirectIteration()) + return sol.logpdf +end + +@testset "ForwardDiff - DirectIteration loglik (mutable)" begin + @testset "primal sanity" begin + loglik_val = di_loglik_fd( + A_di_fd, B_di_fd, C_di_fd, u0_di_fd, + noise_di_fd, y_di_fd, H_di_fd + ) + @test isfinite(loglik_val) + end + + @testset "gradient w.r.t. A" begin + f = a_vec -> di_loglik_fd( + reshape(a_vec, N_di_fd, N_di_fd), + B_di_fd, C_di_fd, u0_di_fd, noise_di_fd, y_di_fd, H_di_fd + ) + x0 = vec(copy(A_di_fd)) + @test ForwardDiff.gradient(f, x0) ≈ fdm_gradient(f, x0) rtol = 1.0e-4 + end + + @testset "gradient w.r.t. u0" begin + f = u_vec -> di_loglik_fd( + A_di_fd, B_di_fd, C_di_fd, + u_vec, noise_di_fd, y_di_fd, H_di_fd + ) + x0 = [0.1, -0.1] + @test ForwardDiff.gradient(f, x0) ≈ fdm_gradient(f, x0) rtol = 1.0e-4 + end + + @testset "gradient w.r.t. H" begin + f = h_vec -> di_loglik_fd( + A_di_fd, B_di_fd, C_di_fd, u0_di_fd, + noise_di_fd, y_di_fd, reshape(h_vec, M_di_fd, M_di_fd) + ) + x0 = vec(copy(H_di_fd)) + @test ForwardDiff.gradient(f, x0) ≈ fdm_gradient(f, x0) rtol = 1.0e-4 + end +end + +# ============================================================================= +# Non-diagonal R — ForwardDiff gradient tests +# ============================================================================= + +const H_di_fd_offdiag = [0.1 0.05; 0.02 0.08] + +function di_loglik_fd_offdiag(A, B, C, u0, noise, y, H) + T_el = promote_type(eltype(A), eltype(B), eltype(C), eltype(u0), eltype(H)) + H_d = promote_array(T_el, H) + R = H_d * H_d' + prob = LinearStateSpaceProblem( + promote_array(T_el, A), promote_array(T_el, B), + promote_array(T_el, u0), (0, length(y)); + C = promote_array(T_el, C), + observables_noise = R, + observables = y, noise = noise + ) + sol = solve(prob, DirectIteration()) + return sol.logpdf +end + +@testset "ForwardDiff - DirectIteration loglik non-diagonal R (mutable)" begin + @testset "primal sanity" begin + loglik_val = di_loglik_fd_offdiag( + A_di_fd, B_di_fd, C_di_fd, u0_di_fd, + noise_di_fd, y_di_fd, H_di_fd_offdiag + ) + @test isfinite(loglik_val) + end + + @testset "gradient w.r.t. H (off-diagonal)" begin + f = h_vec -> di_loglik_fd_offdiag( + A_di_fd, B_di_fd, C_di_fd, u0_di_fd, + noise_di_fd, y_di_fd, reshape(h_vec, M_di_fd, M_di_fd) + ) + x0 = vec(copy(H_di_fd_offdiag)) + @test ForwardDiff.gradient(f, x0) ≈ fdm_gradient(f, x0) rtol = 1.0e-4 + end + + @testset "gradient w.r.t. A (with off-diagonal R)" begin + f = a_vec -> di_loglik_fd_offdiag( + reshape(a_vec, N_di_fd, N_di_fd), + B_di_fd, C_di_fd, u0_di_fd, noise_di_fd, y_di_fd, H_di_fd_offdiag + ) + x0 = vec(copy(A_di_fd)) + @test ForwardDiff.gradient(f, x0) ≈ fdm_gradient(f, x0) rtol = 1.0e-4 + end +end + +# ============================================================================= +# StaticArrays — ForwardDiff gradient tests +# ============================================================================= + +const noise_di_fd_s = [SVector{K_di_fd}(n) for n in noise_di_fd] +const y_di_fd_s = [SVector{M_di_fd}(yi) for yi in y_di_fd] + +@testset "ForwardDiff - DirectIteration loglik (static)" begin + @testset "gradient w.r.t. A" begin + f = a_vec -> begin + T_el = eltype(a_vec) + A_d = SMatrix{N_di_fd, N_di_fd}(reshape(a_vec, N_di_fd, N_di_fd)) + B_d = SMatrix{N_di_fd, K_di_fd}(T_el.(B_di_fd)) + C_d = SMatrix{M_di_fd, N_di_fd}(T_el.(C_di_fd)) + H_d = SMatrix{M_di_fd, M_di_fd}(T_el.(H_di_fd)) + prob = LinearStateSpaceProblem( + A_d, B_d, + SVector{N_di_fd}(zeros(T_el, N_di_fd)), (0, length(y_di_fd_s)); + C = C_d, observables_noise = H_d * H_d', + observables = y_di_fd_s, noise = noise_di_fd_s + ) + sol = solve(prob, DirectIteration()) + return sol.logpdf + end + x0 = collect(vec(Matrix(A_di_fd))) + @test ForwardDiff.gradient(f, x0) ≈ fdm_gradient(f, x0) rtol = 1.0e-4 + end + + @testset "gradient w.r.t. H" begin + f = h_vec -> begin + T_el = eltype(h_vec) + A_d = SMatrix{N_di_fd, N_di_fd}(T_el.(A_di_fd)) + B_d = SMatrix{N_di_fd, K_di_fd}(T_el.(B_di_fd)) + C_d = SMatrix{M_di_fd, N_di_fd}(T_el.(C_di_fd)) + H_d = SMatrix{M_di_fd, M_di_fd}(reshape(h_vec, M_di_fd, M_di_fd)) + prob = LinearStateSpaceProblem( + A_d, B_d, + SVector{N_di_fd}(zeros(T_el, N_di_fd)), (0, length(y_di_fd_s)); + C = C_d, observables_noise = H_d * H_d', + observables = y_di_fd_s, noise = noise_di_fd_s + ) + sol = solve(prob, DirectIteration()) + return sol.logpdf + end + x0 = collect(vec(Matrix(H_di_fd))) + @test ForwardDiff.gradient(f, x0) ≈ fdm_gradient(f, x0) rtol = 1.0e-4 + end +end diff --git a/test/linear_gradients.jl b/test/linear_gradients.jl deleted file mode 100644 index bfeb114..0000000 --- a/test/linear_gradients.jl +++ /dev/null @@ -1,243 +0,0 @@ -using ChainRulesTestUtils, DifferenceEquations, Distributions, LinearAlgebra, Test, Zygote, - Random, ChainRulesCore -using DelimitedFiles -using DiffEqBase -using FiniteDiff: finite_difference_gradient - -# Matrices from RBC -A_rbc = [ - 0.9568351489231076 6.209371005755285; - 3.0153731819288737e-18 0.20000000000000007 -] -B_rbc = reshape([0.0; -0.01], 2, 1) # make sure B is a matrix -C_rbc = [0.09579643002426148 0.6746869652592109; 1.0 0.0] -D_rbc = abs2.([0.1, 0.1]) -u0_rbc = zeros(2) -observables_rbc = readdlm( - joinpath( - pkgdir(DifferenceEquations), - "test/data/RBC_observables.csv" - ), - ',' -)' |> collect -noise_rbc = readdlm( - joinpath(pkgdir(DifferenceEquations), "test/data/RBC_noise.csv"), - ',' -)' |> - collect -# Data and Noise -T = 5 -observables_rbc = observables_rbc[:, 1:T] -noise_rbc = noise_rbc[:, 1:T] - -function z_sum(A, B, C, u0, noise, observables, D; kwargs...) - problem = LinearStateSpaceProblem( - A, B, u0, (0, size(observables, 2)); C, - observables_noise = D, - noise, observables, kwargs... - ) - sol = solve(problem) # since noise provided, uses DirectIteration - return sol.z[5][1] + sol.z[3][2] -end -@testset "mean_z test" begin - @test z_sum(A_rbc, B_rbc, C_rbc, u0_rbc, noise_rbc, observables_rbc, D_rbc) ≈ - -0.09008162336682057 - gradient( - (args...) -> z_sum(args..., observables_rbc, D_rbc), A_rbc, B_rbc, C_rbc, - u0_rbc, noise_rbc - ) - test_rrule( - Zygote.ZygoteRuleConfig(), - (args...) -> z_sum(args..., observables_rbc, D_rbc), A_rbc, B_rbc, - C_rbc, u0_rbc, noise_rbc; rrule_f = rrule_via_ad, check_inferred = false - ) -end -function u_sum(A, B, C, u0, noise, observables, D; kwargs...) - problem = LinearStateSpaceProblem( - A, B, u0, (0, size(observables, 2)); C, - observables_noise = D, - noise, observables, kwargs... - ) - sol = solve(problem) - u = sol.u # Zygote bug, must use separate name, also passes Nothing for Δsol so requires workarounds - return u[3][1] + u[3][2] - # BROKEN? ZYGOTE BUG? Seems to give the wrong Δsol type when calling the pullback - # return sol.u[3][1] + sol.u[3][2] #+ sol[3][1] + sol[3][2] + sol[2,1] -end -@testset "u test" begin - @test u_sum(A_rbc, B_rbc, C_rbc, u0_rbc, noise_rbc, observables_rbc, D_rbc) ≈ - -0.08780558376240931 - gradient( - (args...) -> u_sum(args..., observables_rbc, D_rbc), A_rbc, B_rbc, C_rbc, - u0_rbc, noise_rbc - ) - test_rrule( - Zygote.ZygoteRuleConfig(), - (args...) -> u_sum(args..., observables_rbc, D_rbc), A_rbc, B_rbc, - C_rbc, u0_rbc, noise_rbc; rrule_f = rrule_via_ad, check_inferred = false - ) -end -function W_sum(A, B, C, u0, noise, observables, D; kwargs...) - problem = LinearStateSpaceProblem( - A, B, u0, (0, size(observables, 2)); C, - observables_noise = D, - noise, observables, kwargs... - ) - sol = solve(problem) - return sol.W[1, 2] + sol.W[1, 4] + sol.z[2][2] -end -@testset "W test" begin - gradient( - (args...) -> W_sum(args..., observables_rbc, D_rbc), A_rbc, B_rbc, C_rbc, - u0_rbc, noise_rbc - ) - test_rrule( - Zygote.ZygoteRuleConfig(), - (args...) -> W_sum(args..., observables_rbc, D_rbc), A_rbc, B_rbc, - C_rbc, u0_rbc, noise_rbc; rrule_f = rrule_via_ad, check_inferred = false - ) -end - -# Versions without observations -function no_observables_sum(A, B, C, u0, noise; kwargs...) - problem = LinearStateSpaceProblem( - A, B, u0, (0, size(noise_rbc, 2)); C, noise, - kwargs... - ) - sol = solve(problem) - return sol.W[1, 2] + sol.W[1, 4] + sol.z[2][2] -end -@testset "no observables gradient" begin - @test no_observables_sum(A_rbc, B_rbc, C_rbc, u0_rbc, noise_rbc) ≈ - -0.08892781958364693 - gradient( - (args...) -> no_observables_sum(args...), A_rbc, B_rbc, C_rbc, - u0_rbc, noise_rbc - ) - test_rrule( - Zygote.ZygoteRuleConfig(), - (args...) -> no_observables_sum(args...), A_rbc, B_rbc, - C_rbc, u0_rbc, noise_rbc; rrule_f = rrule_via_ad, check_inferred = false - ) -end -function no_noise(A, C, u0; kwargs...) - problem = LinearStateSpaceProblem(A, nothing, u0, (0, 5); C, kwargs...) - sol = solve(problem) - # u = sol.u # bugs with u - return sol.z[2][2] # + u[2][2] -end -@testset "no noise" begin - u_nonzero = [1.1, 0.2] - @test no_noise(A_rbc, C_rbc, u_nonzero) ≈ 2.2943928649664755 - gradient( - (args...) -> no_noise(args...), A_rbc, C_rbc, - u_nonzero - ) - test_rrule( - Zygote.ZygoteRuleConfig(), - (args...) -> no_noise(args...), A_rbc, C_rbc, u_nonzero; - rrule_f = rrule_via_ad, - check_inferred = false - ) -end - -function no_observation_equation(A, u0; kwargs...) - problem = LinearStateSpaceProblem(A, nothing, u0, (0, 5); kwargs...) - sol = solve(problem) - u = sol.u # bugs with u - return u[2][2] + u[4][1] -end -@testset "no observation equation" begin - u_nonzero = [1.1, 0.2] - @test no_observation_equation(A_rbc, u_nonzero) ≈ 2.4279222804056597 - gradient( - (args...) -> no_observation_equation(args...), A_rbc, - u_nonzero - ) - test_rrule( - Zygote.ZygoteRuleConfig(), - (args...) -> no_observation_equation(args...), A_rbc, u_nonzero; - rrule_f = rrule_via_ad, - check_inferred = false - ) -end - -# Hack to set seeds within equation for finite-diff reproducibility -# Makes it ignore the derivative -setseed(x) = Random.seed!(x) -function ChainRulesCore.rrule(::typeof(setseed), x) - Random.seed!(x) - pb(ȳ) = (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()) - return nothing, pb -end - -function no_observation_equation_noise(A, B, u0; kwargs...) - setseed(1234) # hack for reproducibility with finite diff - problem = LinearStateSpaceProblem(A, B, u0, (0, 5); kwargs...) - sol = solve(problem) - u = sol.u # bugs with u - return u[2][2] + u[4][1] -end -@testset "no observation equation" begin - u_nonzero = [1.1, 0.2] - @test no_observation_equation_noise(A_rbc, B_rbc, u_nonzero) ≈ 2.3898508744331406 - gradient( - (args...) -> no_observation_equation_noise(args...), A_rbc, B_rbc, - u_nonzero - ) - test_rrule( - Zygote.ZygoteRuleConfig(), - (args...) -> no_observation_equation_noise(args...), A_rbc, B_rbc, u_nonzero; - rrule_f = rrule_via_ad, - check_inferred = false - ) -end - -function last_state_pass_noise(A, B, C, u0, noise) - problem = LinearStateSpaceProblem( - A, B, u0, (0, size(noise, 2)); C, noise, - observables_noise = nothing, observables = nothing - ) - sol = solve(problem) - return sol.u[end][2] -end - -@testset "last state with noise, no observable noise" begin - T = 20 - noise = Matrix([1.0; zeros(T - 1)]') # impulse - u_nonzero = [0.1, 0.2] - last_state_pass_noise(A_rbc, B_rbc, C_rbc, u_nonzero, noise) - gradient(last_state_pass_noise, A_rbc, B_rbc, C_rbc, u_nonzero, noise) - test_rrule( - Zygote.ZygoteRuleConfig(), - (u_nonzero) -> last_state_pass_noise(A_rbc, B_rbc, C_rbc, u_nonzero, noise), - u_nonzero; - rrule_f = rrule_via_ad, - check_inferred = false - ) -end -function last_observable_pass_noise(A, B, C, u0, noise) - problem = LinearStateSpaceProblem( - A, B, u0, (0, size(noise, 2)); C, noise, - observables_noise = nothing, observables = nothing - ) - sol = solve(problem) - return sol.z[end][2] -end -@testset "last observable with noise, no observable noise" begin - T = 20 - noise = Matrix([1.0; zeros(T - 1)]') # impulse - u_nonzero = [0.1, 0.2] - last_observable_pass_noise(A_rbc, B_rbc, C_rbc, u_nonzero, noise) - gradient(last_observable_pass_noise, A_rbc, B_rbc, C_rbc, u_nonzero, noise) - test_rrule( - Zygote.ZygoteRuleConfig(), - (u_nonzero) -> last_observable_pass_noise( - A_rbc, B_rbc, C_rbc, u_nonzero, - noise - ), - u_nonzero; - rrule_f = rrule_via_ad, - check_inferred = false - ) -end diff --git a/test/linear_likelihood.jl b/test/linear_likelihood.jl deleted file mode 100644 index a41f963..0000000 --- a/test/linear_likelihood.jl +++ /dev/null @@ -1,247 +0,0 @@ -using ChainRulesTestUtils, DifferenceEquations, Distributions, LinearAlgebra, Test, Zygote -using DelimitedFiles -using DiffEqBase -using FiniteDiff: finite_difference_gradient - -function joint_likelihood_1(A, B, C, u0, noise, observables, D; kwargs...) - problem = LinearStateSpaceProblem( - A, B, u0, (0, size(observables, 2)); C, - observables_noise = D, - noise, observables, kwargs... - ) - return solve(problem).logpdf -end - -# CRTU has problems with generating random MvNormal, so just testing diagonals -function kalman_likelihood(A, B, C, u0, observables, D; kwargs...) - problem = LinearStateSpaceProblem( - A, B, u0, (0, size(observables, 2)); C, - observables_noise = D, - u0_prior_mean = u0, - u0_prior_var = diagm(ones(length(u0))), - noise = nothing, observables, kwargs... - ) - return solve(problem).logpdf -end - -# Matrices from RBC -A_rbc = [ - 0.9568351489231076 6.209371005755285; - 3.0153731819288737e-18 0.20000000000000007 -] -B_rbc = reshape([0.0; -0.01], 2, 1) # make sure B is a matrix -C_rbc = [0.09579643002426148 0.6746869652592109; 1.0 0.0] -D_rbc = abs2.([0.1, 0.1]) -u0_rbc = zeros(2) - -observables_rbc = readdlm( - joinpath( - pkgdir(DifferenceEquations), - "test/data/RBC_observables.csv" - ), - ',' -)' |> collect -noise_rbc = readdlm( - joinpath(pkgdir(DifferenceEquations), "test/data/RBC_noise.csv"), - ',' -)' |> - collect -# Data and Noise -T = 5 -observables_rbc = observables_rbc[:, 1:T] -noise_rbc = noise_rbc[:, 1:T] - -@testset "basic inference" begin - prob = LinearStateSpaceProblem( - A_rbc, B_rbc, u0_rbc, (0, size(observables_rbc, 2)); - C = C_rbc, - observables_noise = D_rbc, noise = noise_rbc, - observables = observables_rbc - ) - @inferred LinearStateSpaceProblem( - A_rbc, B_rbc, u0_rbc, (0, size(observables_rbc, 2)); - C = C_rbc, observables_noise = D_rbc, - noise = noise_rbc, - observables = observables_rbc - ) - - sol = solve(prob) - @inferred solve(prob) - - DiffEqBase.get_concrete_problem(prob, false) - @inferred DiffEqBase.get_concrete_problem(prob, false) - - joint_likelihood_1(A_rbc, B_rbc, C_rbc, u0_rbc, noise_rbc, observables_rbc, D_rbc) - @inferred joint_likelihood_1( - A_rbc, B_rbc, C_rbc, u0_rbc, noise_rbc, observables_rbc, - D_rbc - ) -end - -@testset "basic kalman inference" begin - prob = LinearStateSpaceProblem( - A_rbc, B_rbc, u0_rbc, (0, size(observables_rbc, 2)); - C = C_rbc, - observables_noise = D_rbc, observables = observables_rbc, - u0_prior_mean = u0_rbc, - u0_prior_var = diagm(ones(length(u0_rbc))) - ) - @inferred LinearStateSpaceProblem( - A_rbc, B_rbc, u0_rbc, (0, size(observables_rbc, 2)); - C = C_rbc, - observables_noise = D_rbc, - observables = observables_rbc, - u0_prior_mean = u0_rbc, - u0_prior_var = diagm(ones(length(u0_rbc))) - ) - - sol = solve(prob) - @inferred solve(prob) - - prob_concrete = DiffEqBase.get_concrete_problem(prob, false) - @inferred DiffEqBase.get_concrete_problem(prob, false) - - kalman_likelihood(A_rbc, B_rbc, C_rbc, u0_rbc, observables_rbc, D_rbc) - @inferred kalman_likelihood(A_rbc, B_rbc, C_rbc, u0_rbc, observables_rbc, D_rbc) -end - -gradient( - (args...) -> joint_likelihood_1(args..., observables_rbc, D_rbc), A_rbc, B_rbc, - C_rbc, - u0_rbc, noise_rbc -) - -@testset "linear rbc joint likelihood" begin - @test joint_likelihood_1( - A_rbc, B_rbc, C_rbc, u0_rbc, noise_rbc, observables_rbc, - D_rbc - ) ≈ - -690.9407412360038 - @inferred joint_likelihood_1( - A_rbc, B_rbc, C_rbc, u0_rbc, noise_rbc, observables_rbc, - D_rbc - ) # - gradient( - (args...) -> joint_likelihood_1(args..., observables_rbc, D_rbc), A_rbc, B_rbc, - C_rbc, - u0_rbc, noise_rbc - ) - - test_rrule( - Zygote.ZygoteRuleConfig(), - (args...) -> joint_likelihood_1(args..., observables_rbc, D_rbc), A_rbc, - B_rbc, - C_rbc, u0_rbc, noise_rbc; rrule_f = rrule_via_ad, check_inferred = false - ) -end - -gradient( - (args...) -> kalman_likelihood(args..., observables_rbc, D_rbc), A_rbc, B_rbc, - C_rbc, - u0_rbc -) - -@testset "linear rbc kalman likelihood" begin - @test kalman_likelihood(A_rbc, B_rbc, C_rbc, u0_rbc, observables_rbc, D_rbc) ≈ - -607.3698273765538 - @inferred kalman_likelihood(A_rbc, B_rbc, C_rbc, u0_rbc, observables_rbc, D_rbc) # would this catch inference problems in the solve? - test_rrule( - Zygote.ZygoteRuleConfig(), - (args...) -> kalman_likelihood(args..., observables_rbc, D_rbc), A_rbc, - B_rbc, C_rbc, - u0_rbc; rrule_f = rrule_via_ad, check_inferred = false - ) -end - -# Load FVGQ data for checks -A_FVGQ = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_A.csv"), ',') -B_FVGQ = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_B.csv"), ',') -C_FVGQ = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_C.csv"), ',') -D_FVGQ = ones(6) * 1.0e-3 - -observables_FVGQ = readdlm( - joinpath( - pkgdir(DifferenceEquations), - "test/data/FVGQ20_observables.csv" - ), ',' -)' |> collect - -noise_FVGQ = readdlm( - joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_noise.csv"), - ',' -)' |> - collect -u0_FVGQ = zeros(size(A_FVGQ, 1)) - -@testset "linear FVGQ joint likelihood" begin - @test joint_likelihood_1( - A_FVGQ, B_FVGQ, C_FVGQ, u0_FVGQ, noise_FVGQ, observables_FVGQ, - D_FVGQ - ) ≈ -1.4613614369686982e6 - @inferred joint_likelihood_1( - A_FVGQ, B_FVGQ, C_FVGQ, u0_FVGQ, noise_FVGQ, - observables_FVGQ, - D_FVGQ - ) - test_rrule( - Zygote.ZygoteRuleConfig(), - (args...) -> joint_likelihood_1(args..., observables_FVGQ, D_FVGQ), A_FVGQ, - B_FVGQ, - C_FVGQ, u0_FVGQ, noise_FVGQ; rrule_f = rrule_via_ad, check_inferred = false - ) -end - -@testset "linear FVGQ Kalman" begin - # Note: set rtol to be higher than the default case because of huge gradient numbers - # D_FVGQ = - # @test kalman_likelihood(A_FVGQ, B_FVGQ, C_FVGQ, u0_FVGQ, observables_FVGQ, abs2.(ones(6) * 1e-3)) ≈ - # -108.52706300389917 - @test kalman_likelihood( - A_FVGQ, B_FVGQ, C_FVGQ, u0_FVGQ, observables_FVGQ, - D_FVGQ - ) ≈ - 2253.0905386483046 - - gradient( - (args...) -> kalman_likelihood(args..., observables_FVGQ, D_FVGQ), A_FVGQ, - B_FVGQ, - C_FVGQ, u0_FVGQ - ) - - test_rrule( - Zygote.ZygoteRuleConfig(), - (args...) -> kalman_likelihood(args..., observables_FVGQ, D_FVGQ), - A_FVGQ, B_FVGQ, - C_FVGQ, u0_FVGQ; rrule_f = rrule_via_ad, check_inferred = false, rtol = 1.0e-8 - ) -end - -@testset "basic kalman failure" begin - A = [1.0e20 0.0; 1.0e20 0.0] - u0_prior_var = diagm(1.0e10 * ones(length(u0_rbc))) - prob = LinearStateSpaceProblem( - A, B_rbc, u0_rbc, (0, size(observables_rbc, 2)); - C = C_rbc, - observables_noise = D_rbc, observables = observables_rbc, - u0_prior_mean = u0_rbc, u0_prior_var - ) - sol = solve(prob) - @test sol.logpdf ≈ -Inf - @test sol.retcode != :Success -end - -@testset "basic kalman failure gradient" begin - A = [1.0e20 0.0; 1.0e20 0.0] - u0_prior_var = diagm(1.0e10 * ones(length(u0_rbc))) - function fail_kalman(B_rbc) - prob = LinearStateSpaceProblem( - A, B_rbc, u0_rbc, (0, size(observables_rbc, 2)); - C = C_rbc, - observables_noise = D_rbc, - observables = observables_rbc, - u0_prior_mean = u0_rbc, u0_prior_var - ) - return solve(prob).logpdf - end - @test gradient(fail_kalman, B_rbc)[1] ≈ [0.0; 0.0;;] # but hopefully gradients are ignored! -end diff --git a/test/linear_simulations.jl b/test/linear_simulations.jl deleted file mode 100644 index cb993f7..0000000 --- a/test/linear_simulations.jl +++ /dev/null @@ -1,202 +0,0 @@ -using ChainRulesTestUtils, DifferenceEquations, Distributions, LinearAlgebra, Test, Zygote, - Random -using DelimitedFiles -using DiffEqBase -using FiniteDiff: finite_difference_gradient - -# Matrices from RBC -A_rbc = [ - 0.9568351489231076 6.209371005755285; - 3.0153731819288737e-18 0.20000000000000007 -] -B_rbc = reshape([0.0; -0.01], 2, 1) # make sure B is a matrix -C_rbc = [0.09579643002426148 0.6746869652592109; 1.0 0.0] -D_rbc = abs2.([0.1, 0.1]) -u0_rbc = zeros(2) - -observables_rbc = readdlm( - joinpath( - pkgdir(DifferenceEquations), - "test/data/RBC_observables.csv" - ), - ',' -)' |> collect -# Data and Noise -@testset "basic inference, simulated noise" begin - prob = LinearStateSpaceProblem( - A_rbc, B_rbc, u0_rbc, (0, size(observables_rbc, 2)); - C = C_rbc, - observables_noise = D_rbc, observables = observables_rbc, - syms = [:a, :b] - ) - @inferred LinearStateSpaceProblem( - A_rbc, B_rbc, u0_rbc, (0, size(observables_rbc, 2)); - C = C_rbc, observables_noise = D_rbc, - observables = observables_rbc, syms = [:a, :b] - ) - - sol = solve(prob) - @inferred solve(prob) - - # todo: add in regression tests -end - -@testset "basic inference, simulated noise, no observations, no observation noise" begin - T = 20 - prob = LinearStateSpaceProblem(A_rbc, B_rbc, u0_rbc, (0, T); C = C_rbc, syms = [:a, :b]) - @inferred LinearStateSpaceProblem( - A_rbc, B_rbc, u0_rbc, (0, T); C = C_rbc, - syms = [:a, :b] - ) - - sol = solve(prob) - @inferred solve(prob) - - # todo: add in regression tests -end - -@testset "simulation with observations and noise, no observation noise" begin - Random.seed!(1234) - sol = solve(LinearStateSpaceProblem(A_rbc, B_rbc, u0_rbc, (0, 5); C = C_rbc)) - @test sol.u ≈ - [ - [0.0, 0.0], [0.0, 0.003597289068234817], - [0.02233690243961772, -0.010152627110638895], - [-0.04166869504075366, 0.0021653707472607075], - [-0.026424481689999797, -0.006756025225207251], - [-0.06723454002062011, -0.00555367682297924], - ] - @test sol.z ≈ - [ - [0.0, 0.0], [0.0024270440446074832, 0.0], - [-0.004710049663169753, 0.02233690243961772], - [-0.002530764810543453, -0.04166869504075366], - [-0.007089573167553201, -0.026424481689999797], - [-0.010187822270025022, -0.06723454002062011], - ] - @test sol.W ≈ - [-0.3597289068234817 1.0872084924285859 -0.4195896169388487 0.7189099374659392 0.4202471777937789] - @test sol.logpdf === nothing -end - -@testset "simulation with observations and noise, no observation noise" begin - Random.seed!(1234) - sol = solve( - LinearStateSpaceProblem( - A_rbc, B_rbc, u0_rbc, (0, 5); C = C_rbc, - observables_noise = D_rbc - ) - ) - @test sol.u ≈ - [ - [0.0, 0.0], [0.0, 0.003597289068234817], - [0.02233690243961772, -0.010152627110638895], - [-0.04166869504075366, 0.0021653707472607075], - [-0.026424481689999797, -0.006756025225207251], - [-0.06723454002062011, -0.00555367682297924], - ] - @test sol.z ≈ - [ - [-0.06856709022761191, 0.20547630560640365], - [0.034916316989299055, -0.030490125519643224], - [0.0414594477647271, -0.06215886919798015], - [0.08614040809827415, -0.040311314885592704], - [0.0034755874208198837, -0.08053882074804589], - [-0.07921183287013331, -0.16087605412196193], - ] - @test sol.W ≈ - [-0.3597289068234817 1.0872084924285859 -0.4195896169388487 0.7189099374659392 0.4202471777937789] - @test sol.logpdf === nothing -end - -@testset "basic inference, no simulated noise, no observations with observation noise" begin - T = 20 - B_no_noise = zeros(2, 2) - u0 = [1.0, 0.5] - prob_no_noise = LinearStateSpaceProblem( - A_rbc, B_no_noise, u0, (0, T); C = C_rbc, - syms = [:a, :b] - ) - - sol_no_noise = solve(prob_no_noise) - - prob_obs_noise = LinearStateSpaceProblem( - A_rbc, B_no_noise, u0, (0, T); C = C_rbc, - syms = [:a, :b], observables_noise = D_rbc - ) - @inferred LinearStateSpaceProblem( - A_rbc, B_no_noise, u0, (0, T); C = C_rbc, - syms = [:a, :b], observables_noise = D_rbc - ) - sol_obs_noise = solve(prob_obs_noise) - @inferred solve(prob_obs_noise) - - # check that if the variance of the noise is tiny it is identical - sol_tiny_obs_noise = solve( - LinearStateSpaceProblem( - A_rbc, B_no_noise, u0, (0, T); - C = C_rbc, - syms = [:a, :b], - observables_noise = [1.0e-16, 1.0e-16] - ) - ) - @test maximum(maximum.(sol_tiny_obs_noise.z - sol_no_noise.z)) < 1.0e-7 # still some noise - @test maximum(maximum.(sol_tiny_obs_noise.z - sol_no_noise.z)) > 0.0 # but not zero -end - -@testset "basic inference, no noise, no observations and no with observation noise" begin - T = 5 - B_no_noise = zeros(2, 2) - u0 = [1.0, 0.5] - sol_no_noise = solve( - LinearStateSpaceProblem( - A_rbc, B_no_noise, u0, (0, T); C = C_rbc, - syms = [:a, :b] - ) - ) - - #Now literally pass in no noise in B with a nothing - prob = LinearStateSpaceProblem( - A_rbc, nothing, u0, (0, T); C = C_rbc, - syms = [:a, :b] - ) - @inferred LinearStateSpaceProblem( - A_rbc, nothing, u0, (0, T); C = C_rbc, - syms = [:a, :b] - ) - - sol_nothing_noise = solve(prob) - @inferred solve(prob) - - @test sol_no_noise.z ≈ sol_nothing_noise.z - @test sol_no_noise.u ≈ sol_nothing_noise.u - @test sol_nothing_noise.W === nothing -end - -@testset "no observation process" begin - Random.seed!(1234) - T = 5 - u0 = [1.0, 0.5] - prob = LinearStateSpaceProblem( - A_rbc, B_rbc, u0, (0, T); C = nothing, - syms = [:a, :b] - ) - @inferred LinearStateSpaceProblem( - A_rbc, B_rbc, u0, (0, T); C = nothing, - syms = [:a, :b] - ) - sol = solve(prob) - @inferred solve(prob) - - @test sol.z === nothing - @test sol.u ≈ [ - [1.0, 0.5], [4.06152065180075, 0.10359728906823484], - [4.5294797207351944, 0.009847372889361128], - [4.395111394835915, 0.006165370747260727], - [4.243680140369242, -0.005956025225207233], - [4.023519148749289, -0.005393676822979223], - ] - @test sol.W ≈ - [-0.3597289068234817 1.0872084924285859 -0.4195896169388487 0.7189099374659392 0.4202471777937789] - @test sol.logpdf === nothing -end diff --git a/test/quadratic_direct_iteration.jl b/test/quadratic_direct_iteration.jl new file mode 100644 index 0000000..0853479 --- /dev/null +++ b/test/quadratic_direct_iteration.jl @@ -0,0 +1,285 @@ +using DifferenceEquations, LinearAlgebra, Test, Random, DelimitedFiles, DiffEqBase +using DifferenceEquations: init, solve! + +# ============================================================================= +# Small random test data (N=2, K=1, M=2, T=5) +# ============================================================================= + +Random.seed!(99) +const N_q = 2; const K_q = 1; const M_q = 2; const T_q = 5 + +const A_0_sm = 0.01 * randn(N_q) +const A_1_sm_raw = randn(N_q, N_q) +const A_1_sm = 0.5 * A_1_sm_raw / maximum(abs.(eigvals(A_1_sm_raw))) +const A_2_sm = 0.01 * randn(N_q, N_q, N_q) +const B_sm = 0.1 * randn(N_q, K_q) +const C_0_sm = 0.01 * randn(M_q) +const C_1_sm = randn(M_q, N_q) +const C_2_sm = 0.01 * randn(M_q, N_q, N_q) +const D_sm = abs2.([0.1, 0.1]) +const u0_sm = zeros(N_q) + +Random.seed!(200) +const noise_sm = [randn(K_q) for _ in 1:T_q] + +# Pre-simulate observations for logpdf tests +Random.seed!(300) +const sim_unpruned = solve( + QuadraticStateSpaceProblem( + A_0_sm, A_1_sm, A_2_sm, B_sm, u0_sm, (0, T_q); + C_0 = C_0_sm, C_1 = C_1_sm, C_2 = C_2_sm, noise = noise_sm + ) +) +const obs_sm = [sim_unpruned.z[t + 1] + 0.05 * randn(M_q) for t in 1:T_q] + +# ============================================================================= +# Unpruned QuadraticStateSpaceProblem tests +# ============================================================================= + +@testset "Unpruned simulation (no obs) — finite and solve! matches solve" begin + Random.seed!(1234) + prob = QuadraticStateSpaceProblem( + A_0_sm, A_1_sm, A_2_sm, B_sm, u0_sm, (0, T_q); + C_0 = C_0_sm, C_1 = C_1_sm, C_2 = C_2_sm + ) + Random.seed!(1234) + sol = solve(prob) + @test all(all(isfinite, u) for u in sol.u) + @test all(all(isfinite, z) for z in sol.z) + @test sol.logpdf == 0.0 + + # solve! matches solve + Random.seed!(1234) + ws = init(prob, DirectIteration()) + sol_ws = solve!(ws) + @test sol_ws.u ≈ sol.u + @test sol_ws.z ≈ sol.z + @test sol_ws.logpdf ≈ sol.logpdf +end + +@testset "Unpruned with observations + obs_noise — logpdf finite" begin + prob = QuadraticStateSpaceProblem( + A_0_sm, A_1_sm, A_2_sm, B_sm, u0_sm, (0, T_q); + C_0 = C_0_sm, C_1 = C_1_sm, C_2 = C_2_sm, + noise = noise_sm, observables = obs_sm, observables_noise = Diagonal(D_sm) + ) + sol = solve(prob) + @test isfinite(sol.logpdf) + @test sol.logpdf != 0.0 +end + +@testset "Unpruned no noise (B=nothing) — deterministic" begin + u0_det = [0.5, -0.3] + prob = QuadraticStateSpaceProblem( + A_0_sm, A_1_sm, A_2_sm, nothing, u0_det, (0, T_q); + C_0 = C_0_sm, C_1 = C_1_sm, C_2 = C_2_sm + ) + sol1 = solve(prob) + sol2 = solve(prob) + @test sol1.u ≈ sol2.u + @test sol1.z ≈ sol2.z + @test sol1.W === nothing + @test sol2.W === nothing +end + +@testset "Unpruned C=nothing — no observation process" begin + Random.seed!(1234) + prob = QuadraticStateSpaceProblem( + A_0_sm, A_1_sm, A_2_sm, B_sm, u0_sm, (0, T_q) + ) + sol = solve(prob) + @test sol.z === nothing + @test all(all(isfinite, u) for u in sol.u) + @test sol.logpdf == 0.0 +end + +# ============================================================================= +# Pruned PrunedQuadraticStateSpaceProblem tests +# ============================================================================= + +# Pre-simulate observations for pruned logpdf tests +Random.seed!(400) +const sim_pruned = solve( + PrunedQuadraticStateSpaceProblem( + A_0_sm, A_1_sm, A_2_sm, B_sm, u0_sm, (0, T_q); + C_0 = C_0_sm, C_1 = C_1_sm, C_2 = C_2_sm, noise = noise_sm + ) +) +const obs_pruned_sm = [sim_pruned.z[t + 1] + 0.05 * randn(M_q) for t in 1:T_q] + +@testset "Pruned simulation (no obs) — finite and solve! matches solve" begin + Random.seed!(1234) + prob = PrunedQuadraticStateSpaceProblem( + A_0_sm, A_1_sm, A_2_sm, B_sm, u0_sm, (0, T_q); + C_0 = C_0_sm, C_1 = C_1_sm, C_2 = C_2_sm + ) + Random.seed!(1234) + sol = solve(prob) + @test all(all(isfinite, u) for u in sol.u) + @test all(all(isfinite, z) for z in sol.z) + @test sol.logpdf == 0.0 + + # solve! matches solve + Random.seed!(1234) + ws = init(prob, DirectIteration()) + sol_ws = solve!(ws) + @test sol_ws.u ≈ sol.u + @test sol_ws.z ≈ sol.z + @test sol_ws.logpdf ≈ sol.logpdf +end + +@testset "Pruned with observations + obs_noise — logpdf finite" begin + prob = PrunedQuadraticStateSpaceProblem( + A_0_sm, A_1_sm, A_2_sm, B_sm, u0_sm, (0, T_q); + C_0 = C_0_sm, C_1 = C_1_sm, C_2 = C_2_sm, + noise = noise_sm, observables = obs_pruned_sm, observables_noise = Diagonal(D_sm) + ) + sol = solve(prob) + @test isfinite(sol.logpdf) + @test sol.logpdf != 0.0 +end + +@testset "Pruned no noise (B=nothing) — deterministic" begin + u0_det = [0.5, -0.3] + prob = PrunedQuadraticStateSpaceProblem( + A_0_sm, A_1_sm, A_2_sm, nothing, u0_det, (0, T_q); + C_0 = C_0_sm, C_1 = C_1_sm, C_2 = C_2_sm + ) + sol1 = solve(prob) + sol2 = solve(prob) + @test sol1.u ≈ sol2.u + @test sol1.z ≈ sol2.z + @test sol1.W === nothing +end + +@testset "Pruned C=nothing — no observation process" begin + Random.seed!(1234) + prob = PrunedQuadraticStateSpaceProblem( + A_0_sm, A_1_sm, A_2_sm, B_sm, u0_sm, (0, T_q) + ) + sol = solve(prob) + @test sol.z === nothing + @test all(all(isfinite, u) for u in sol.u) + @test sol.logpdf == 0.0 +end + +# ============================================================================= +# Regression: PrunedQuadraticStateSpaceProblem matches old closure-based value +# ============================================================================= + +# RBC quadratic data (from test/direct_iteration.jl) +A_0_rbc = [-7.824904812740593e-5, 0.0] +A_1_rbc = [0.9568351489231076 6.209371005755285; 3.0153731819288737e-18 0.20000000000000007] +A_2_rbc = cat( + [-0.00019761505863889124 0.03375055315837927; 0.0 0.0], + [0.03375055315837913 3.128758481817603; 0.0 0.0]; dims = 3 +) +B_2_rbc = reshape([0.0; -0.01], 2, 1) +C_0_rbc = [7.824904812740593e-5, 0.0] +C_1_rbc = [0.09579643002426148 0.6746869652592109; 1.0 0.0] +C_2_rbc = cat( + [-0.00018554166974717046 0.0025652363153049716; 0.0 0.0], + [0.002565236315304951 0.3132705036896446; 0.0 0.0]; dims = 3 +) +D_2_rbc = abs2.([0.1, 0.1]) +u0_2_rbc = zeros(2) + +observables_2_rbc_matrix = readdlm( + joinpath(pkgdir(DifferenceEquations), "test/data/RBC_observables.csv"), ',' +)' |> collect +observables_2_rbc = [observables_2_rbc_matrix[:, t] for t in 1:size(observables_2_rbc_matrix, 2)] +noise_2_rbc_matrix = readdlm( + joinpath(pkgdir(DifferenceEquations), "test/data/RBC_noise.csv"), ',' +)' |> collect +noise_2_rbc = [noise_2_rbc_matrix[:, t] for t in 1:size(noise_2_rbc_matrix, 2)] +T_rbc = 5 +observables_2_rbc_short = observables_2_rbc[1:T_rbc] +noise_2_rbc_short = noise_2_rbc[1:T_rbc] + +@testset "Pruned RBC regression — matches closure-based quadratic_joint_likelihood" begin + prob = PrunedQuadraticStateSpaceProblem( + A_0_rbc, A_1_rbc, A_2_rbc, B_2_rbc, u0_2_rbc, + (0, length(observables_2_rbc_short)); + C_0 = C_0_rbc, C_1 = C_1_rbc, C_2 = C_2_rbc, + observables_noise = Diagonal(D_2_rbc), noise = noise_2_rbc_short, + observables = observables_2_rbc_short + ) + sol = solve(prob) + @test sol.logpdf ≈ -690.81094364573 +end + +@testset "Pruned RBC — solve! matches solve" begin + prob = PrunedQuadraticStateSpaceProblem( + A_0_rbc, A_1_rbc, A_2_rbc, B_2_rbc, u0_2_rbc, + (0, length(observables_2_rbc_short)); + C_0 = C_0_rbc, C_1 = C_1_rbc, C_2 = C_2_rbc, + observables_noise = Diagonal(D_2_rbc), noise = noise_2_rbc_short, + observables = observables_2_rbc_short + ) + sol_direct = solve(prob) + ws = init(prob, DirectIteration()) + sol_ws = solve!(ws) + @test sol_ws.u ≈ sol_direct.u + @test sol_ws.z ≈ sol_direct.z + @test sol_ws.logpdf ≈ sol_direct.logpdf +end + +# ============================================================================= +# Workspace (init/solve!) additional tests +# ============================================================================= + +@testset "Unpruned solve!() repeated — idempotent" begin + prob = QuadraticStateSpaceProblem( + A_0_sm, A_1_sm, A_2_sm, B_sm, u0_sm, (0, T_q); + C_0 = C_0_sm, C_1 = C_1_sm, C_2 = C_2_sm, + noise = noise_sm, observables = obs_sm, observables_noise = Diagonal(D_sm) + ) + ws = init(prob, DirectIteration()) + sol1 = solve!(ws) + sol2 = solve!(ws) + @test sol1.u ≈ sol2.u + @test sol1.z ≈ sol2.z + @test sol1.logpdf ≈ sol2.logpdf +end + +@testset "Pruned solve!() repeated — idempotent" begin + prob = PrunedQuadraticStateSpaceProblem( + A_0_sm, A_1_sm, A_2_sm, B_sm, u0_sm, (0, T_q); + C_0 = C_0_sm, C_1 = C_1_sm, C_2 = C_2_sm, + noise = noise_sm, observables = obs_pruned_sm, observables_noise = Diagonal(D_sm) + ) + ws = init(prob, DirectIteration()) + sol1 = solve!(ws) + sol2 = solve!(ws) + @test sol1.u ≈ sol2.u + @test sol1.z ≈ sol2.z + @test sol1.logpdf ≈ sol2.logpdf +end + +@testset "Unpruned solve!() — no obs, B=nothing" begin + u0_det = [0.5, -0.3] + prob = QuadraticStateSpaceProblem( + A_0_sm, A_1_sm, A_2_sm, nothing, u0_det, (0, T_q); + C_0 = C_0_sm, C_1 = C_1_sm, C_2 = C_2_sm + ) + sol_direct = solve(prob) + ws = init(prob, DirectIteration()) + sol_ws = solve!(ws) + @test sol_ws.u ≈ sol_direct.u + @test sol_ws.z ≈ sol_direct.z + @test sol_ws.W === nothing +end + +@testset "Pruned solve!() — no obs, B=nothing" begin + u0_det = [0.5, -0.3] + prob = PrunedQuadraticStateSpaceProblem( + A_0_sm, A_1_sm, A_2_sm, nothing, u0_det, (0, T_q); + C_0 = C_0_sm, C_1 = C_1_sm, C_2 = C_2_sm + ) + sol_direct = solve(prob) + ws = init(prob, DirectIteration()) + sol_ws = solve!(ws) + @test sol_ws.u ≈ sol_direct.u + @test sol_ws.z ≈ sol_direct.z + @test sol_ws.W === nothing +end diff --git a/test/quadratic_direct_iteration_enzyme.jl b/test/quadratic_direct_iteration_enzyme.jl new file mode 100644 index 0000000..8503455 --- /dev/null +++ b/test/quadratic_direct_iteration_enzyme.jl @@ -0,0 +1,224 @@ +# Enzyme AD tests for Quadratic and PrunedQuadratic DirectIteration +# Forward: test_forward (EnzymeTestUtils) +# Reverse: test_reverse (EnzymeTestUtils) + +using LinearAlgebra, Test, Enzyme, EnzymeTestUtils, Random +using DifferenceEquations +using DifferenceEquations: init, solve!, StateSpaceWorkspace +using FiniteDifferences: central_fdm + +const _fdm_qe = central_fdm(5, 1) + +# ============================================================================= +# Small test data (N=2, K=1, M=2, T=2) +# ============================================================================= + +Random.seed!(77) +const N_qe = 2; const K_qe = 1; const M_qe = 2; const T_qe = 2 + +const A_0_qe = 0.01 * randn(N_qe) +const A_1_qe_raw = randn(N_qe, N_qe) +const A_1_qe = 0.5 * A_1_qe_raw / maximum(abs.(eigvals(A_1_qe_raw))) +const A_2_qe = 0.01 * randn(N_qe, N_qe, N_qe) +const B_qe = 0.1 * randn(N_qe, K_qe) +const C_0_qe = 0.01 * randn(M_qe) +const C_1_qe = randn(M_qe, N_qe) +const C_2_qe = 0.01 * randn(M_qe, N_qe, N_qe) +const u0_qe = zeros(N_qe) +const noise_qe = [0.1 * randn(K_qe) for _ in 1:T_qe] + +# ============================================================================= +# Helper: allocate sol/cache from init +# ============================================================================= + +function make_quad_sol_cache(A_0, A_1, A_2, B, u0, noise; C_0, C_1, C_2) + prob = QuadraticStateSpaceProblem( + A_0, A_1, A_2, B, u0, (0, length(noise)); + C_0, C_1, C_2, noise + ) + ws = init(prob, DirectIteration()) + return ws.output, ws.cache +end + +function make_pruned_sol_cache(A_0, A_1, A_2, B, u0, noise; C_0, C_1, C_2) + prob = PrunedQuadraticStateSpaceProblem( + A_0, A_1, A_2, B, u0, (0, length(noise)); + C_0, C_1, C_2, noise + ) + ws = init(prob, DirectIteration()) + return ws.output, ws.cache +end + +# ============================================================================= +# Unpruned wrapper functions +# ============================================================================= + +function quad_solve!(A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, sol, cache) + prob = QuadraticStateSpaceProblem( + A_0, A_1, A_2, B, u0, (0, length(noise)); + C_0, C_1, C_2, noise + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol, cache) + solve!(ws) + return (sol.u, sol.z) +end + +function quad_scalar!(A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, sol, cache)::Float64 + prob = QuadraticStateSpaceProblem( + A_0, A_1, A_2, B, u0, (0, length(noise)); + C_0, C_1, C_2, noise + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol, cache) + return sum(solve!(ws).u[end]) +end + +# ============================================================================= +# Pruned wrapper functions +# ============================================================================= + +function pruned_solve!(A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, sol, cache) + prob = PrunedQuadraticStateSpaceProblem( + A_0, A_1, A_2, B, u0, (0, length(noise)); + C_0, C_1, C_2, noise + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol, cache) + solve!(ws) + return (sol.u, sol.z) +end + +function pruned_scalar!(A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, sol, cache)::Float64 + prob = PrunedQuadraticStateSpaceProblem( + A_0, A_1, A_2, B, u0, (0, length(noise)); + C_0, C_1, C_2, noise + ) + ws = StateSpaceWorkspace(prob, DirectIteration(), sol, cache) + return sum(solve!(ws).u[end]) +end + +# ============================================================================= +# Sanity tests +# ============================================================================= + +@testset "Unpruned quadratic solve! sanity" begin + sol, cache = make_quad_sol_cache( + A_0_qe, A_1_qe, A_2_qe, B_qe, u0_qe, noise_qe; + C_0 = C_0_qe, C_1 = C_1_qe, C_2 = C_2_qe + ) + val = quad_scalar!( + A_0_qe, A_1_qe, A_2_qe, B_qe, C_0_qe, C_1_qe, C_2_qe, + u0_qe, noise_qe, sol, cache + ) + @test isfinite(val) + + val2 = quad_scalar!( + A_0_qe, A_1_qe, A_2_qe, B_qe, C_0_qe, C_1_qe, C_2_qe, + u0_qe, noise_qe, sol, cache + ) + @test val ≈ val2 rtol = 1.0e-12 +end + +@testset "Pruned quadratic solve! sanity" begin + sol, cache = make_pruned_sol_cache( + A_0_qe, A_1_qe, A_2_qe, B_qe, u0_qe, noise_qe; + C_0 = C_0_qe, C_1 = C_1_qe, C_2 = C_2_qe + ) + val = pruned_scalar!( + A_0_qe, A_1_qe, A_2_qe, B_qe, C_0_qe, C_1_qe, C_2_qe, + u0_qe, noise_qe, sol, cache + ) + @test isfinite(val) + + val2 = pruned_scalar!( + A_0_qe, A_1_qe, A_2_qe, B_qe, C_0_qe, C_1_qe, C_2_qe, + u0_qe, noise_qe, sol, cache + ) + @test val ≈ val2 rtol = 1.0e-12 +end + +# ============================================================================= +# Unpruned forward (all Duplicated) +# ============================================================================= + +@testset "EnzymeTestUtils - Unpruned quadratic forward (all Duplicated)" begin + sol, cache = make_quad_sol_cache( + A_0_qe, A_1_qe, A_2_qe, B_qe, u0_qe, noise_qe; + C_0 = C_0_qe, C_1 = C_1_qe, C_2 = C_2_qe + ) + + test_forward( + quad_solve!, Const, + (copy(A_0_qe), Duplicated), (copy(A_1_qe), Duplicated), + (copy(A_2_qe), Duplicated), (copy(B_qe), Duplicated), + (copy(C_0_qe), Duplicated), (copy(C_1_qe), Duplicated), + (copy(C_2_qe), Duplicated), (copy(u0_qe), Duplicated), + ([copy(n) for n in noise_qe], Duplicated), + (sol, Duplicated), (cache, Duplicated); + fdm = _fdm_qe, + ) +end + +# ============================================================================= +# Unpruned reverse (test_reverse) +# ============================================================================= + +@testset "EnzymeTestUtils - Unpruned quadratic reverse" begin + sol, cache = make_quad_sol_cache( + A_0_qe, A_1_qe, A_2_qe, B_qe, u0_qe, noise_qe; + C_0 = C_0_qe, C_1 = C_1_qe, C_2 = C_2_qe + ) + + test_reverse( + quad_scalar!, Active, + (copy(A_0_qe), Duplicated), (copy(A_1_qe), Duplicated), + (copy(A_2_qe), Duplicated), (copy(B_qe), Duplicated), + (copy(C_0_qe), Duplicated), (copy(C_1_qe), Duplicated), + (copy(C_2_qe), Duplicated), (copy(u0_qe), Duplicated), + ([copy(n) for n in noise_qe], Duplicated), + (deepcopy(sol), Duplicated), (deepcopy(cache), Duplicated); + fdm = _fdm_qe, + ) +end + +# ============================================================================= +# Pruned forward (all Duplicated) +# ============================================================================= + +@testset "EnzymeTestUtils - Pruned quadratic forward (all Duplicated)" begin + sol, cache = make_pruned_sol_cache( + A_0_qe, A_1_qe, A_2_qe, B_qe, u0_qe, noise_qe; + C_0 = C_0_qe, C_1 = C_1_qe, C_2 = C_2_qe + ) + + test_forward( + pruned_solve!, Const, + (copy(A_0_qe), Duplicated), (copy(A_1_qe), Duplicated), + (copy(A_2_qe), Duplicated), (copy(B_qe), Duplicated), + (copy(C_0_qe), Duplicated), (copy(C_1_qe), Duplicated), + (copy(C_2_qe), Duplicated), (copy(u0_qe), Duplicated), + ([copy(n) for n in noise_qe], Duplicated), + (sol, Duplicated), (cache, Duplicated); + fdm = _fdm_qe, + ) +end + +# ============================================================================= +# Pruned reverse (test_reverse) +# ============================================================================= + +@testset "EnzymeTestUtils - Pruned quadratic reverse" begin + sol, cache = make_pruned_sol_cache( + A_0_qe, A_1_qe, A_2_qe, B_qe, u0_qe, noise_qe; + C_0 = C_0_qe, C_1 = C_1_qe, C_2 = C_2_qe + ) + + test_reverse( + pruned_scalar!, Active, + (copy(A_0_qe), Duplicated), (copy(A_1_qe), Duplicated), + (copy(A_2_qe), Duplicated), (copy(B_qe), Duplicated), + (copy(C_0_qe), Duplicated), (copy(C_1_qe), Duplicated), + (copy(C_2_qe), Duplicated), (copy(u0_qe), Duplicated), + ([copy(n) for n in noise_qe], Duplicated), + (deepcopy(sol), Duplicated), (deepcopy(cache), Duplicated); + fdm = _fdm_qe, + ) +end diff --git a/test/quadratic_likelihood.jl b/test/quadratic_likelihood.jl deleted file mode 100644 index 173192e..0000000 --- a/test/quadratic_likelihood.jl +++ /dev/null @@ -1,181 +0,0 @@ -using ChainRulesTestUtils, DifferenceEquations, Distributions, LinearAlgebra, Test, Zygote -using DelimitedFiles -using DiffEqBase -using FiniteDiff: finite_difference_gradient - -# joint case -function joint_likelihood_2( - A_0, A_1, A_2, B, C_0, C_1, C_2, u0, noise, observables, D; - kwargs... - ) - problem = QuadraticStateSpaceProblem( - A_0, A_1, A_2, B, u0, (0, size(observables, 2)); - C_0, C_1, - C_2, observables_noise = D, noise, observables, - kwargs... - ) - return solve(problem).logpdf -end - -# Matrices from RBC -A_0_rbc = [-7.824904812740593e-5, 0.0] -A_1_rbc = [0.9568351489231076 6.209371005755285; 3.0153731819288737e-18 0.20000000000000007] -A_2_rbc = cat( - [-0.00019761505863889124 0.03375055315837927; 0.0 0.0], - [0.03375055315837913 3.128758481817603; 0.0 0.0]; dims = 3 -) -B_2_rbc = reshape([0.0; -0.01], 2, 1) # make sure B is a matrix -C_0_rbc = [7.824904812740593e-5, 0.0] -C_1_rbc = [0.09579643002426148 0.6746869652592109; 1.0 0.0] -C_2_rbc = cat( - [-0.00018554166974717046 0.0025652363153049716; 0.0 0.0], - [0.002565236315304951 0.3132705036896446; 0.0 0.0]; dims = 3 -) -D_2_rbc = abs2.([0.1, 0.1]) -u0_2_rbc = zeros(2) - -observables_2_rbc = readdlm( - joinpath( - pkgdir(DifferenceEquations), - "test/data/RBC_observables.csv" - ), - ',' -)' |> collect -noise_2_rbc = readdlm( - joinpath(pkgdir(DifferenceEquations), "test/data/RBC_noise.csv"), - ',' -)' |> - collect - -# Data and Noise -T = 5 -observables_2_rbc = observables_2_rbc[:, 1:T] -noise_2_rbc = noise_2_rbc[:, 1:T] - -@testset "basic inference" begin - prob = QuadraticStateSpaceProblem( - A_0_rbc, A_1_rbc, A_2_rbc, B_2_rbc, u0_2_rbc, - (0, size(observables_2_rbc, 2)); C_0 = C_0_rbc, - C_1 = C_1_rbc, - C_2 = C_2_rbc, observables_noise = D_2_rbc, - noise = noise_2_rbc, observables = observables_2_rbc - ) - @inferred QuadraticStateSpaceProblem( - A_0_rbc, A_1_rbc, A_2_rbc, B_2_rbc, u0_2_rbc, - (0, size(observables_2_rbc, 2)); C_0 = C_0_rbc, - C_1 = C_1_rbc, C_2 = C_2_rbc, - observables_noise = D_2_rbc, - noise = noise_2_rbc, - observables = observables_2_rbc - ) - - DiffEqBase.get_concrete_problem(prob, false) - @inferred DiffEqBase.get_concrete_problem(prob, false) - - sol = solve(prob) - @inferred solve(prob) -end - -@testset "quadratic rbc joint likelihood" begin - @test joint_likelihood_2( - A_0_rbc, A_1_rbc, A_2_rbc, B_2_rbc, C_0_rbc, C_1_rbc, C_2_rbc, - u0_2_rbc, noise_2_rbc, observables_2_rbc, D_2_rbc - ) ≈ - -690.81094364573 - @inferred joint_likelihood_2( - A_0_rbc, A_1_rbc, A_2_rbc, B_2_rbc, C_0_rbc, C_1_rbc, - C_2_rbc, - u0_2_rbc, noise_2_rbc, observables_2_rbc, D_2_rbc - ) # would this catch inference problems in the solve? - gradient( - (args...) -> joint_likelihood_2(args..., observables_2_rbc, D_2_rbc), A_0_rbc, - A_1_rbc, - A_2_rbc, B_2_rbc, C_0_rbc, C_1_rbc, C_2_rbc, u0_2_rbc, noise_2_rbc - ) - test_rrule( - Zygote.ZygoteRuleConfig(), - (args...) -> joint_likelihood_2(args..., observables_2_rbc, D_2_rbc), - A_0_rbc, - A_1_rbc, A_2_rbc, B_2_rbc, C_0_rbc, C_1_rbc, C_2_rbc, u0_2_rbc, noise_2_rbc; - rrule_f = rrule_via_ad, check_inferred = false - ) -end - -# Load FVGQ data for checks -A_0_raw = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_A_0.csv"), ',') -A_0_FVGQ = vec(A_0_raw) -A_1_FVGQ = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_A_1.csv"), ',') -A_2_raw = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_A_2.csv"), ',') -A_2_FVGQ = reshape(A_2_raw, length(A_0_FVGQ), length(A_0_FVGQ), length(A_0_FVGQ)) -B_2_FVGQ = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_B.csv"), ',') -C_0_raw = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_C_0.csv"), ',') -C_0_FVGQ = vec(C_0_raw) -C_1_FVGQ = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_C_1.csv"), ',') -C_2_raw = readdlm(joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_C_2.csv"), ',') -C_2_FVGQ = reshape(C_2_raw, length(C_0_FVGQ), length(A_0_FVGQ), length(A_0_FVGQ)) -# D_raw = readdlm(joinpath(pkgdir(DifferenceEquations), "FVGQ_D.csv"); header = false))) -D_2_FVGQ = ones(6) * 1.0e-3 -u0_2_FVGQ = zeros(size(A_1_FVGQ, 1)) - -observables_2_FVGQ = readdlm( - joinpath( - pkgdir(DifferenceEquations), - "test/data/FVGQ20_observables.csv" - ), ',' -)' |> collect - -noise_2_FVGQ = readdlm( - joinpath(pkgdir(DifferenceEquations), "test/data/FVGQ20_noise.csv"), - ',' -)' |> - collect - -@testset "quadratic FVGQ joint likelihood" begin - @test joint_likelihood_2( - A_0_FVGQ, A_1_FVGQ, A_2_FVGQ, B_2_FVGQ, C_0_FVGQ, C_1_FVGQ, - C_2_FVGQ, - u0_2_FVGQ, noise_2_FVGQ, observables_2_FVGQ, D_2_FVGQ - ) ≈ - -1.4728927648336522e7 - @inferred joint_likelihood_2( - A_0_FVGQ, A_1_FVGQ, A_2_FVGQ, B_2_FVGQ, C_0_FVGQ, C_1_FVGQ, - C_2_FVGQ, - u0_2_FVGQ, noise_2_FVGQ, observables_2_FVGQ, D_2_FVGQ - ) - gradient( - (args...) -> joint_likelihood_2(args..., observables_2_FVGQ, D_2_FVGQ), - A_0_FVGQ, - A_1_FVGQ, - A_2_FVGQ, B_2_FVGQ, C_0_FVGQ, C_1_FVGQ, C_2_FVGQ, u0_2_FVGQ, noise_2_FVGQ - ) - - test_rrule( - Zygote.ZygoteRuleConfig(), - ( - A_0_FVGQ, - C_1_FVGQ, - u0_2_FVGQ, - ) -> joint_likelihood_2( - A_0_FVGQ, - A_1_FVGQ, - A_2_FVGQ, - B_2_FVGQ, - C_0_FVGQ, - C_1_FVGQ, - C_2_FVGQ, - u0_2_FVGQ, - noise_2_FVGQ, - observables_2_FVGQ, - D_2_FVGQ - ), - A_0_FVGQ, C_1_FVGQ, u0_2_FVGQ; - rrule_f = rrule_via_ad, check_inferred = false - ) - - # A little slow to run all of them all every time. Important occasionally, though, since tests the gradient wrt the noise - # test_rrule(Zygote.ZygoteRuleConfig(), - # (args...) -> joint_likelihood_2(args..., observables_FVGQ, D_FVGQ), A_0_FVGQ, - # A_1_FVGQ, - # A_2_FVGQ, B_2_FVGQ, C_0_FVGQ, C_1_FVGQ, C_2_FVGQ, u0_2_FVGQ, noise_2_FVGQ; - # rrule_f = rrule_via_ad, check_inferred = false) -end diff --git a/test/quadratic_simulations.jl b/test/quadratic_simulations.jl deleted file mode 100644 index 047652b..0000000 --- a/test/quadratic_simulations.jl +++ /dev/null @@ -1,110 +0,0 @@ -using ChainRulesTestUtils, DifferenceEquations, Distributions, LinearAlgebra, Test, Zygote -using DelimitedFiles -using DiffEqBase -using FiniteDiff: finite_difference_gradient - -# Matrices from RBC -A_0_rbc = [-7.824904812740593e-5, 0.0] -A_1_rbc = [0.9568351489231076 6.209371005755285; 3.0153731819288737e-18 0.20000000000000007] -A_2_rbc = cat( - [-0.00019761505863889124 0.03375055315837927; 0.0 0.0], - [0.03375055315837913 3.128758481817603; 0.0 0.0]; dims = 3 -) -B_2_rbc = reshape([0.0; -0.01], 2, 1) # make sure B is a matrix -C_0_rbc = [7.824904812740593e-5, 0.0] -C_1_rbc = [0.09579643002426148 0.6746869652592109; 1.0 0.0] -C_2_rbc = cat( - [-0.00018554166974717046 0.0025652363153049716; 0.0 0.0], - [0.002565236315304951 0.3132705036896446; 0.0 0.0]; dims = 3 -) -D_2_rbc = abs2.([0.1, 0.1]) -u0_2_rbc = zeros(2) - -observables_2_rbc = readdlm( - joinpath( - pkgdir(DifferenceEquations), - "test/data/RBC_observables.csv" - ), - ',' -)' |> collect - -# Data and Noise - -@testset "basic inference, simulated noise" begin - prob = QuadraticStateSpaceProblem( - A_0_rbc, A_1_rbc, A_2_rbc, B_2_rbc, u0_2_rbc, - (0, size(observables_2_rbc, 2)); C_0 = C_0_rbc, - C_1 = C_1_rbc, - C_2 = C_2_rbc, observables_noise = D_2_rbc, - observables = observables_2_rbc - ) - @inferred QuadraticStateSpaceProblem( - A_0_rbc, A_1_rbc, A_2_rbc, B_2_rbc, u0_2_rbc, - (0, size(observables_2_rbc, 2)); C_0 = C_0_rbc, - C_1 = C_1_rbc, C_2 = C_2_rbc, - observables_noise = D_2_rbc, - observables = observables_2_rbc - ) - - sol = solve(prob) - @inferred solve(prob) -end - -@testset "basic inference, simulated noise, no observations and no observation noise" begin - T = 20 - prob = QuadraticStateSpaceProblem( - A_0_rbc, A_1_rbc, A_2_rbc, B_2_rbc, u0_2_rbc, (0, T); - C_0 = C_0_rbc, C_1 = C_1_rbc, C_2 = C_2_rbc - ) - - sol = solve(prob) - @inferred solve(prob) - - # todo: add in regression tests -end - -@testset "basic inference, no simulated noise, no observations with observation noise" begin - T = 20 - B_no_noise = zeros(2, 2) - u0 = [1.0, 0.5] - B_no_noise = reshape([0.0; 0.0], 2, 1) - - prob_no_noise = QuadraticStateSpaceProblem( - A_0_rbc, A_1_rbc, A_2_rbc, B_no_noise, u0, - (0, T); - C_0 = C_0_rbc, C_1 = C_1_rbc, C_2 = C_2_rbc - ) - - sol_no_noise = solve(prob_no_noise) - - prob_obs_noise = QuadraticStateSpaceProblem( - A_0_rbc, A_1_rbc, A_2_rbc, B_no_noise, u0, - (0, T); - C_0 = C_0_rbc, C_1 = C_1_rbc, C_2 = C_2_rbc, - observables_noise = D_2_rbc - ) - @inferred QuadraticStateSpaceProblem( - A_0_rbc, A_1_rbc, A_2_rbc, B_no_noise, u0, (0, T); - C_0 = C_0_rbc, C_1 = C_1_rbc, C_2 = C_2_rbc, - observables_noise = D_2_rbc - ) - sol_obs_noise = solve(prob_obs_noise) - @inferred solve(prob_obs_noise) - - # check that if the variance of the noise is tiny it is identical - sol_tiny_obs_noise = solve( - QuadraticStateSpaceProblem( - A_0_rbc, A_1_rbc, A_2_rbc, - B_no_noise, u0, - (0, T); - C_0 = C_0_rbc, C_1 = C_1_rbc, - C_2 = C_2_rbc, - observables_noise = [ - 1.0e-16, - 1.0e-16, - ] - ) - ) - @test maximum(maximum.(sol_tiny_obs_noise.z - sol_no_noise.z)) < 1.0e-7 # still some noise - @test maximum(maximum.(sol_tiny_obs_noise.z - sol_no_noise.z)) > 0.0 # but not zero -end diff --git a/test/runtests.jl b/test/runtests.jl index eb894c4..c509a92 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,28 +5,38 @@ using Distributions using LinearAlgebra using Random -const GROUP = get(ENV, "GROUP", "All") - function activate_jet_env() Pkg.activate("jet") Pkg.develop(PackageSpec(path = dirname(@__DIR__))) return Pkg.instantiate() end -if GROUP == "All" || GROUP == "Core" - # include("matrix_vector_of_vectors.jl") # may add later to support noise inputs as vector of vectors - include("qa.jl") - include("explicit_imports.jl") - include("kalman_likelihood.jl") - include("linear_likelihood.jl") - include("linear_gradients.jl") - include("linear_simulations.jl") - include("quadratic_likelihood.jl") - include("quadratic_simulations.jl") - include("sciml_interfaces.jl") +include("qa.jl") +include("explicit_imports.jl") +include("linear_direct_iteration.jl") +include("kalman.jl") +include("direct_iteration.jl") +include("quadratic_direct_iteration.jl") +include("static_arrays.jl") +include("cache_reuse.jl") +include("sciml_interfaces.jl") +include("sensitivity_interface.jl") +include("linear_direct_iteration_forwarddiff.jl") +include("kalman_forwarddiff.jl") +include("conditional_likelihood.jl") +include("conditional_likelihood_forwarddiff.jl") +include("save_everystep.jl") + +if get(ENV, "CI", "false") != "true" + include("gradient_comparison.jl") + include("linear_direct_iteration_enzyme.jl") + include("quadratic_direct_iteration_enzyme.jl") + include("kalman_enzyme.jl") + include("conditional_likelihood_enzyme.jl") end -if GROUP == "JET" + +if get(ENV, "GROUP", "") == "JET" activate_jet_env() include("jet/jet_tests.jl") end diff --git a/test/save_everystep.jl b/test/save_everystep.jl new file mode 100644 index 0000000..11dc5ce --- /dev/null +++ b/test/save_everystep.jl @@ -0,0 +1,419 @@ +# Tests for save_everystep=false: endpoints-only solve with correct logpdf. +# Verifies that sol.u[1]=initial, sol.u[2]=final, logpdf matches full solve. + +using DifferenceEquations, Distributions, LinearAlgebra, Test, Random, ForwardDiff +using DifferenceEquations: init, solve! +using StaticArrays + +# ============================================================================= +# Shared test data +# ============================================================================= + +const A_se = [0.8 0.1; -0.1 0.7] +const B_se = [0.1 0.0; 0.0 0.1] +const C_se = [1.0 0.0; 0.0 1.0] +const u0_se = zeros(2) +const T_se = 10 + +Random.seed!(42) +const noise_se = [randn(2) for _ in 1:T_se] +const y_se = [randn(2) for _ in 1:T_se] + +# ============================================================================= +# Primal simulation: DirectIteration +# ============================================================================= + +@testset "save_everystep=false — DI simulation with C and noise" begin + prob = LinearStateSpaceProblem(A_se, B_se, u0_se, (0, T_se); C = C_se, noise = noise_se) + sol_full = solve(prob) + sol_ep = solve(prob; save_everystep = false) + + @test length(sol_ep.u) == 2 + @test length(sol_ep.z) == 2 + @test sol_ep.u[1] ≈ sol_full.u[1] + @test sol_ep.u[2] ≈ sol_full.u[end] + @test sol_ep.z[1] ≈ sol_full.z[1] + @test sol_ep.z[2] ≈ sol_full.z[end] + @test sol_ep.logpdf == 0.0 +end + +@testset "save_everystep=false — DI simulation C=nothing" begin + prob = LinearStateSpaceProblem(A_se, B_se, u0_se, (0, T_se); noise = noise_se) + sol_full = solve(prob) + sol_ep = solve(prob; save_everystep = false) + + @test length(sol_ep.u) == 2 + @test sol_ep.z === nothing + @test sol_ep.u[1] ≈ sol_full.u[1] + @test sol_ep.u[2] ≈ sol_full.u[end] +end + +@testset "save_everystep=false — DI simulation B=nothing" begin + prob = LinearStateSpaceProblem(A_se, nothing, [1.0, 0.5], (0, T_se); C = C_se) + sol_full = solve(prob) + sol_ep = solve(prob; save_everystep = false) + + @test length(sol_ep.u) == 2 + @test sol_ep.u[1] ≈ sol_full.u[1] + @test sol_ep.u[2] ≈ sol_full.u[end] + @test sol_ep.z[1] ≈ sol_full.z[1] + @test sol_ep.z[2] ≈ sol_full.z[end] +end + +@testset "save_everystep=false — DI with obs noise simulation" begin + prob = LinearStateSpaceProblem( + A_se, B_se, u0_se, (0, T_se); + C = C_se, noise = noise_se, observables_noise = Diagonal([0.01, 0.01]) + ) + sol_ep = solve(prob; save_everystep = false) + + @test length(sol_ep.u) == 2 + @test length(sol_ep.z) == 2 + # u endpoints are deterministic (noise is fixed), but z has random obs noise + # so we can only check u matches, not z (different random draws for 2 vs T+1 elements) + sol_no_obs_noise = solve( + LinearStateSpaceProblem(A_se, B_se, u0_se, (0, T_se); C = C_se, noise = noise_se); + save_everystep = false + ) + @test sol_ep.u[1] ≈ sol_no_obs_noise.u[1] + @test sol_ep.u[2] ≈ sol_no_obs_noise.u[2] +end + +# ============================================================================= +# Primal simulation: Generic StateSpaceProblem +# ============================================================================= + +@testset "save_everystep=false — Generic DI simulation" begin + f!! = (x_next, x, w, p, t) -> begin + mul!(x_next, p.A, x) + mul!(x_next, p.B, w, 1.0, 1.0) + return x_next + end + g!! = (y, x, p, t) -> begin + mul!(y, p.C, x) + return y + end + p = (; A = A_se, B = B_se, C = C_se) + + prob = StateSpaceProblem( + f!!, g!!, u0_se, (0, T_se), p; + n_shocks = 2, n_obs = 2, noise = noise_se + ) + sol_full = solve(prob) + sol_ep = solve(prob; save_everystep = false) + + @test length(sol_ep.u) == 2 + @test sol_ep.u[1] ≈ sol_full.u[1] + @test sol_ep.u[2] ≈ sol_full.u[end] + @test sol_ep.z[1] ≈ sol_full.z[1] + @test sol_ep.z[2] ≈ sol_full.z[end] +end + +# ============================================================================= +# Likelihood: DirectIteration +# ============================================================================= + +@testset "save_everystep=false — DI likelihood" begin + prob = LinearStateSpaceProblem( + A_se, B_se, u0_se, (0, T_se); C = C_se, + observables_noise = Diagonal([0.01, 0.01]), + observables = y_se, noise = noise_se + ) + sol_full = solve(prob) + sol_ep = solve(prob; save_everystep = false) + + @test sol_ep.logpdf ≈ sol_full.logpdf + @test sol_ep.u[1] ≈ sol_full.u[1] + @test sol_ep.u[2] ≈ sol_full.u[end] +end + +# ============================================================================= +# Likelihood: ConditionalLikelihood +# ============================================================================= + +@testset "save_everystep=false — CL AR(1)" begin + rho = 0.8; sigma_e = 0.5; T_cl = 20 + Random.seed!(111) + y_cl = [randn(1) for _ in 1:T_cl] + + prob = LinearStateSpaceProblem( + fill(rho, 1, 1), nothing, [0.0], (0, T_cl); + observables = y_cl, observables_noise = Diagonal([sigma_e^2]) + ) + sol_full = solve(prob, ConditionalLikelihood()) + sol_ep = solve(prob, ConditionalLikelihood(); save_everystep = false) + + @test sol_ep.logpdf ≈ sol_full.logpdf + @test sol_ep.u[1] ≈ sol_full.u[1] + @test sol_ep.u[2] ≈ sol_full.u[end] +end + +@testset "save_everystep=false — CL VAR(1)" begin + T_cl = 15 + Random.seed!(222) + y_cl = [randn(2) for _ in 1:T_cl] + + prob = LinearStateSpaceProblem( + A_se, nothing, u0_se, (0, T_cl); + observables = y_cl, observables_noise = Diagonal([0.25, 0.25]) + ) + sol_full = solve(prob, ConditionalLikelihood()) + sol_ep = solve(prob, ConditionalLikelihood(); save_everystep = false) + + @test sol_ep.logpdf ≈ sol_full.logpdf +end + +@testset "save_everystep=false — CL generic nonlinear" begin + rho = 0.8; alpha = 0.05; sigma_e = 0.3; T_cl = 15 + Random.seed!(333) + y_cl = [randn(1) for _ in 1:T_cl] + + f!! = (x_next, x, w, p, t) -> begin + val = p.rho * x[1] + p.alpha * x[1]^2 + if ismutable(x_next) + x_next[1] = val + return x_next + else + return typeof(x)(val) + end + end + + prob = StateSpaceProblem( + f!!, nothing, [0.0], (0, T_cl), (; rho, alpha); + n_shocks = 0, n_obs = 0, + observables = y_cl, observables_noise = Diagonal([sigma_e^2]) + ) + sol_full = solve(prob, ConditionalLikelihood()) + sol_ep = solve(prob, ConditionalLikelihood(); save_everystep = false) + + @test sol_ep.logpdf ≈ sol_full.logpdf +end + +# ============================================================================= +# Likelihood: KalmanFilter +# ============================================================================= + +@testset "save_everystep=false — KF" begin + Random.seed!(444) + y_kf = [randn(2) for _ in 1:T_se] + + prob = LinearStateSpaceProblem( + A_se, B_se, u0_se, (0, T_se); C = C_se, + observables_noise = Diagonal([0.01, 0.01]), observables = y_kf, + u0_prior_mean = zeros(2), u0_prior_var = Matrix(1.0 * I(2)) + ) + sol_full = solve(prob) + sol_ep = solve(prob; save_everystep = false) + + @test length(sol_ep.u) == 2 + @test length(sol_ep.P) == 2 + @test length(sol_ep.z) == 2 + @test sol_ep.logpdf ≈ sol_full.logpdf + @test sol_ep.u[1] ≈ sol_full.u[1] + @test sol_ep.u[2] ≈ sol_full.u[end] + @test sol_ep.P[1] ≈ sol_full.P[1] + @test sol_ep.P[2] ≈ sol_full.P[end] + @test sol_ep.z[1] ≈ sol_full.z[1] + @test sol_ep.z[2] ≈ sol_full.z[end] +end + +# ============================================================================= +# Quadratic models +# ============================================================================= + +@testset "save_everystep=false — Unpruned quadratic simulation" begin + A_0 = [0.0, 0.0] + A_1 = A_se + A_2 = zeros(2, 2, 2) + A_2[1, 1, 1] = 0.01 + + prob = QuadraticStateSpaceProblem( + A_0, A_1, A_2, B_se, u0_se, (0, T_se); + C_0 = [0.0, 0.0], C_1 = C_se, C_2 = zeros(2, 2, 2), noise = noise_se + ) + sol_full = solve(prob) + sol_ep = solve(prob; save_everystep = false) + + @test length(sol_ep.u) == 2 + @test sol_ep.u[1] ≈ sol_full.u[1] + @test sol_ep.u[2] ≈ sol_full.u[end] + @test sol_ep.z[1] ≈ sol_full.z[1] + @test sol_ep.z[2] ≈ sol_full.z[end] +end + +@testset "save_everystep=false — Pruned quadratic simulation" begin + A_0 = [0.0, 0.0] + A_1 = A_se + A_2 = zeros(2, 2, 2) + A_2[1, 1, 1] = 0.01 + + prob = PrunedQuadraticStateSpaceProblem( + A_0, A_1, A_2, B_se, u0_se, (0, T_se); + C_0 = [0.0, 0.0], C_1 = C_se, C_2 = zeros(2, 2, 2), noise = noise_se + ) + sol_full = solve(prob) + sol_ep = solve(prob; save_everystep = false) + + @test length(sol_ep.u) == 2 + @test sol_ep.u[1] ≈ sol_full.u[1] + @test sol_ep.u[2] ≈ sol_full.u[end] + @test sol_ep.z[1] ≈ sol_full.z[1] + @test sol_ep.z[2] ≈ sol_full.z[end] +end + +# ============================================================================= +# StaticArrays +# ============================================================================= + +@testset "save_everystep=false — StaticArrays DI simulation" begin + A_s = SMatrix{2, 2}(A_se) + B_s = SMatrix{2, 2}(B_se) + C_s = SMatrix{2, 2}(C_se) + u0_s = SVector{2}(u0_se) + noise_s = [SVector{2}(n) for n in noise_se] + + prob = LinearStateSpaceProblem(A_s, B_s, u0_s, (0, T_se); C = C_s, noise = noise_s) + sol_full = solve(prob) + sol_ep = solve(prob; save_everystep = false) + + @test length(sol_ep.u) == 2 + @test sol_ep.u[1] ≈ sol_full.u[1] + @test sol_ep.u[2] ≈ sol_full.u[end] +end + +@testset "save_everystep=false — StaticArrays CL" begin + T_cl = 10 + y_s = [SVector{2}(randn(2)) for _ in 1:T_cl] + + prob = LinearStateSpaceProblem( + SMatrix{2, 2}(A_se), nothing, SVector{2}(u0_se), (0, T_cl); + observables = y_s, observables_noise = Diagonal(SVector{2}(0.25, 0.25)) + ) + sol_full = solve(prob, ConditionalLikelihood()) + sol_ep = solve(prob, ConditionalLikelihood(); save_everystep = false) + + @test sol_ep.logpdf ≈ sol_full.logpdf +end + +@testset "save_everystep=false — StaticArrays KF" begin + T_kf = 10 + Random.seed!(555) + y_s = [SVector{2}(randn(2)) for _ in 1:T_kf] + + prob = LinearStateSpaceProblem( + SMatrix{2, 2}(A_se), SMatrix{2, 2}(B_se), SVector{2}(u0_se), (0, T_kf); + C = SMatrix{2, 2}(C_se), + observables_noise = Diagonal(SVector{2}(0.01, 0.01)), observables = y_s, + u0_prior_mean = SVector{2}(0.0, 0.0), + u0_prior_var = SMatrix{2, 2}(1.0, 0.0, 0.0, 1.0) + ) + sol_full = solve(prob) + sol_ep = solve(prob; save_everystep = false) + + @test sol_ep.logpdf ≈ sol_full.logpdf + @test sol_ep.u[1] ≈ sol_full.u[1] + @test sol_ep.u[2] ≈ sol_full.u[end] +end + +# ============================================================================= +# Workspace reuse +# ============================================================================= + +@testset "save_everystep=false — workspace init/solve! reuse" begin + prob = LinearStateSpaceProblem( + A_se, nothing, u0_se, (0, T_se); + observables = y_se, observables_noise = Diagonal([0.25, 0.25]) + ) + ws = init(prob, ConditionalLikelihood(); save_everystep = false) + sol1 = solve!(ws) + sol2 = solve!(ws) + @test sol1.logpdf ≈ sol2.logpdf + @test sol1.u ≈ sol2.u + @test length(sol1.u) == 2 +end + +# ============================================================================= +# Edge cases +# ============================================================================= + +@testset "save_everystep=false — edge case T=2 (1 step)" begin + y1 = [randn(2)] + prob = LinearStateSpaceProblem( + A_se, nothing, u0_se, (0, 1); + observables = y1, observables_noise = Diagonal([0.25, 0.25]) + ) + sol_full = solve(prob, ConditionalLikelihood()) + sol_ep = solve(prob, ConditionalLikelihood(); save_everystep = false) + @test sol_ep.logpdf ≈ sol_full.logpdf + @test sol_ep.u[1] ≈ sol_full.u[1] + @test sol_ep.u[2] ≈ sol_full.u[end] +end + +@testset "save_everystep=false — edge case T=3 (2 steps)" begin + y2 = [randn(2) for _ in 1:2] + prob = LinearStateSpaceProblem( + A_se, nothing, u0_se, (0, 2); + observables = y2, observables_noise = Diagonal([0.25, 0.25]) + ) + sol_full = solve(prob, ConditionalLikelihood()) + sol_ep = solve(prob, ConditionalLikelihood(); save_everystep = false) + @test sol_ep.logpdf ≈ sol_full.logpdf + @test sol_ep.u[1] ≈ sol_full.u[1] + @test sol_ep.u[2] ≈ sol_full.u[end] +end + +# ============================================================================= +# ForwardDiff gradients match +# ============================================================================= + +include("forwarddiff_test_utils.jl") + +@testset "save_everystep=false — ForwardDiff CL gradient matches" begin + function cl_fd(A_vec, y, se) + T_el = eltype(A_vec) + A = reshape(A_vec, 2, 2) + prob = LinearStateSpaceProblem( + A, nothing, zeros(T_el, 2), (0, length(y)); + observables = y, observables_noise = Diagonal([T_el(0.25), T_el(0.25)]) + ) + return solve(prob, ConditionalLikelihood(); save_everystep = se).logpdf + end + + Random.seed!(777) + y_fd = [randn(2) for _ in 1:10] + x0 = vec(copy(A_se)) + + g_true = ForwardDiff.gradient(a -> cl_fd(a, y_fd, true), x0) + g_false = ForwardDiff.gradient(a -> cl_fd(a, y_fd, false), x0) + @test g_true ≈ g_false +end + +@testset "save_everystep=false — ForwardDiff KF gradient matches" begin + function kf_fd(A_vec, B, C, mu0, Sigma0, R, y, se) + T_el = eltype(A_vec) + A = reshape(A_vec, 2, 2) + prob = LinearStateSpaceProblem( + A, promote_array(T_el, B), zeros(T_el, 2), (0, length(y)); + C = promote_array(T_el, C), + observables_noise = promote_array(T_el, R), observables = y, + u0_prior_mean = promote_array(T_el, mu0), + u0_prior_var = promote_array(T_el, Sigma0) + ) + return solve(prob, KalmanFilter(); save_everystep = se).logpdf + end + + Random.seed!(888) + y_fd = [randn(2) for _ in 1:10] + mu0 = zeros(2) + Sigma0 = Matrix(1.0 * I(2)) + R = Diagonal([0.01, 0.01]) + x0 = vec(copy(A_se)) + + g_true = ForwardDiff.gradient( + a -> kf_fd(a, B_se, C_se, mu0, Sigma0, R, y_fd, true), x0 + ) + g_false = ForwardDiff.gradient( + a -> kf_fd(a, B_se, C_se, mu0, Sigma0, R, y_fd, false), x0 + ) + @test g_true ≈ g_false +end diff --git a/test/sciml_interfaces.jl b/test/sciml_interfaces.jl index 4ccf73c..390e6fb 100644 --- a/test/sciml_interfaces.jl +++ b/test/sciml_interfaces.jl @@ -1,10 +1,8 @@ -using ChainRulesTestUtils, DifferenceEquations, Distributions, LinearAlgebra, Test, Zygote -using DelimitedFiles -using DiffEqBase -using FiniteDiff: finite_difference_gradient -using Plots, DataFrames +using DifferenceEquations, Distributions, LinearAlgebra, Test +using DelimitedFiles, DiffEqBase, Plots, DataFrames + +# --- RBC model data (shared by both problem types) --- -# Matrices from RBC A_rbc = [ 0.9568351489231076 6.209371005755285; 3.0153731819288737e-18 0.20000000000000007 @@ -14,41 +12,36 @@ C_rbc = [0.09579643002426148 0.6746869652592109; 1.0 0.0] D_rbc = abs2.([0.1, 0.1]) u0_rbc = zeros(2) -observables_rbc = readdlm( - joinpath( - pkgdir(DifferenceEquations), - "test/data/RBC_observables.csv" - ), - ',' +observables_rbc_matrix = readdlm( + joinpath(pkgdir(DifferenceEquations), "test/data/RBC_observables.csv"), ',' +)' |> collect +noise_rbc_matrix = readdlm( + joinpath(pkgdir(DifferenceEquations), "test/data/RBC_noise.csv"), ',' )' |> collect -noise_rbc = readdlm( - joinpath(pkgdir(DifferenceEquations), "test/data/RBC_noise.csv"), - ',' -)' |> - collect -# Data and Noise T = 5 -observables_rbc = observables_rbc[:, 1:T] -noise_rbc = noise_rbc[:, 1:T] +observables_rbc = [observables_rbc_matrix[:, t] for t in 1:T] +noise_rbc = [noise_rbc_matrix[:, t] for t in 1:T] -@testset "Plotting given noise" begin +# --- LinearStateSpaceProblem SciML interfaces --- + +@testset "Plotting given noise (Linear)" begin prob = LinearStateSpaceProblem( - A_rbc, B_rbc, u0_rbc, (0, size(observables_rbc, 2)); + A_rbc, B_rbc, u0_rbc, (0, length(observables_rbc)); C = C_rbc, - observables_noise = D_rbc, noise = noise_rbc, + observables_noise = Diagonal(D_rbc), noise = noise_rbc, observables = observables_rbc, syms = (:a, :b) ) sol = solve(prob) plot(sol) end -@testset "Ensemble simulation and plotting given noise" begin +@testset "Ensemble simulation and plotting given noise (Linear)" begin # random initial conditions via the u0 prob = LinearStateSpaceProblem( A_rbc, B_rbc, MvNormal(u0_rbc, diagm(ones(length(u0_rbc)))), - (0, size(observables_rbc, 2)); C = C_rbc, - observables_noise = D_rbc, noise = noise_rbc, + (0, length(observables_rbc)); C = C_rbc, + observables_noise = Diagonal(D_rbc), noise = noise_rbc, observables = observables_rbc, syms = (:a, :b) ) sol2 = solve( @@ -60,12 +53,12 @@ end plot(summ) end -@testset "Dataframes" begin +@testset "Dataframes (Linear)" begin prob = LinearStateSpaceProblem( A_rbc, B_rbc, MvNormal(u0_rbc, diagm(ones(length(u0_rbc)))), - (0, size(observables_rbc, 2)); C = C_rbc, - observables_noise = D_rbc, noise = noise_rbc, + (0, length(observables_rbc)); C = C_rbc, + observables_noise = Diagonal(D_rbc), noise = noise_rbc, observables = observables_rbc, syms = (:a, :b) ) sol = solve(prob) @@ -74,23 +67,259 @@ end @test size(df) == (6, 3) end -@testset "Plotting simulating noise" begin +@testset "Symbolic indexing — state and obs (Linear)" begin + prob = LinearStateSpaceProblem( + A_rbc, B_rbc, u0_rbc, (0, length(observables_rbc)); + C = C_rbc, + observables_noise = Diagonal(D_rbc), noise = noise_rbc, + observables = observables_rbc, + syms = (:capital, :productivity), + obs_syms = (:output, :consumption) + ) + sol = solve(prob) + + # State indexing + @test sol[:capital] ≈ [sol.u[t][1] for t in eachindex(sol.u)] + @test sol[:productivity] ≈ [sol.u[t][2] for t in eachindex(sol.u)] + + # Observation indexing + @test sol[:output] ≈ [sol.z[t][1] for t in eachindex(sol.z)] + @test sol[:consumption] ≈ [sol.z[t][2] for t in eachindex(sol.z)] + + # Unknown symbol errors + @test_throws Exception sol[:nonexistent] + + # Direct u access works + @test length(sol.u) == length(observables_rbc) + 1 + + # DataFrame still works + df = DataFrame(sol) + @test :capital in propertynames(df) +end + +@testset "No syms — backward compat (Linear)" begin + prob = LinearStateSpaceProblem( + A_rbc, B_rbc, u0_rbc, (0, length(observables_rbc)) + ) + sol = solve(prob) + @test length(sol.u) == length(observables_rbc) + 1 +end + +@testset "Plotting simulating noise (Linear)" begin prob = LinearStateSpaceProblem( - A_rbc, B_rbc, u0_rbc, (0, size(observables_rbc, 2)); + A_rbc, B_rbc, u0_rbc, (0, length(observables_rbc)); C = C_rbc, - observables_noise = D_rbc, observables = observables_rbc, + observables_noise = Diagonal(D_rbc), observables = observables_rbc, syms = (:a, :b) ) sol = solve(prob) plot(sol) end -@testset "Ensemble simulation and plotting, simulating noise" begin +@testset "Ensemble simulation and plotting, simulating noise (Linear)" begin # fixed initial condition, random noise prob = LinearStateSpaceProblem( - A_rbc, B_rbc, u0_rbc, (0, size(observables_rbc, 2)); + A_rbc, B_rbc, u0_rbc, (0, length(observables_rbc)); C = C_rbc, - observables_noise = D_rbc, observables = observables_rbc, + observables_noise = Diagonal(D_rbc), observables = observables_rbc, + syms = (:a, :b) + ) + sol2 = solve( + EnsembleProblem(prob), DirectIteration(), EnsembleThreads(); + trajectories = 10 + ) + plot(sol2) + summ = EnsembleSummary(sol2) + plot(summ) +end + +# --- StateSpaceProblem callbacks + data --- + +linear_f!! = (x_next, x, w, p, t) -> begin + mul!(x_next, p.A, x) + mul!(x_next, p.B, w, 1.0, 1.0) + return x_next +end +linear_g!! = (y, x, p, t) -> begin + mul!(y, p.C, x) + return y +end +p_rbc = (; A = A_rbc, B = B_rbc, C = C_rbc) + +# --- StateSpaceProblem SciML interfaces --- + +@testset "remake with u0 and p (Generic)" begin + prob = StateSpaceProblem( + linear_f!!, linear_g!!, u0_rbc, (0, T), p_rbc; + n_shocks = 1, n_obs = 2, + observables_noise = Diagonal(D_rbc), noise = noise_rbc, observables = observables_rbc + ) + + # remake with new u0 + new_u0 = [0.1, 0.2] + prob2 = remake(prob; u0 = new_u0) + @test prob2.u0 == new_u0 + @test prob2.p === p_rbc + sol2 = solve(prob2) + @test length(sol2.u) == T + 1 + + # remake with new p + new_p = (; A = A_rbc * 0.99, B = B_rbc, C = C_rbc) + prob3 = remake(prob; p = new_p) + @test prob3.p === new_p + @test prob3.u0 == u0_rbc + sol3 = solve(prob3) + @test length(sol3.u) == T + 1 +end + +@testset "Plotting given noise (Generic)" begin + prob = StateSpaceProblem( + linear_f!!, linear_g!!, u0_rbc, (0, T), p_rbc; + n_shocks = 1, n_obs = 2, + observables_noise = Diagonal(D_rbc), noise = noise_rbc, + observables = observables_rbc, syms = (:a, :b) + ) + sol = solve(prob) + plot(sol) +end + +@testset "Ensemble simulation and plotting given noise (Generic)" begin + prob = StateSpaceProblem( + linear_f!!, linear_g!!, + MvNormal(u0_rbc, diagm(ones(length(u0_rbc)))), + (0, T), p_rbc; + n_shocks = 1, n_obs = 2, + observables_noise = Diagonal(D_rbc), noise = noise_rbc, + observables = observables_rbc, syms = (:a, :b) + ) + sol2 = solve( + EnsembleProblem(prob), DirectIteration(), EnsembleThreads(); + trajectories = 10 + ) + plot(sol2) + summ = EnsembleSummary(sol2) + plot(summ) +end + +@testset "Dataframes (Generic)" begin + prob = StateSpaceProblem( + linear_f!!, linear_g!!, + MvNormal(u0_rbc, diagm(ones(length(u0_rbc)))), + (0, T), p_rbc; + n_shocks = 1, n_obs = 2, + observables_noise = Diagonal(D_rbc), noise = noise_rbc, + observables = observables_rbc, syms = (:a, :b) + ) + sol = solve(prob) + df = DataFrame(sol) + @test propertynames(df) == [:timestamp, :a, :b] + @test size(df) == (T + 1, 3) +end + +@testset "Symbolic indexing — state and obs (Generic)" begin + prob = StateSpaceProblem( + linear_f!!, linear_g!!, u0_rbc, (0, T), p_rbc; + n_shocks = 1, n_obs = 2, + syms = (:capital, :productivity), + obs_syms = (:output, :consumption), + observables_noise = Diagonal(D_rbc), noise = noise_rbc, observables = observables_rbc + ) + sol = solve(prob) + + # State indexing + @test sol[:capital] ≈ [sol.u[t][1] for t in eachindex(sol.u)] + @test sol[:productivity] ≈ [sol.u[t][2] for t in eachindex(sol.u)] + + # Observation indexing + @test sol[:output] ≈ [sol.z[t][1] for t in eachindex(sol.z)] + @test sol[:consumption] ≈ [sol.z[t][2] for t in eachindex(sol.z)] + + # Unknown symbol errors + @test_throws Exception sol[:nonexistent] + + # Direct u access works + @test length(sol.u) == T + 1 + + # DataFrame still works + df = DataFrame(sol) + @test :capital in propertynames(df) +end + +@testset "Symbolic indexing — syms only, no obs_syms (Generic)" begin + prob = StateSpaceProblem( + linear_f!!, linear_g!!, u0_rbc, (0, T), p_rbc; + n_shocks = 1, n_obs = 2, + syms = (:capital, :productivity), + observables_noise = Diagonal(D_rbc), noise = noise_rbc, observables = observables_rbc + ) + sol = solve(prob) + @test sol[:capital] ≈ [sol.u[t][1] for t in eachindex(sol.u)] + @test_throws ArgumentError sol[:output] # no obs_syms defined +end + +@testset "Symbolic indexing — obs_syms only, no syms (Generic)" begin + prob = StateSpaceProblem( + linear_f!!, linear_g!!, u0_rbc, (0, T), p_rbc; + n_shocks = 1, n_obs = 2, + obs_syms = (:output, :consumption), + observables_noise = Diagonal(D_rbc), noise = noise_rbc, observables = observables_rbc + ) + sol = solve(prob) + @test sol[:output] ≈ [sol.z[t][1] for t in eachindex(sol.z)] + @test_throws ArgumentError sol[:capital] # no syms defined +end + +@testset "Symbolic indexing — obs_syms but no observations in solution (Generic)" begin + prob = StateSpaceProblem( + linear_f!!, nothing, u0_rbc, (0, T), p_rbc; + n_shocks = 1, n_obs = 0, + obs_syms = (:output, :consumption), + noise = noise_rbc + ) + sol = solve(prob) + @test sol.z === nothing + @test_throws Exception sol[:output] # obs_syms defined but z is nothing +end + +@testset "Symbolic indexing survives remake (Generic)" begin + prob = StateSpaceProblem( + linear_f!!, linear_g!!, u0_rbc, (0, T), p_rbc; + n_shocks = 1, n_obs = 2, + syms = (:capital, :productivity), + obs_syms = (:output, :consumption), + observables_noise = Diagonal(D_rbc), noise = noise_rbc, observables = observables_rbc + ) + prob2 = remake(prob; u0 = [0.1, 0.2]) + sol2 = solve(prob2) + @test sol2[:capital] ≈ [sol2.u[t][1] for t in eachindex(sol2.u)] + @test sol2[:output] ≈ [sol2.z[t][1] for t in eachindex(sol2.z)] +end + +@testset "No syms — backward compat (Generic)" begin + prob = StateSpaceProblem( + linear_f!!, linear_g!!, u0_rbc, (0, T), p_rbc; + n_shocks = 1, n_obs = 2 + ) + sol = solve(prob) + @test length(sol.u) == T + 1 +end + +@testset "Plotting simulating noise (Generic)" begin + prob = StateSpaceProblem( + linear_f!!, linear_g!!, u0_rbc, (0, T), p_rbc; + n_shocks = 1, n_obs = 2, + observables_noise = Diagonal(D_rbc), observables = observables_rbc, + syms = (:a, :b) + ) + sol = solve(prob) + plot(sol) +end + +@testset "Ensemble simulation and plotting, simulating noise (Generic)" begin + prob = StateSpaceProblem( + linear_f!!, linear_g!!, u0_rbc, (0, T), p_rbc; + n_shocks = 1, n_obs = 2, + observables_noise = Diagonal(D_rbc), observables = observables_rbc, syms = (:a, :b) ) sol2 = solve( diff --git a/test/sensitivity_interface.jl b/test/sensitivity_interface.jl new file mode 100644 index 0000000..2d56677 --- /dev/null +++ b/test/sensitivity_interface.jl @@ -0,0 +1,104 @@ +# Minimal sensitivity interface MWE +# Tests Enzyme AD through SciML-like struct construction + solve! pattern. +# Serves as reproducible MWE for SciML/Enzyme developers. + +using LinearAlgebra, Test, Enzyme, EnzymeTestUtils + +# ============================================================================= +# Minimal problem type (immutable — Enzyme handles struct construction fine +# when all args are Duplicated) +# ============================================================================= + +struct MinimalProblem{AT, UT} + A::AT + u0::UT +end + +struct MinimalCache{UT} + u::Vector{UT} +end + +function alloc_minimal_cache(u0, T) + return MinimalCache([similar(u0) for _ in 1:T]) +end + +function zero_minimal_cache!(cache) + for t in eachindex(cache.u) + cache.u[t] .= 0 + end + return cache +end + +# ============================================================================= +# Solve: u[t+1] = A * u[t], in-place mutation of cache +# ============================================================================= + +function minimal_solve!(prob::MinimalProblem, cache::MinimalCache) + cache.u[1] .= prob.u0 + for t in 1:(length(cache.u) - 1) + mul!(cache.u[t + 1], prob.A, cache.u[t]) + end + return nothing +end + +# ============================================================================= +# Wrapper functions for Enzyme AD +# ============================================================================= + +# Forward: constructs problem, solves, returns mutated cache (not nothing). +# Rule: a function that mutates an argument must return that argument. +function minimal_solve_wrapper!(A, u0, cache) + prob = MinimalProblem(A, u0) + zero_minimal_cache!(cache) + minimal_solve!(prob, cache) + return cache +end + +# Scalar: constructs problem, solves, returns scalar for reverse mode. +function minimal_loss(A, u0, cache)::Float64 + prob = MinimalProblem(A, u0) + zero_minimal_cache!(cache) + minimal_solve!(prob, cache) + return sum(cache.u[end]) +end + +# ============================================================================= +# Tests +# ============================================================================= + +@testset "Minimal sensitivity interface - sanity" begin + A = [0.8 0.1; -0.1 0.7] + u0 = [1.0, 0.5] + cache = alloc_minimal_cache(u0, 4) + + minimal_solve_wrapper!(A, u0, cache) + @test cache.u[1] ≈ u0 + @test cache.u[4] ≈ A^3 * u0 + + loglik = minimal_loss(A, u0, cache) + @test loglik ≈ sum(A^3 * u0) +end + +@testset "Minimal sensitivity - forward (in-place, validates cache tangents)" begin + A = [0.8 0.1; -0.1 0.7] + u0 = [1.0, 0.5] + + test_forward( + minimal_solve_wrapper!, Const, + (copy(A), Duplicated), + (copy(u0), Duplicated), + (alloc_minimal_cache(u0, 4), Duplicated) + ) +end + +@testset "Minimal sensitivity - reverse (scalar loglik)" begin + A = [0.8 0.1; -0.1 0.7] + u0 = [1.0, 0.5] + + test_reverse( + minimal_loss, Active, + (copy(A), Duplicated), + (copy(u0), Duplicated), + (alloc_minimal_cache(u0, 4), Duplicated) + ) +end diff --git a/test/static_arrays.jl b/test/static_arrays.jl new file mode 100644 index 0000000..a69391b --- /dev/null +++ b/test/static_arrays.jl @@ -0,0 +1,421 @@ +using DifferenceEquations, LinearAlgebra, Random, Test +using StaticArrays +using DifferenceEquations: mul!!, muladd!! + +# --- LinearStateSpaceProblem --- + +@testset "StaticArrays linear DirectIteration" begin + A = @SMatrix [0.9 0.1; 0.0 0.8] + B = @SMatrix [0.0; 0.1;;] # 2×1 SMatrix + C = @SMatrix [1.0 0.0; 0.0 1.0] + u0 = @SVector [0.5, 0.3] + + # Create noise as vector of SVector + noise = [SVector{1, Float64}(randn()) for _ in 1:9] + + prob = LinearStateSpaceProblem(A, B, u0, (0, 9); C, noise) + + # Compare SVector result to Vector result + A_v = Matrix(A) + B_v = Matrix(B) + C_v = Matrix(C) + u0_v = Vector(u0) + noise_v = [Vector(n) for n in noise] + + prob_v = LinearStateSpaceProblem(A_v, B_v, u0_v, (0, 9); C = C_v, noise = noise_v) + + sol_s = solve(prob) + sol_v = solve(prob_v) + + # Results should match + for t in eachindex(sol_s.u) + @test Vector(sol_s.u[t]) ≈ sol_v.u[t] + end + for t in eachindex(sol_s.z) + @test Vector(sol_s.z[t]) ≈ sol_v.z[t] + end +end + +@testset "StaticArrays linear no noise" begin + A = @SMatrix [0.9 0.1; 0.0 0.8] + C = @SMatrix [1.0 0.0; 0.0 1.0] + u0 = @SVector [1.0, 0.5] + + prob = LinearStateSpaceProblem(A, nothing, u0, (0, 5); C) + + A_v = Matrix(A) + C_v = Matrix(C) + u0_v = Vector(u0) + prob_v = LinearStateSpaceProblem(A_v, nothing, u0_v, (0, 5); C = C_v) + + sol_s = solve(prob) + sol_v = solve(prob_v) + + for t in eachindex(sol_s.u) + @test Vector(sol_s.u[t]) ≈ sol_v.u[t] + end +end + +@testset "StaticArrays no observation" begin + A = @SMatrix [0.9 0.1; 0.0 0.8] + B = @SMatrix [0.0; 0.1;;] + u0 = @SVector [1.0, 0.5] + + noise = [SVector{1, Float64}(randn()) for _ in 1:4] + + prob = LinearStateSpaceProblem(A, B, u0, (0, 4); C = nothing, noise) + sol = solve(prob) + + @test sol.z === nothing + @test length(sol.u) == 5 + + # Verify against manual computation + A_v = Matrix(A) + B_v = Matrix(B) + u0_v = Vector(u0) + noise_v = [Vector(n) for n in noise] + prob_v = LinearStateSpaceProblem(A_v, B_v, u0_v, (0, 4); C = nothing, noise = noise_v) + sol_v = solve(prob_v) + + for t in eachindex(sol.u) + @test Vector(sol.u[t]) ≈ sol_v.u[t] + end +end + +# --- Generic !! callbacks --- + +@inline function f_lss!!(x_p, x, w, p, t) + x_p = mul!!(x_p, p.A, x) + return muladd!!(x_p, p.B, w) +end + +@inline function g_lss!!(y, x, p, t) + return mul!!(y, p.C, x) +end + +@testset "Generic !! callbacks — mutable vs static consistency" begin + A_m = [0.9 0.1; 0.0 0.8] + B_m = reshape([0.0; 0.1], 2, 1) + C_m = [1.0 0.0; 0.0 1.0] + u0_m = [0.5, 0.3] + noise_vals = [randn(1) for _ in 1:9] + + # Mutable version + p_m = (; A = A_m, B = B_m, C = C_m) + prob_m = StateSpaceProblem( + f_lss!!, g_lss!!, u0_m, (0, 9), p_m; + n_shocks = 1, n_obs = 2, noise = noise_vals + ) + sol_m = solve(prob_m) + + # Static version — same callbacks, same data, just wrapped in SMatrix/SVector + A_s = SMatrix{2, 2}(A_m) + B_s = SMatrix{2, 1}(B_m) + C_s = SMatrix{2, 2}(C_m) + u0_s = SVector{2}(u0_m) + noise_s = [SVector{1}(n) for n in noise_vals] + + p_s = (; A = A_s, B = B_s, C = C_s) + prob_s = StateSpaceProblem( + f_lss!!, g_lss!!, u0_s, (0, 9), p_s; + n_shocks = 1, n_obs = 2, noise = noise_s + ) + sol_s = solve(prob_s) + + # Results must match exactly + for t in eachindex(sol_m.u) + @test Vector(sol_s.u[t]) ≈ sol_m.u[t] + end + for t in eachindex(sol_m.z) + @test Vector(sol_s.z[t]) ≈ sol_m.z[t] + end + + # Verify static types are preserved + @test eltype(sol_s.u) <: SVector{2, Float64} + @test eltype(sol_s.z) <: SVector{2, Float64} +end + +@testset "Generic !! callbacks — static matches LinearStateSpaceProblem" begin + A = @SMatrix [0.9 0.1; 0.0 0.8] + B = @SMatrix [0.0; 0.1;;] + C = @SMatrix [1.0 0.0; 0.0 1.0] + u0 = @SVector [0.5, 0.3] + noise = [SVector{1, Float64}(randn()) for _ in 1:9] + + prob_linear = LinearStateSpaceProblem(A, B, u0, (0, 9); C, noise) + sol_linear = solve(prob_linear) + + p = (; A, B, C) + prob_generic = StateSpaceProblem( + f_lss!!, g_lss!!, u0, (0, 9), p; + n_shocks = 1, n_obs = 2, noise = noise + ) + sol_generic = solve(prob_generic) + + for t in eachindex(sol_linear.u) + @test sol_linear.u[t] ≈ sol_generic.u[t] + end + for t in eachindex(sol_linear.z) + @test sol_linear.z[t] ≈ sol_generic.z[t] + end +end + +@testset "Generic !! callbacks — static no noise" begin + A = @SMatrix [0.9 0.1; 0.0 0.8] + C = @SMatrix [1.0 0.0; 0.0 1.0] + u0 = @SVector [1.0, 0.5] + + prob_linear = LinearStateSpaceProblem(A, nothing, u0, (0, 5); C) + sol_linear = solve(prob_linear) + + # f_lss!! handles w=nothing via muladd!!(x_p, B, nothing) → x_p + p = (; A, B = nothing, C) + prob_generic = StateSpaceProblem( + f_lss!!, g_lss!!, u0, (0, 5), p; + n_shocks = 0, n_obs = 2 + ) + sol_generic = solve(prob_generic) + + for t in eachindex(sol_linear.u) + @test sol_linear.u[t] ≈ sol_generic.u[t] + end + for t in eachindex(sol_linear.z) + @test sol_linear.z[t] ≈ sol_generic.z[t] + end +end + +# --- KalmanFilter with StaticArrays --- + +@testset "StaticArrays Kalman filter" begin + Random.seed!(789) + A_raw = randn(3, 3) + A_m = 0.5 * A_raw / maximum(abs.(eigvals(A_raw))) + B_m = 0.1 * randn(3, 2) + C_m = randn(2, 3) + R_m = 0.01 * I(2) |> Matrix + mu0_m = zeros(3) + Sig0_m = Matrix{Float64}(I, 3, 3) + + # Generate observations + x0 = randn(3) + noise = [randn(2) for _ in 1:10] + sim = solve(LinearStateSpaceProblem(A_m, B_m, x0, (0, 10); C = C_m, noise)) + y_m = [sim.z[t + 1] + 0.1 * randn(2) for t in 1:10] + + prob_m = LinearStateSpaceProblem( + A_m, B_m, zeros(3), (0, 10); C = C_m, + u0_prior_mean = mu0_m, u0_prior_var = Sig0_m, + observables_noise = R_m, observables = y_m + ) + sol_m = solve(prob_m, KalmanFilter()) + + # Static version + A_s = SMatrix{3, 3}(A_m) + B_s = SMatrix{3, 2}(B_m) + C_s = SMatrix{2, 3}(C_m) + R_s = SMatrix{2, 2}(R_m) + mu0_s = SVector{3}(mu0_m) + Sig0_s = SMatrix{3, 3}(Sig0_m) + y_s = [SVector{2}(y) for y in y_m] + + prob_s = LinearStateSpaceProblem( + A_s, B_s, SVector{3}(zeros(3)), (0, 10); C = C_s, + u0_prior_mean = mu0_s, u0_prior_var = Sig0_s, + observables_noise = R_s, observables = y_s + ) + sol_s = solve(prob_s, KalmanFilter()) + + # logpdf must match + @test sol_s.logpdf ≈ sol_m.logpdf + + # Filtered states and covariances must match + for t in eachindex(sol_s.u) + @test Vector(sol_s.u[t]) ≈ sol_m.u[t] + @test Matrix(sol_s.P[t]) ≈ sol_m.P[t] + end + for t in eachindex(sol_s.z) + @test Vector(sol_s.z[t]) ≈ sol_m.z[t] + end + + # Verify static types are preserved + @test eltype(sol_s.u) <: SVector{3, Float64} + @test eltype(sol_s.P) <: SMatrix{3, 3, Float64} + @test eltype(sol_s.z) <: SVector{2, Float64} +end + +# --- PrunedQuadraticStateSpaceProblem with StaticArrays --- + +@testset "StaticArrays pruned quadratic" begin + Random.seed!(42) + A_2 = 0.01 * randn(2, 2, 2) + C_2 = 0.01 * randn(2, 2, 2) + noise_vals = [randn(1) for _ in 1:10] + + # Mutable + A0_m = [0.001, -0.001] + A1_m = [0.3 0.1; -0.1 0.3] + B_m = reshape([0.1, 0.0], 2, 1) + C0_m = [0.001, -0.001] + C1_m = [1.0 0.0; 0.0 1.0] + u0_m = zeros(2) + + prob_m = PrunedQuadraticStateSpaceProblem( + A0_m, A1_m, A_2, B_m, u0_m, (0, 10); + C_0 = C0_m, C_1 = C1_m, C_2 = C_2, noise = noise_vals + ) + sol_m = solve(prob_m) + + # Static + A0_s = @SVector [0.001, -0.001] + A1_s = @SMatrix [0.3 0.1; -0.1 0.3] + B_s = @SMatrix [0.1; 0.0;;] + C0_s = @SVector [0.001, -0.001] + C1_s = @SMatrix [1.0 0.0; 0.0 1.0] + u0_s = @SVector zeros(2) + noise_s = [SVector{1}(n) for n in noise_vals] + + prob_s = PrunedQuadraticStateSpaceProblem( + A0_s, A1_s, A_2, B_s, u0_s, (0, 10); + C_0 = C0_s, C_1 = C1_s, C_2 = C_2, noise = noise_s + ) + sol_s = solve(prob_s) + + for t in eachindex(sol_s.u) + @test Vector(sol_s.u[t]) ≈ sol_m.u[t] + end + for t in eachindex(sol_s.z) + @test Vector(sol_s.z[t]) ≈ sol_m.z[t] + end + + @test eltype(sol_s.u) <: SVector{2, Float64} + @test eltype(sol_s.z) <: SVector{2, Float64} +end + +@testset "StaticArrays unpruned quadratic" begin + Random.seed!(42) + A_2 = 0.01 * randn(2, 2, 2) + C_2 = 0.01 * randn(2, 2, 2) + noise_vals = [randn(1) for _ in 1:10] + + # Mutable + A0_m = [0.001, -0.001] + A1_m = [0.3 0.1; -0.1 0.3] + B_m = reshape([0.1, 0.0], 2, 1) + C0_m = [0.001, -0.001] + C1_m = [1.0 0.0; 0.0 1.0] + u0_m = zeros(2) + + prob_m = QuadraticStateSpaceProblem( + A0_m, A1_m, A_2, B_m, u0_m, (0, 10); + C_0 = C0_m, C_1 = C1_m, C_2 = C_2, noise = noise_vals + ) + sol_m = solve(prob_m) + + # Static + A0_s = @SVector [0.001, -0.001] + A1_s = @SMatrix [0.3 0.1; -0.1 0.3] + B_s = @SMatrix [0.1; 0.0;;] + C0_s = @SVector [0.001, -0.001] + C1_s = @SMatrix [1.0 0.0; 0.0 1.0] + u0_s = @SVector zeros(2) + noise_s = [SVector{1}(n) for n in noise_vals] + + prob_s = QuadraticStateSpaceProblem( + A0_s, A1_s, A_2, B_s, u0_s, (0, 10); + C_0 = C0_s, C_1 = C1_s, C_2 = C_2, noise = noise_s + ) + sol_s = solve(prob_s) + + for t in eachindex(sol_s.u) + @test Vector(sol_s.u[t]) ≈ sol_m.u[t] + end + for t in eachindex(sol_s.z) + @test Vector(sol_s.z[t]) ≈ sol_m.z[t] + end + + @test eltype(sol_s.u) <: SVector{2, Float64} + @test eltype(sol_s.z) <: SVector{2, Float64} +end + +# --- solve!() vs solve() consistency for StaticArrays --- + +@testset "StaticArrays solve!() vs solve() consistency" begin + using DifferenceEquations: init, solve!, StateSpaceWorkspace + + @testset "linear DirectIteration" begin + A = @SMatrix [0.9 0.1; 0.0 0.8] + B = @SMatrix [0.0; 0.1;;] + C = @SMatrix [1.0 0.0; 0.0 1.0] + u0 = @SVector [0.5, 0.3] + noise = [SVector{1}(randn()) for _ in 1:9] + + prob = LinearStateSpaceProblem(A, B, u0, (0, 9); C, noise) + sol_alloc = solve(prob) + + ws = init(prob, DirectIteration()) + sol_inplace = solve!(ws) + + for t in eachindex(sol_alloc.u) + @test sol_alloc.u[t] ≈ sol_inplace.u[t] + end + for t in eachindex(sol_alloc.z) + @test sol_alloc.z[t] ≈ sol_inplace.z[t] + end + end + + @testset "Kalman filter" begin + Random.seed!(789) + A_raw = randn(3, 3) + A = SMatrix{3, 3}(0.5 * A_raw / maximum(abs.(eigvals(A_raw)))) + B = SMatrix{3, 2}(0.1 * randn(3, 2)) + C = SMatrix{2, 3}(randn(2, 3)) + R = SMatrix{2, 2}(0.01 * I(2)) + mu0 = SVector{3}(zeros(3)) + Sig0 = SMatrix{3, 3}(1.0 * I(3)) + + noise = [SVector{2}(randn(2)) for _ in 1:10] + sim = solve(LinearStateSpaceProblem(A, B, mu0, (0, 10); C, noise)) + y = [sim.z[t + 1] + SVector{2}(0.1 * randn(2)) for t in 1:10] + + prob = LinearStateSpaceProblem( + A, B, SVector{3}(zeros(3)), (0, 10); C, + u0_prior_mean = mu0, u0_prior_var = Sig0, + observables_noise = R, observables = y + ) + sol_alloc = solve(prob, KalmanFilter()) + + ws = init(prob, KalmanFilter()) + sol_inplace = solve!(ws) + + @test sol_alloc.logpdf ≈ sol_inplace.logpdf + for t in eachindex(sol_alloc.u) + @test sol_alloc.u[t] ≈ sol_inplace.u[t] + @test sol_alloc.P[t] ≈ sol_inplace.P[t] + end + end + + @testset "pruned quadratic" begin + Random.seed!(42) + A_2 = 0.01 * randn(2, 2, 2) + C_2 = 0.01 * randn(2, 2, 2) + noise = [SVector{1}(randn()) for _ in 1:10] + + prob = PrunedQuadraticStateSpaceProblem( + @SVector([0.001, -0.001]), @SMatrix([0.3 0.1; -0.1 0.3]), + A_2, @SMatrix([0.1; 0.0;;]), @SVector(zeros(2)), (0, 10); + C_0 = @SVector([0.001, -0.001]), C_1 = @SMatrix([1.0 0.0; 0.0 1.0]), + C_2 = C_2, noise = noise + ) + sol_alloc = solve(prob) + + ws = init(prob, DirectIteration()) + sol_inplace = solve!(ws) + + for t in eachindex(sol_alloc.u) + @test sol_alloc.u[t] ≈ sol_inplace.u[t] + end + for t in eachindex(sol_alloc.z) + @test sol_alloc.z[t] ≈ sol_inplace.z[t] + end + end +end diff --git a/test/utilities/kalman_simulations.jl b/test/utilities/kalman_simulations.jl index 6f9f9c4..b46bb44 100644 --- a/test/utilities/kalman_simulations.jl +++ b/test/utilities/kalman_simulations.jl @@ -1,7 +1,5 @@ -using ChainRulesTestUtils, DifferenceEquations, Distributions, LinearAlgebra, Test, Zygote +using DifferenceEquations, Distributions, LinearAlgebra using DelimitedFiles -using DiffEqBase -using FiniteDiff: finite_difference_gradient A_kalman = [ 0.0495388 0.0109918 0.0960529 0.0767147 0.0404643;