Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
0c791b0
refactor: comment out ChainRulesCore AD code, remove Zygote test deps
jlperla Mar 18, 2026
2d1c4be
feat: add GenericStateSpaceProblem, modernize SciML integration
jlperla Mar 19, 2026
985c847
feat: add !! pattern static tests, function barrier for type stability
jlperla Mar 19, 2026
016375d
refactor: benchmark solve! with pre-allocated caches only
jlperla Mar 19, 2026
c9c5609
feat: add cross-package benchmarks vs differentiable_economics
jlperla Mar 19, 2026
5bac54e
feat: add symbolic indexing for state and observation channels, upgra…
jlperla Mar 19, 2026
824228c
feat: add Enzyme AD tests and benchmarks for Kalman and DirectIteration
jlperla Mar 19, 2026
4c1eb15
refactor: simplify from code review
jlperla Mar 19, 2026
e28ede7
refactor: remove matrix format and Distributions dependency
jlperla Mar 19, 2026
feefa19
Debugging enzyme
jlperla Mar 20, 2026
b84a174
fix: add mul_aat!! workaround for Enzyme syrk adjoint bug
jlperla Mar 22, 2026
1cee4d9
chore: add PkgBenchmark to benchmark environment
jlperla Mar 22, 2026
2c7364e
fix: disable GC during benchmarks to avoid Enzyme reverse-mode segfault
jlperla Mar 22, 2026
9a3da65
refactor: use Julia shorthand keyword arguments where name matches value
jlperla Mar 22, 2026
1510442
refactor: inline loglik into solve(), remove standalone functions
jlperla Mar 23, 2026
ed81e52
feat: rewrite Enzyme tests through solve!() with correct forward/reve…
jlperla Mar 23, 2026
1edf6ca
feat: rewrite AD benchmarks through solve!() with all-Duplicated pattern
jlperla Mar 23, 2026
27cf782
chore: save new benchmark baseline through solve!() path
jlperla Mar 23, 2026
5928873
refactor: revert to immutable problem, use SciML remake() pattern
jlperla Mar 23, 2026
f06acd8
refactor: split workspace into solution output + scratch cache (SciML…
jlperla Mar 23, 2026
d307745
refactor: use @concrete workspace, always-Float64 logpdf, SciML solve…
jlperla Mar 23, 2026
a48403f
refactor: apply enzyme-jl skill rules to AD tests
jlperla Mar 23, 2026
3830e85
refactor: reorganize tests by algorithm, replace Zygote with Enzyme, …
jlperla Mar 24, 2026
353ee0e
test: add solve!() workspace tests across all primal algorithm files
jlperla Mar 24, 2026
6eb4aa8
feat: add edge-case benchmarks + quadratic Enzyme AD (standard formul…
jlperla Mar 24, 2026
302acce
feat: add static quadratic benchmark (bang-bang callbacks, Ref-based …
jlperla Mar 24, 2026
51c1ef6
refactor: fix Enzyme benchmark pattern — construct prob inside wrapper
jlperla Mar 24, 2026
030854e
feat: add QuadraticStateSpaceProblem and PrunedQuadraticStateSpacePro…
jlperla Mar 25, 2026
f71f481
docs: complete documentation revamp following SciML standards
jlperla Mar 25, 2026
add2022
feat: add ForwardDiff AD support with tests, benchmarks, and docs
jlperla Mar 25, 2026
63435b7
refactor: remove unnecessary copy() from Enzyme benchmarks, comment o…
jlperla Mar 25, 2026
521ed5c
fix: add GC teardown to Enzyme benchmarks to prevent OOM
jlperla Mar 27, 2026
57426b6
refactor: require observables_noise as AbstractMatrix, fix docs, add …
jlperla Mar 27, 2026
0719695
fix: replace @test_broken reverse tests with manual gradient checks
jlperla Mar 27, 2026
bbd11e0
chore: remove unused RecursiveArrayTools dep, delete TODO.md
jlperla Mar 27, 2026
2a3814c
style: add Runic pre-commit hook and format all Julia files
jlperla Mar 27, 2026
2fcc477
style: apply Runic formatting to test/ and benchmark/
jlperla Mar 27, 2026
3af5a7b
feat: add Enzyme forward/reverse Kalman benchmarks for StaticArrays
jlperla Mar 28, 2026
8c3903e
feat: full StaticArrays support with bang-bang fixes, tests, and docs
jlperla Mar 28, 2026
7440528
ci: remove Julia nightly from CI matrix
jlperla Mar 28, 2026
604522c
feat: add ConditionalLikelihood algorithm and save_everystep=false
jlperla Mar 30, 2026
98bd85b
refactor: migrate all Enzyme tests to test_forward/test_reverse
jlperla Mar 30, 2026
a0cb5dc
refactor: pass prob as Duplicated in Enzyme tests, observables get ze…
jlperla Mar 30, 2026
76f46d8
fix: use package solve for data generation in CL tutorial
jlperla Mar 30, 2026
863b6bf
fix: resolve all doc build errors
jlperla Mar 30, 2026
6c3bc1c
fix: move gradient_comparison.jl inside CI guard
jlperla Mar 30, 2026
7a0895f
fix: restore DEAlgorithm import in solve.jl after rebase reconciliation
ChrisRackauckas Jun 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .githooks/pre-commit
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/usr/bin/env bash
# Pre-commit hook: format staged Julia files with Runic

JULIA_FILES=$(git diff --cached --name-only --diff-filter=ACM -- '*.jl')

if [ -z "$JULIA_FILES" ]; then
exit 0
fi

julia --project=@runic --startup-file=no -e 'using Runic; exit(Runic.main(ARGS))' -- --inplace $JULIA_FILES

# Re-stage the formatted files
echo "$JULIA_FILES" | xargs git add

exit 0
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ jobs:
version:
- '1'
- 'lts'
- 'pre'
uses: "SciML/.github/.github/workflows/tests.yml@v1"
with:
julia-version: "${{ matrix.version }}"
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ Manifest.toml
/.benchmarkci
/benchmark/*.json
LocalPreferences.toml
docs/build
docs/build
benchmark/results.json
23 changes: 10 additions & 13 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,26 @@ uuid = "e0ca9c66-1f9e-11ec-127a-1304ce62169c"
version = "1.1.0"
authors = ["various contributors"]

[workspace]
projects = ["test", "benchmark"]

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
ChainRulesCore = "1.24"
CommonSolve = "0.2"
DiffEqBase = "6.145, 7"
Distributions = "0.25"
ConcreteStructs = "0.2.3"
DiffEqBase = "6"
LinearAlgebra = "1"
PDMats = "0.11"
PrecompileTools = "1"
RecursiveArrayTools = "2.34, 4"
SciMLBase = "2, 3"
SciMLBase = "2"
StaticArrays = "1"
SymbolicIndexingInterface = "0.3"
UnPack = "1"
julia = "1.8"
julia = "1.10"
11 changes: 10 additions & 1 deletion benchmark/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
DifferenceEquations = "e0ca9c66-1f9e-11ec-127a-1304ce62169c"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[sources]
DifferenceEquations = {path = ".."}
39 changes: 29 additions & 10 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using DifferenceEquations, BenchmarkTools
using Test, LinearAlgebra, Random
using DifferenceEquations, BenchmarkTools, Enzyme
using LinearAlgebra, Random, StaticArrays

# Check if MKL or not
julia_mkl = @static if VERSION < v"1.7"
Expand All @@ -13,14 +13,33 @@ if !julia_mkl
BLAS.set_num_threads(openblas_threads)
end

println("Running Testsuite with Threads.nthreads = $(Threads.nthreads()) MKL = $julia_mkl, and BLAS.num_threads = $(BLAS.get_num_threads()) \n")
println(
"Threads.nthreads = $(Threads.nthreads()), MKL = $julia_mkl, " *
"BLAS.num_threads = $(BLAS.get_num_threads())\n"
)

# Benchmark groups
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 15.0 # 10 seconds per benchmark by default.
BenchmarkTools.DEFAULT_PARAMETERS.evals = 3
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 5.0
BenchmarkTools.DEFAULT_PARAMETERS.evals = 1

const SUITE = BenchmarkGroup()
SUITE["linear"] = include(pkgdir(DifferenceEquations) * "/benchmark/linear.jl")
SUITE["quadratic"] = include(pkgdir(DifferenceEquations) * "/benchmark/quadratic.jl")
# Enzyme reverse-mode AD corrupts GC metadata under repeated invocation, causing segfaults.
# GC disabled globally to prevent GC from running during Enzyme AD.
# Between benchmark samples, Enzyme @benchmarkable calls use a `teardown` to briefly
# re-enable GC, collect, and disable again — safe because Enzyme is not running at that point.
# This prevents OOM from leaked memory accumulating across samples.
# Upstream: https://github.com/EnzymeAD/Enzyme.jl/issues/2355
GC.enable(false)

# results = run(SUITE; verbose = true)
const SUITE = BenchmarkGroup()
const _bdir = joinpath(pkgdir(DifferenceEquations), "benchmark")
SUITE["kalman"] = include(joinpath(_bdir, "enzyme_kalman.jl"))
SUITE["linear_likelihood"] = include(joinpath(_bdir, "enzyme_linear_likelihood.jl"))
SUITE["linear_simulation"] = include(joinpath(_bdir, "enzyme_linear_simulation.jl"))
SUITE["quadratic"] = include(joinpath(_bdir, "enzyme_quadratic.jl"))
SUITE["static_arrays"] = include(joinpath(_bdir, "static_arrays.jl"))
SUITE["ensemble"] = include(joinpath(_bdir, "ensemble.jl"))
SUITE["forwarddiff_kalman"] = include(joinpath(_bdir, "forwarddiff_kalman.jl"))
SUITE["forwarddiff_linear_likelihood"] = include(joinpath(_bdir, "forwarddiff_linear_likelihood.jl"))
SUITE["forwarddiff_linear_simulation"] = include(joinpath(_bdir, "forwarddiff_linear_simulation.jl"))
SUITE["conditional_likelihood"] = include(joinpath(_bdir, "enzyme_conditional_likelihood.jl"))
SUITE["forwarddiff_conditional_likelihood"] = include(joinpath(_bdir, "forwarddiff_conditional_likelihood.jl"))
SUITE["gradient_comparison"] = include(joinpath(_bdir, "gradient_comparison.jl"))
249 changes: 249 additions & 0 deletions benchmark/ensemble.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
# Manual ensemble loop with Enzyme AD
# NOT using EnsembleProblem — Enzyme cannot differentiate through DiffEqBase dispatch.
# Uses construct-inside + solve! in a tight loop over trajectories.
#
# Returns ENS_BENCH BenchmarkGroup

using Enzyme: make_zero, make_zero!
using DifferenceEquations: init, solve!, StateSpaceWorkspace, fill_zero!!

const ENS_BENCH = BenchmarkGroup()
ENS_BENCH["raw"] = BenchmarkGroup()
ENS_BENCH["forward"] = BenchmarkGroup()
ENS_BENCH["reverse"] = BenchmarkGroup()

# =============================================================================
# Problem sizes
# =============================================================================

const p_ens_small = (; N = 2, K = 1, M = 2, T = 10, N_traj = 20)
const p_ens_large = (; N = 5, K = 2, M = 3, T = 50, N_traj = 50)

# =============================================================================
# Problem setup
# =============================================================================

function make_ensemble_benchmark(; N, K, M, T, N_traj, seed = 42)
Random.seed!(seed)
A_raw = randn(N, N)
A = 0.5 * A_raw / maximum(abs.(eigvals(A_raw)))
B = 0.1 * randn(N, K)
C = randn(M, N)
u0 = zeros(N)

# Pre-generate noise for each trajectory
all_noise = [[randn(K) for _ in 1:T] for _ in 1:N_traj]

# Pre-allocate sol/cache for each trajectory (simulation only, no obs)
prob_template = LinearStateSpaceProblem(A, B, u0, (0, T); C, noise = all_noise[1])
all_sol = [deepcopy(init(prob_template, DirectIteration()).output) for _ in 1:N_traj]
all_cache = [deepcopy(init(prob_template, DirectIteration()).cache) for _ in 1:N_traj]

# Shadows for AD
dA = make_zero(A)
dB = make_zero(B)
dC = make_zero(C)
du0 = make_zero(u0)
dall_noise = [[make_zero(all_noise[1][1]) for _ in 1:T] for _ in 1:N_traj]
dall_sol = [make_zero(s) for s in all_sol]
dall_cache = [make_zero(c) for c in all_cache]

return (;
A, B, C, u0, all_noise, all_sol, all_cache,
dA, dB, dC, du0, dall_noise, dall_sol, dall_cache,
)
end

# =============================================================================
# Wrapper functions
# =============================================================================

function ensemble_raw!(A, B, C, u0, all_noise, all_sol, all_cache)
total = 0.0
for i in eachindex(all_noise)
prob = LinearStateSpaceProblem(A, B, u0, (0, length(all_noise[i])); C, noise = all_noise[i])
ws = StateSpaceWorkspace(prob, DirectIteration(), all_sol[i], all_cache[i])
solve!(ws)
total += sum(all_sol[i].u[end])
end
return total / length(all_noise)
end

function ensemble_forward_bench!(A, B, C, u0, all_noise, all_sol, all_cache)
# Same as raw — Enzyme differentiates through this
return ensemble_raw!(A, B, C, u0, all_noise, all_sol, all_cache)
end

function ensemble_scalar!(A, B, C, u0, all_noise, all_sol, all_cache)::Float64
return ensemble_raw!(A, B, C, u0, all_noise, all_sol, all_cache)
end

# =============================================================================
# Instantiate problems
# =============================================================================

const ens_s = make_ensemble_benchmark(; p_ens_small...)
const ens_l = make_ensemble_benchmark(; p_ens_large...)

# =============================================================================
# Raw benchmarks (primal solve through public API)
# =============================================================================

function raw_ens!(A, B, C, u0, all_noise, all_sol, all_cache)
return ensemble_raw!(A, B, C, u0, all_noise, all_sol, all_cache)
end

# Warmup
raw_ens!(
ens_s.A, ens_s.B, ens_s.C, ens_s.u0, ens_s.all_noise,
ens_s.all_sol, ens_s.all_cache
)
raw_ens!(
ens_l.A, ens_l.B, ens_l.C, ens_l.u0, ens_l.all_noise,
ens_l.all_sol, ens_l.all_cache
)

ENS_BENCH["raw"]["small"] = @benchmarkable raw_ens!(
$(ens_s.A), $(ens_s.B), $(ens_s.C), $(ens_s.u0), $(ens_s.all_noise),
$(ens_s.all_sol), $(ens_s.all_cache)
)
ENS_BENCH["raw"]["large"] = @benchmarkable raw_ens!(
$(ens_l.A), $(ens_l.B), $(ens_l.C), $(ens_l.u0), $(ens_l.all_noise),
$(ens_l.all_sol), $(ens_l.all_cache)
)

# =============================================================================
# Forward mode AD — perturb A[1,1], return computed arrays
# =============================================================================

function forward_ensemble_bench!(
A, B, C, u0, all_noise, all_sol, all_cache,
dA, dB, dC, du0, dall_noise, dall_sol, dall_cache
)
# Zero all shadows
dA = fill_zero!!(dA); dB = fill_zero!!(dB); dC = fill_zero!!(dC); du0 = fill_zero!!(du0)
@inbounds for i in eachindex(dall_noise)
for j in eachindex(dall_noise[i])
dall_noise[i][j] = fill_zero!!(dall_noise[i][j])
end
end
@inbounds for i in eachindex(dall_sol)
make_zero!(dall_sol[i])
end
@inbounds for i in eachindex(dall_cache)
make_zero!(dall_cache[i])
end
# Set perturbation direction
dA[1, 1] = 1.0

autodiff(
Forward, ensemble_forward_bench!,
Duplicated(A, dA), Duplicated(B, dB), Duplicated(C, dC),
Duplicated(u0, du0), Duplicated(all_noise, dall_noise),
Duplicated(all_sol, dall_sol),
Duplicated(all_cache, dall_cache)
)
return nothing
end

# Warmup
forward_ensemble_bench!(
copy(ens_s.A), copy(ens_s.B), copy(ens_s.C), copy(ens_s.u0),
[[copy(n) for n in traj] for traj in ens_s.all_noise],
ens_s.all_sol, ens_s.all_cache,
ens_s.dA, ens_s.dB, ens_s.dC, ens_s.du0,
ens_s.dall_noise, ens_s.dall_sol, ens_s.dall_cache
)

ENS_BENCH["forward"]["small"] = @benchmarkable forward_ensemble_bench!(
$(copy(ens_s.A)), $(copy(ens_s.B)), $(copy(ens_s.C)), $(copy(ens_s.u0)),
$([[copy(n) for n in traj] for traj in ens_s.all_noise]),
$(ens_s.all_sol), $(ens_s.all_cache),
$(ens_s.dA), $(ens_s.dB), $(ens_s.dC), $(ens_s.du0),
$(ens_s.dall_noise), $(ens_s.dall_sol), $(ens_s.dall_cache)
)

# Warmup large
forward_ensemble_bench!(
copy(ens_l.A), copy(ens_l.B), copy(ens_l.C), copy(ens_l.u0),
[[copy(n) for n in traj] for traj in ens_l.all_noise],
ens_l.all_sol, ens_l.all_cache,
ens_l.dA, ens_l.dB, ens_l.dC, ens_l.du0,
ens_l.dall_noise, ens_l.dall_sol, ens_l.dall_cache
)

ENS_BENCH["forward"]["large"] = @benchmarkable forward_ensemble_bench!(
$(copy(ens_l.A)), $(copy(ens_l.B)), $(copy(ens_l.C)), $(copy(ens_l.u0)),
$([[copy(n) for n in traj] for traj in ens_l.all_noise]),
$(ens_l.all_sol), $(ens_l.all_cache),
$(ens_l.dA), $(ens_l.dB), $(ens_l.dC), $(ens_l.du0),
$(ens_l.dall_noise), $(ens_l.dall_sol), $(ens_l.dall_cache)
)

# =============================================================================
# Reverse mode AD — all Duplicated, scalar return with Active
# =============================================================================

function reverse_ensemble_bench!(
A, B, C, u0, all_noise, all_sol, all_cache,
dA, dB, dC, du0, dall_noise, dall_sol, dall_cache
)
# Zero all shadows
dA = fill_zero!!(dA); dB = fill_zero!!(dB); dC = fill_zero!!(dC); du0 = fill_zero!!(du0)
@inbounds for i in eachindex(dall_noise)
for j in eachindex(dall_noise[i])
dall_noise[i][j] = fill_zero!!(dall_noise[i][j])
end
end
@inbounds for i in eachindex(dall_sol)
make_zero!(dall_sol[i])
end
@inbounds for i in eachindex(dall_cache)
make_zero!(dall_cache[i])
end

autodiff(
Reverse, ensemble_scalar!, Active,
Duplicated(A, dA), Duplicated(B, dB), Duplicated(C, dC),
Duplicated(u0, du0), Duplicated(all_noise, dall_noise),
Duplicated(all_sol, dall_sol),
Duplicated(all_cache, dall_cache)
)
return nothing
end

# Warmup
reverse_ensemble_bench!(
copy(ens_s.A), copy(ens_s.B), copy(ens_s.C), copy(ens_s.u0),
[[copy(n) for n in traj] for traj in ens_s.all_noise],
ens_s.all_sol, ens_s.all_cache,
ens_s.dA, ens_s.dB, ens_s.dC, ens_s.du0,
ens_s.dall_noise, ens_s.dall_sol, ens_s.dall_cache
)

ENS_BENCH["reverse"]["small"] = @benchmarkable reverse_ensemble_bench!(
$(copy(ens_s.A)), $(copy(ens_s.B)), $(copy(ens_s.C)), $(copy(ens_s.u0)),
$([[copy(n) for n in traj] for traj in ens_s.all_noise]),
$(ens_s.all_sol), $(ens_s.all_cache),
$(ens_s.dA), $(ens_s.dB), $(ens_s.dC), $(ens_s.du0),
$(ens_s.dall_noise), $(ens_s.dall_sol), $(ens_s.dall_cache)
)

# Warmup large
reverse_ensemble_bench!(
copy(ens_l.A), copy(ens_l.B), copy(ens_l.C), copy(ens_l.u0),
[[copy(n) for n in traj] for traj in ens_l.all_noise],
ens_l.all_sol, ens_l.all_cache,
ens_l.dA, ens_l.dB, ens_l.dC, ens_l.du0,
ens_l.dall_noise, ens_l.dall_sol, ens_l.dall_cache
)

ENS_BENCH["reverse"]["large"] = @benchmarkable reverse_ensemble_bench!(
$(copy(ens_l.A)), $(copy(ens_l.B)), $(copy(ens_l.C)), $(copy(ens_l.u0)),
$([[copy(n) for n in traj] for traj in ens_l.all_noise]),
$(ens_l.all_sol), $(ens_l.all_cache),
$(ens_l.dA), $(ens_l.dB), $(ens_l.dC), $(ens_l.du0),
$(ens_l.dall_noise), $(ens_l.dall_sol), $(ens_l.dall_cache)
)

ENS_BENCH
Loading
Loading