diff --git a/Project.toml b/Project.toml index 313084dda..2bc987030 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ReservoirComputing" uuid = "7c2d2b1e-3dd4-11ea-355a-8f6a8116e294" -version = "0.12.23" +version = "0.12.24" authors = ["Francesco Martinuzzi"] [deps] @@ -17,9 +17,11 @@ WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" [weakdeps] CellularAutomata = "878138dc-5b27-11ea-1a71-cb95d38d6b29" +DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [extensions] @@ -27,23 +29,27 @@ RCCellularAutomataExt = "CellularAutomata" RCLIBSVMExt = "LIBSVM" RCLinearSolveExt = "LinearSolve" RCMLJLinearModelsExt = "MLJLinearModels" +RCODEReservoirExt = ["SciMLBase", "DataInterpolations"] RCSparseArraysExt = "SparseArrays" [compat] ArrayInterface = "7.19.0" CellularAutomata = "0.0.6" ConcreteStructs = "0.2.3" +DataInterpolations = "6, 7, 8" DifferentialEquations = "7.16.1, 8" LIBSVM = "0.8" LinearAlgebra = "1.10" LinearSolve = "3.57.0" LuxCore = "1.3.0" MLJLinearModels = "0.9.2, 0.10" -NNlib = "0.9.26" -PrecompileTools = "1" +NNlib = "0.9.30" +OrdinaryDiffEq = "6" +PrecompileTools = "1.2" Random = "1.10" Reexport = "1.2.2" SafeTestsets = "0.1" +SciMLBase = "2.51, 3" SciMLTesting = "1" SparseArrays = "1.10" Static = "1.2.0" @@ -53,14 +59,17 @@ WeightInitializers = "1.0.5" julia = "1.10" [extras] +DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b" MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLTesting = "09d9d899-5365-40a9-917a-5f67fddea283" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "SafeTestsets", "SciMLTesting", "DifferentialEquations", "MLJLinearModels", "LIBSVM", "Statistics", "SparseArrays"] +test = ["Test", "SafeTestsets", "SciMLTesting", "DataInterpolations", "DifferentialEquations", "MLJLinearModels", "LIBSVM", "OrdinaryDiffEq", "SciMLBase", "SparseArrays", "Statistics"] diff --git a/docs/Project.toml b/docs/Project.toml index 430988608..bdba183c7 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,8 @@ [deps] CellularAutomata = "878138dc-5b27-11ea-1a71-cb95d38d6b29" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" +DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656" @@ -10,24 +12,30 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692" OrdinaryDiffEqAdamsBashforthMoulton = "89bda076-bce5-4f1c-845f-551c83cdda9a" +OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReservoirComputing = "7c2d2b1e-3dd4-11ea-355a-8f6a8116e294" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] CellularAutomata = "0.0.6" ConcreteStructs = "0.2" +DataInterpolations = "6, 7, 8" +DelayDiffEq = "6" Documenter = "1" DocumenterCitations = "1" DocumenterInterLinks = "1" JLD2 = "0.6" -OrdinaryDiffEqAdamsBashforthMoulton = "2" -LinearSolve = "3" LIBSVM = "0.8" +LinearSolve = "3" MLJLinearModels = "0.10" +OrdinaryDiffEqAdamsBashforthMoulton = "2" +OrdinaryDiffEqTsit5 = "1, 2" Plots = "1" ReservoirComputing = "0.12.0" +SciMLBase = "2.51, 3" Static = "1" StatsBase = "0.34.4" diff --git a/docs/pages.jl b/docs/pages.jl index de3714034..a15e17918 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -4,6 +4,7 @@ pages = [ "Tutorials" => Any[ "Building a model from scratch" => "tutorials/scratch.md", "Chaos forecasting with an ESN" => "tutorials/lorenz_basic.md", + "Continuous-time reservoirs from a SciMLProblem" => "tutorials/sciml_reservoir.md", "Fitting a Next Generation Reservoir Computer" => "tutorials/ngrc.md", "Deep Echo State Networks" => "tutorials/deep_esn.md", "Training Reservoir Computing Models" => "tutorials/train.md", diff --git a/docs/src/api/layers.md b/docs/src/api/layers.md index 78deac8de..ddd1fb3f3 100644 --- a/docs/src/api/layers.md +++ b/docs/src/api/layers.md @@ -36,24 +36,21 @@ LIFESNCell ``` -## Continuous-time reservoirs +## Wrappers ```@docs - AbstractSciMLProblemReservoir - SciMLProblemReservoir + LocalInformationFlow ``` +## Continuous-Time Reservoirs + ```@docs + AbstractSciMLProblemReservoir + SciMLProblemReservoir AbstractSampler TerminalStateSampling ``` -## Wrappers - -```@docs - LocalInformationFlow -``` - ## Reservoir computing with cellular automata ```@docs diff --git a/docs/src/tutorials/sciml_reservoir.md b/docs/src/tutorials/sciml_reservoir.md new file mode 100644 index 000000000..a6ee9012e --- /dev/null +++ b/docs/src/tutorials/sciml_reservoir.md @@ -0,0 +1,342 @@ +# Continuous-Time Reservoirs from a `SciMLProblem` + +ReservoirComputing.jl exposes a continuous-time reservoir layer, +[`SciMLProblemReservoir`](@ref), that wraps any +`AbstractSciMLProblem` (`ODEProblem`, `SDEProblem`, `DDEProblem`) and +plugs it into the standard `collectstates` / `predict` pipeline. The +implementation lives in the `RCODEReservoirExt` package extension, so +the extension is loaded automatically once `SciMLBase` and +`DataInterpolations` are in scope. The user picks a concrete solver +package (e.g. `OrdinaryDiffEqTsit5`, `OrdinaryDiffEq`) separately — +its solver types are what the reservoir's `args[1]` consumes. + +This page walks through the core type, how time is laid out internally, +and a worked example that checks the continuous reservoir against a +closed-form analytic solution. + +## Loading the extension + +```julia +using ReservoirComputing +using SciMLBase +using DataInterpolations +using OrdinaryDiffEqTsit5 # or `OrdinaryDiffEq`, or whichever solver pkg you need +``` + +`SciMLBase` provides `solve` / `remake`, `DataInterpolations` is used +for the per-window input signal in the autoregressive `predict` path, +and the chosen OrdinaryDiffEq solver package brings the concrete +solver type (`Tsit5()`, `Euler()`, …). + +## Constructing a reservoir from an ODE problem + +The constructor follows the +[DiffEqFlux `NeuralODE` pattern](https://github.com/SciML/DiffEqFlux.jl/blob/master/src/neural_de.jl): + +```julia +SciMLProblemReservoir(prob, sampler, tspan, args...; kwargs...) +``` + +* `prob` — any `AbstractSciMLProblem`. The reservoir's initial state is + taken from `prob.u0`; the ODE right-hand side reads the time-varying + input through `p.input(t)` (injected by the extension at solve time). +* `sampler` — an [`AbstractSampler`](@ref). The bundled + [`TerminalStateSampling`](@ref) records the reservoir state at the + end of each input window. +* `tspan` — overrides `prob.tspan` via `remake` at solve time. The + input column grid is synthesised from `tspan` and the input width. +* `args...` — forwarded to `solve` positionally; the solver algorithm + is the first element (`Tsit5()`, `Euler()`, …). +* `kwargs...` — forwarded to `solve`. The three keys `saveat`, + `save_everystep`, and `dense` are owned by the helper and rejected + at construction. + +The user's ODE can carry static parameters as a `NamedTuple` (or +`nothing` / `NullParameters()` if there are none). Anything else is +rejected with an `ArgumentError`. The reserved name `:input` cannot +appear in `prob.p` because the extension uses it to inject the +interpolated input signal. + +## How time is laid out + +`collectstates` receives a discrete `data::AbstractMatrix` of shape +`(channels, n_samples)` and turns it into a continuous-time input: + +1. `tspan` is split into `n_samples` equal-width windows of width + `Δt = (t1 - t0) / n_samples`. +2. Input column `k` is held over window `k` + (`t0 + (k-1)Δt ≤ t < t0 + kΔt`) via zero-order hold. +3. The reservoir state at the **end** of window `k` + (`t = t0 + kΔt`) is sampled and becomes `states[:, k]`. + +This input-at-start / sample-at-end alignment matches the discrete +update semantics: `states[:, k]` is the reservoir state after +processing input `k`, with no off-by-one offset. + +## A worked example: linear ODE with closed-form solution + +The scalar linear ODE `dx/dt = -x + u(t)` with `u(t) = 1` and +`x(0) = 0` has the closed form `x(t) = 1 - exp(-t)`. Running it +through the continuous reservoir at a tight solver tolerance recovers +that curve to within ~1e-6: + +```julia +using ReservoirComputing +using SciMLBase +using DataInterpolations +using OrdinaryDiffEqTsit5 +using Random + +function linear_rhs!(dx, x, p, t) + input_t = p.input(t) + dx .= .-x .+ input_t +end + +n_samples = 10 +tspan = (0.0, 1.0) +Δt = (tspan[2] - tspan[1]) / n_samples +sample_ts = collect(range(tspan[1] + Δt, tspan[2]; length = n_samples)) + +u_const = 1.0 +data = fill(u_const, 1, n_samples) +initial_state = [0.0] + +prob = ODEProblem(linear_rhs!, initial_state, tspan, (;)) +res = SciMLProblemReservoir( + prob, TerminalStateSampling(), tspan, Tsit5(); + reltol = 1.0e-10, abstol = 1.0e-12 +) +rc = ReservoirComputer(res, LinearReadout(1 => 1)) +ps, st = setup(MersenneTwister(0), rc) + +states, _ = collectstates(rc, data, ps, st) +analytic = u_const .* (1 .- exp.(.-sample_ts)) + +@assert states[1, :] ≈ analytic atol = 1.0e-6 +``` + +## Calling `predict` + +Both `predict` signatures route through the same continuous helper: + +```julia +predict(rc, data, ps, st) # teacher-forced +predict(rc, steps, ps, st; initialdata) # autoregressive rollout +``` + +* The teacher-forced path solves the full `tspan` once and applies the + readout column-by-column to the sampled states. +* The autoregressive path splits `tspan` into `steps` sub-intervals, + feeds the previous readout output back as the constant input on + the next sub-interval, and stitches the per-window readouts into + the returned output matrix. + +In both cases the reservoir's initial state is `prob.u0`. To continue +from a previously computed trajectory, `remake(prob; u0 = …)` before +constructing the reservoir. + +## Eye test: Lorenz chaos forecasting with a continuous ESN + +The README example trains a discrete ESN on the Lorenz attractor and +rolls it forward autoregressively. The same pipeline runs verbatim with +`SciMLProblemReservoir` once you wrap the leaky-integrator continuous +ESN equations + +```math +\frac{dx}{dt} = -x + \tanh\!\left(W_r\, x + W_{in}\, u(t) + b\right) +``` + +as an `ODEProblem`. The reservoir matrices are random, the readout is +linear, and the only training step fits the linear readout on the +collected continuous states. + +```@example ctesn-lorenz +using ReservoirComputing +using SciMLBase +using DataInterpolations +using OrdinaryDiffEqTsit5 +using Plots +using Random + +Random.seed!(42) +rng = MersenneTwister(17) + +# 1. Lorenz data +function lorenz!(du, u, p, t) + du[1] = p[1] * (u[2] - u[1]) + du[2] = u[1] * (p[2] - u[3]) - u[2] + du[3] = u[1] * u[2] - p[3] * u[3] +end +data_prob = ODEProblem(lorenz!, [1.0, 0.0, 0.0], (0.0, 40.0), [10.0, 28.0, 8 / 3]) +data = Array(solve(data_prob, Tsit5(); saveat = 0.02)) + +shift, train_len, predict_len = 300, 1000, 250 +input_data = data[:, shift:(shift + train_len - 1)] +target_data = data[:, (shift + 1):(shift + train_len)] +test = data[:, (shift + train_len):(shift + train_len + predict_len - 1)] + +# 2. Continuous ESN reservoir parameters +N_res = 100 +Wr = 0.3 .* randn(rng, N_res, N_res) ./ sqrt(N_res) +Win = 0.5 .* randn(rng, N_res, 3) +bias = 0.05 .* randn(rng, N_res) +initial_state = zeros(N_res) + +# 3. Raw ODE equations — leaky-integrator continuous ESN +function ctesn_rhs!(dx, x, p, t) + input_t = p.input(t) + return dx .= .-x .+ tanh.(p.Wr * x .+ p.Win * input_t .+ p.b) +end + +# 4. Wrap as SciMLProblemReservoir with Δt = 1 per input window +function build_rc(tspan_len) + prob = ODEProblem( + ctesn_rhs!, initial_state, + (0.0, Float64(tspan_len)), (Wr = Wr, Win = Win, b = bias) + ) + res = SciMLProblemReservoir( + prob, TerminalStateSampling(), + (0.0, Float64(tspan_len)), Tsit5(); + reltol = 1.0e-6, abstol = 1.0e-8 + ) + return ReservoirComputer(res, (NLAT2(),), LinearReadout(N_res => 3)) +end + +rc_train = build_rc(train_len) +rc_predict = build_rc(predict_len) +ps, st = setup(rng, rc_train) + +# 5. Fit the linear readout on the collected continuous states +ps, st = train!(rc_train, input_data, target_data, ps, st) + +# 6. Autoregressive rollout under the same continuous dynamics +ps_pred, st_pred = setup(rng, rc_predict) +ps_pred = merge(ps_pred, (readout = ps.readout,)) +st_pred = merge(st_pred, (readout = st.readout,)) +output, _ = predict(rc_predict, predict_len, ps_pred, st_pred; initialdata = test[:, 1]) + +plot(transpose(output)[:, 1], transpose(output)[:, 2], transpose(output)[:, 3]; + label = "predicted") +plot!(transpose(test)[:, 1], transpose(test)[:, 2], transpose(test)[:, 3]; + label = "actual") +``` + +The two trajectories should agree on the early portion of the rollout +before chaotic divergence — exactly the behaviour the discrete-ESN +example produces. The point of the eye test is that nothing in the +training loop changes: `train!` and `predict` still drive the +`SciMLProblemReservoir` through the same pipeline they use for any +discrete reservoir. + +## A delay-equation target: Mackey-Glass + +`SciMLProblemReservoir` wraps **any** `AbstractSciMLProblem`, so the +training data — and, if you want, the reservoir itself — can come +from a delay-differential equation, a stochastic equation, or any +other SciML problem type. The smallest non-trivial demonstration: +forecast the Mackey-Glass time series, a 1-D delay equation that has +been a reservoir-computing benchmark for two decades. + +```@example ctesn-mg +using ReservoirComputing +using SciMLBase +using DataInterpolations +using OrdinaryDiffEqTsit5 +using DelayDiffEq +using LinearAlgebra +using Plots +using Random +using Statistics + +Random.seed!(42) +rng = MersenneTwister(17) + +# Mackey-Glass: dx/dt = β x(t-τ) / (1 + x(t-τ)^n) - γ x(t). +# With τ = 17 the trajectory is chaotic. +const β_mg, γ_mg, n_mg, τ_mg = 0.2, 0.1, 10, 17.0 +function mackey_glass!(dx, x, h, p, t) + x_delay = h(p, t - τ_mg)[1] + return dx[1] = β_mg * x_delay / (1 + x_delay^n_mg) - γ_mg * x[1] +end +mg_history(p, t) = [1.2] + +mg_data_prob = DDEProblem(mackey_glass!, [1.2], mg_history, (0.0, 1500.0); + constant_lags = [τ_mg]) +mg_data = reduce(hcat, + solve(mg_data_prob, MethodOfSteps(Tsit5()); saveat = 1.0).u) + +shift, train_len, predict_len = 200, 1000, 200 +input_data = mg_data[:, shift:(shift + train_len - 1)] +target_data = mg_data[:, (shift + 1):(shift + train_len)] +test_data = mg_data[:, (shift + train_len):(shift + train_len + predict_len - 1)] + +# Same continuous ESN reservoir machinery as the Lorenz example, retuned +# for the much smaller Mackey-Glass amplitude (~1, vs Lorenz ~20). +N_res = 150 +sparsity = 6 / N_res +Wr_raw = randn(rng, N_res, N_res) .* (rand(rng, N_res, N_res) .< sparsity) +Wr = (0.9 / maximum(abs.(eigvals(Wr_raw)))) .* Wr_raw +Win = 0.05 .* randn(rng, N_res, 1) +bias = 0.0 .* randn(rng, N_res) +initial_state = zeros(N_res) + +function mg_reservoir_rhs!(dx, x, p, t) + input_t = p.input(t) + return dx .= .-x .+ tanh.(p.Wr * x .+ p.Win * input_t .+ p.b) +end + +function build_mg_rc(n_steps) + tspan = (0.0, n_steps * 1.0) + prob = ODEProblem(mg_reservoir_rhs!, initial_state, tspan, + (Wr = Wr, Win = Win, b = bias)) + res = SciMLProblemReservoir(prob, TerminalStateSampling(), tspan, Tsit5(); + reltol = 1.0e-6, abstol = 1.0e-8) + return ReservoirComputer(res, (NLAT2(),), LinearReadout(N_res => 1)) +end + +rc_mg_train = build_mg_rc(train_len) +rc_mg_predict = build_mg_rc(predict_len) +ps_mg, st_mg = setup(rng, rc_mg_train) +ps_mg, st_mg = train!(rc_mg_train, input_data, target_data, ps_mg, st_mg, + StandardRidge(1.0e-6); washout = 0) + +ps_pred, st_pred = setup(rng, rc_mg_predict) +ps_pred = merge(ps_pred, (readout = ps_mg.readout,)) +st_pred = merge(st_pred, (readout = st_mg.readout,)) +mg_output, _ = predict(rc_mg_predict, predict_len, ps_pred, st_pred; + initialdata = test_data[:, 1]) + +plot([test_data[1, :], mg_output[1, :]]; + label = ["actual" "predicted"], linewidth = 2, + xlabel = "step", ylabel = "x(t)", + title = "Mackey-Glass (τ=17) — continuous ESN rollout") +``` + +Two things to notice: + +* The **data path** uses a `DDEProblem` solved with + `MethodOfSteps(Tsit5())` — no special handling on + `SciMLProblemReservoir`'s side; the wrapper only cares about the + shape of the resulting matrix. +* The **reservoir** is kept as an `ODEProblem` for simplicity. Because + `SciMLProblemReservoir`'s `prob` field is untyped, it would equally + accept a `DDEProblem` of the form + `dx/dt = -x(t) + tanh(W_r x(t-τ_r) + W_{in} u(t) + b)` — useful when + the target has long-range temporal correlations. Delay-coupled + reservoirs of that form are explored in the CTESN/delay-reservoir + literature; a tuned implementation will land in PR3. + +As with the Lorenz example, this is a demonstration of the new +plumbing rather than an optimised benchmark: hyperparameters were +tuned by hand to land a watchable forecast, not chosen via +cross-validation. + +## Adding your own sampler + +The reservoir state sequence the readout sees is produced by an +[`AbstractSampler`](@ref). To plug in a custom strategy (window mean, +sub-sampling within a window, etc.), define a concrete subtype and a +matching `_sample(::YourSampler, sol)` method inside an extension that +also loads `OrdinaryDiffEq` and `SciMLBase`. The method should return +a `(state_dim, n_samples)` matrix; everything downstream (state +modifiers, readout, predict) is sampler-agnostic. diff --git a/ext/RCODEReservoirExt.jl b/ext/RCODEReservoirExt.jl new file mode 100644 index 000000000..661f21cfa --- /dev/null +++ b/ext/RCODEReservoirExt.jl @@ -0,0 +1,392 @@ +module RCODEReservoirExt + +using DataInterpolations: ConstantInterpolation +using LuxCore: apply +# `solve` and `remake` come from `SciMLBase`. The user picks the concrete +# solver type (e.g. `Tsit5()`) and loads its package separately +# (`OrdinaryDiffEqTsit5`, `OrdinaryDiffEq`, …); dispatch at solve time +# selects the right method via the type they passed in `res.args[1]`. We +# deliberately don't list a solver package as a weakdep trigger so users +# aren't forced to pull the full `OrdinaryDiffEq` meta-package in. +using SciMLBase: remake, solve, NullParameters + +using ReservoirComputing: ReservoirComputing, + AbstractReservoirComputer, + AbstractSampler, + AbstractSciMLProblemReservoir, + TerminalStateSampling, + collectstates +import ReservoirComputing: _collectstates, _predict + +# --------------------------------------------------------------------------- +# Parameter assembly +# +# At solve time the extension must hand the ODE three things: +# (1) the interpolated input signal `u(t)` exposed as `p.input`, +# (2) the user's static parameters (if any) from `prob.p`, +# (3) any Lux-managed reservoir parameters from `ps.reservoir`. +# +# `_to_namedtuple` normalises `prob.p` so that nothing / `NullParameters` / a +# user `NamedTuple` all collapse into a single `NamedTuple` we can merge into. +# Anything else is rejected with a clear error — wrapping unknown payloads +# silently would hide bugs in user-defined ODEs. +# --------------------------------------------------------------------------- + +_to_namedtuple(prob_p::NamedTuple) = prob_p +_to_namedtuple(::NullParameters) = NamedTuple() +_to_namedtuple(::Nothing) = NamedTuple() +function _to_namedtuple(prob_p) + return throw( + ArgumentError( + "SciMLProblemReservoir requires `prob.p` to be a NamedTuple, " * + "`nothing`, or `SciMLBase.NullParameters()`, got $(typeof(prob_p)). " * + "Wrap your parameters in a NamedTuple — the extension injects " * + "`input` on top before calling `solve`." + ) + ) +end + +function _build_solve_params(prob_p, ps_reservoir, input_interp) + base = _to_namedtuple(prob_p) + # `:input` is the reserved key the extension injects so the user's ODE + # right-hand side can read `p.input(t)`. A pre-existing `:input` field + # in either `prob.p` or `ps.reservoir` would be silently shadowed below, + # which is exactly the silent-failure surface we want to avoid. + if haskey(base, :input) + throw( + ArgumentError( + "`prob.p` already contains an `:input` field. The continuous " * + "reservoir extension reserves that name for the interpolated " * + "input signal it injects at solve time. Rename the field in " * + "your ODE problem before constructing the reservoir." + ) + ) + end + if !isempty(ps_reservoir) && haskey(ps_reservoir, :input) + throw( + ArgumentError( + "`ps.reservoir` already contains an `:input` field. That name is " * + "reserved for the extension's interpolated input signal — " * + "rename your reservoir parameter." + ) + ) + end + merged = isempty(ps_reservoir) ? base : merge(base, ps_reservoir) + return merge(merged, (input = input_interp,)) +end + +# --------------------------------------------------------------------------- +# Input signal construction +# +# `collectstates` sees a discrete `data::AbstractMatrix` and reconstructs the +# continuous-time input via linear interpolation between input columns. The +# grid mirrors the `saveat` grid so an input column and its corresponding +# state sample share the same time stamp. +# +# `_make_const_input_fn` is the closed-loop counterpart used inside the +# autoregressive `predict`: between two reservoir-output events the input is +# the previous output, held constant. +# --------------------------------------------------------------------------- + +""" + ZeroOrderHoldInterp(data, ts) + +Piecewise-constant input signal for the continuous reservoir. Holds a +`data::AbstractMatrix` of shape `(channels, T)` alongside the matching +time-stamp vector `ts`. For `t` in window `k` (i.e. `ts[k] ≤ t < ts[k+1]`) +the call returns `view(data, :, k)`; out-of-range times clamp to the +nearest endpoint. + +We pick zero-order hold (ZOH) over linear interpolation deliberately: +under linear interpolation the reservoir state at sample time `sample_ts[k]` +depends on both `data[:, k]` and `data[:, k+1]` for any non-Euler solver, +which is a one-step lookahead that contradicts the documented "state +after processing input k" semantics. With ZOH, `data[:, k]` is the only +input column that influences `states[:, k]`, regardless of solver — and +the autoregressive `predict` path already uses ZOH for its per-window +input function, so the two paths now use the same scheme. + +Why not `DataInterpolations.ConstantInterpolation`: matrix-valued `u` +has no `_integral` method, so `cache_parameters=true` fails at +construction; the default `cache_parameters=false` leaves unused cache +fields typed as `Vector{Union{}}`, which SciMLBase's dual-eltype probing +crashes on while preparing `solve` (observed on DataInterpolations v8 / +SciMLBase v2, 2026-06). A bespoke struct with concrete fields and a +view-returning call sidesteps both paths and is allocation-free in the +ODE hot path. Revisit if/when DataInterpolations supports matrix-`u` +non-cached construction without the bottom-type fallout. +""" +struct ZeroOrderHoldInterp{D <: AbstractMatrix, T <: AbstractVector} + data::D + ts::T +end + +function (interp::ZeroOrderHoldInterp)(t) + ts = interp.ts + n_samples = length(ts) + t < ts[1] && return view(interp.data, :, 1) + t ≥ ts[end] && return view(interp.data, :, n_samples) + window_idx = searchsortedlast(ts, t) + return view(interp.data, :, clamp(window_idx, 1, n_samples)) +end + +function _make_input_fn(data::AbstractMatrix, ts::AbstractVector) + return ZeroOrderHoldInterp(data, ts) +end + +function _make_const_input_fn(u_vec::AbstractVector, t_lo, t_hi) + # `cache_parameters=true` is fine for vector u (autoregressive predict + # always holds the previous readout output constant over one sub-interval). + return ConstantInterpolation([u_vec, u_vec], [t_lo, t_hi]; cache_parameters = true) +end + +# --------------------------------------------------------------------------- +# Samplers +# +# A sampler maps a continuous trajectory into the discrete state matrix the +# readout sees. `TerminalStateSampling` reads the solution exactly at the +# user-visible time grid (the same one we pass through `saveat`), so the +# result is just the columnar view of `sol.u`. +# --------------------------------------------------------------------------- + +function _sample(::TerminalStateSampling, sol) + return reduce(hcat, sol.u) +end + +# --------------------------------------------------------------------------- +# State-modifier composition +# +# The discrete fallback threads `states_modifiers` per reservoir step (see +# `_partial_apply` in `reservoircomputer.jl`). For the continuous path we +# evolve the trajectory first and then apply modifiers column-by-column to +# the sampled matrix. This keeps the per-sample semantics identical to the +# discrete code without contaminating the ODE right-hand side. +# --------------------------------------------------------------------------- + +function _apply_modifiers_continuous( + modifiers::Tuple, states_matrix::AbstractMatrix, ps_mods, st_mods + ) + isempty(modifiers) && return states_matrix, st_mods + n_samples = size(states_matrix, 2) + src_cols = eachcol(states_matrix) + + first_col, new_st = ReservoirComputing._apply_seq( + modifiers, first(src_cols), ps_mods, st_mods + ) + # `similar(first_col, ...)` — not `similar(states_matrix, ...)` — so the + # output matrix takes the modifier output's eltype. If a modifier + # promotes/demotes (e.g. Float32 → Float64), we want that to surface, + # not be silently truncated back to the reservoir state's eltype. + output = similar(first_col, length(first_col), n_samples) + output[:, 1] .= first_col + for (idx, src_col) in Iterators.drop(enumerate(src_cols), 1) + modified_col, new_st = ReservoirComputing._apply_seq( + modifiers, src_col, ps_mods, new_st + ) + output[:, idx] .= modified_col + end + return output, new_st +end + +# --------------------------------------------------------------------------- +# Continuous `_collectstates` +# +# Pipeline: +# 1. Split `res.tspan` into `n_samples` equal-width windows. +# 2. Place input column `k` at the *start* of window `k` (time +# `t0 + (k-1)Δt`) and request a sample at the *end* of window `k` +# (time `t0 + kΔt`). This alignment matches the discrete reservoir +# semantics — `states[:, k]` is the state after processing input `k` +# — and is what makes the Euler-equivalence test land without an +# off-by-one shift. +# 3. `remake` the user's problem with the locked `tspan` and the merged +# parameter pack (interpolated input injected as `p.input`). +# 4. `solve(...; saveat = sample_ts, save_everystep=false, dense=false)`. +# `res.kwargs` come last so user kwargs win on collision — the +# constructor already rejects the three protected keys, so they +# cannot collide in practice. +# 5. Push the trajectory through the sampler → raw state matrix. +# 6. Apply state modifiers → final state matrix matching the discrete +# `(state_dims, n_samples)` shape expected by the readout. +# --------------------------------------------------------------------------- + +function _collectstates( + res::AbstractSciMLProblemReservoir, + rc::AbstractReservoirComputer, + data::AbstractMatrix, + ps::NamedTuple, + st::NamedTuple + ) + n_samples = size(data, 2) + n_samples ≥ 2 || throw( + ArgumentError( + "SciMLProblemReservoir collectstates needs at least 2 input " * + "columns to define a time grid; got $n_samples." + ) + ) + + t0, t1 = res.tspan + t1 > t0 || throw( + ArgumentError( + "SciMLProblemReservoir requires `tspan[2] > tspan[1]`, got " * + "tspan = ($t0, $t1). Continuous integration is only defined " * + "over a strictly positive interval." + ) + ) + + Δt = (t1 - t0) / n_samples + input_ts = collect(range(t0, t1 - Δt; length = n_samples)) + sample_ts = collect(range(t0 + Δt, t1; length = n_samples)) + + input_interp = _make_input_fn(data, input_ts) + solve_p = _build_solve_params(res.prob.p, ps.reservoir, input_interp) + + prob_remade = remake(res.prob; tspan = res.tspan, p = solve_p) + + sol = solve( + prob_remade, res.args...; + saveat = sample_ts, + save_everystep = false, + dense = false, + res.kwargs... + ) + + raw_states = _sample(res.sampler, sol) + modified_states, st_mods = _apply_modifiers_continuous( + rc.states_modifiers, raw_states, ps.states_modifiers, st.states_modifiers + ) + + newst = ( + reservoir = st.reservoir, + states_modifiers = st_mods, + readout = st.readout, + ) + return modified_states, newst +end + +# --------------------------------------------------------------------------- +# Teacher-forced `predict` +# +# Solve once over the whole tspan, then apply the readout column-by-column. +# Cheaper than the autoregressive path because the ODE never has to be +# restarted between samples. +# --------------------------------------------------------------------------- + +function _predict( + ::AbstractSciMLProblemReservoir, + rc::AbstractReservoirComputer, + data::AbstractMatrix, + ps::NamedTuple, + st::NamedTuple + ) + states, new_st = collectstates(rc, data, ps, st) + n_samples = size(states, 2) + st_ro = new_st.readout + state_cols = eachcol(states) + first_output, st_ro = apply(rc.readout, first(state_cols), ps.readout, st_ro) + outputs = similar(first_output, size(first_output, 1), n_samples) + outputs[:, 1] .= first_output + for (idx, state_col) in Iterators.drop(enumerate(state_cols), 1) + current_output, st_ro = apply(rc.readout, state_col, ps.readout, st_ro) + outputs[:, idx] .= current_output + end + return outputs, merge(new_st, (readout = st_ro,)) +end + +# --------------------------------------------------------------------------- +# Autoregressive `predict` +# +# Split `tspan` into `steps` equal sub-intervals. For each sub-interval the +# input is the previous readout output, held constant via a +# `ConstantInterpolation`. After each sub-solve we: +# - sample the terminal state, +# - apply state modifiers (per-sample, consistent with the discrete loop), +# - apply the readout, +# - feed the output back as the next input. +# +# The initial reservoir state is `res.prob.u0`; users who want to continue +# from a previously computed trajectory should `remake(prob; u0 = …)` before +# constructing the reservoir. +# --------------------------------------------------------------------------- + +function _predict( + res::AbstractSciMLProblemReservoir, + rc::AbstractReservoirComputer, + steps::Integer, + ps::NamedTuple, + st::NamedTuple; + initialdata::AbstractVector + ) + steps ≥ 1 || throw(ArgumentError("steps must be ≥ 1, got $steps")) + + t0, t1 = res.tspan + t1 > t0 || throw( + ArgumentError( + "Autoregressive predict requires `tspan[2] > tspan[1]`, got " * + "tspan = ($t0, $t1)." + ) + ) + ts = collect(range(t0, t1; length = steps + 1)) + window_starts = @view ts[1:(end - 1)] + window_ends = @view ts[2:end] + + # Preserve `u0`'s original type — `collect` would degrade `SVector` / + # `ComponentArray` / scalar states into a plain `Vector` and either + # error (no `collect(::Number)` method) or silently flatten the + # user's chosen representation. We only ever read `current_state`, + # never mutate it in place, so a direct reference is safe. + current_state = res.prob.u0 + current_input = initialdata + + st_mods = st.states_modifiers + st_ro = st.readout + + # `outputs` is allocated *after* the first readout call so its element + # type and row count come from `apply(rc.readout, …)` rather than + # `initialdata`. Otherwise a readout returning a different eltype + # (e.g. Float64 vs the Float32 input) would force a silent + # conversion at the column assignment. + local outputs + for (step_idx, (t_lo, t_hi)) in enumerate(zip(window_starts, window_ends)) + input_fn = _make_const_input_fn(current_input, t_lo, t_hi) + solve_p = _build_solve_params(res.prob.p, ps.reservoir, input_fn) + sub_prob = remake( + res.prob; + tspan = (t_lo, t_hi), + p = solve_p, + u0 = current_state + ) + sol = solve( + sub_prob, res.args...; + saveat = [t_hi], + save_everystep = false, + dense = false, + res.kwargs... + ) + current_state = sol.u[end] + + if !isempty(rc.states_modifiers) + state_after_mods, st_mods = ReservoirComputing._apply_seq( + rc.states_modifiers, current_state, ps.states_modifiers, st_mods + ) + else + state_after_mods = current_state + end + + current_output, st_ro = apply(rc.readout, state_after_mods, ps.readout, st_ro) + if step_idx == 1 + outputs = similar(current_output, length(current_output), steps) + end + outputs[:, step_idx] .= current_output + current_input = current_output + end + + newst = ( + reservoir = st.reservoir, + states_modifiers = st_mods, + readout = st_ro, + ) + return outputs, newst +end + +end # module diff --git a/src/layers/sciml_reservoir.jl b/src/layers/sciml_reservoir.jl index 212f87d11..1558588e2 100644 --- a/src/layers/sciml_reservoir.jl +++ b/src/layers/sciml_reservoir.jl @@ -11,9 +11,14 @@ abstract type AbstractSampler end """ TerminalStateSampling() -Sample the continuous-time reservoir's state at the terminal time of each -input window. This is the continuous analogue of the standard discrete update: -one reservoir state per input column. +Sampler that records the reservoir state at the *end* of each input window. +With `T` input columns and `tspan = (t0, t1)`, `collectstates` splits +`tspan` into `T` equal-width windows; input column `k` is applied at the +start of window `k` (time `t0 + (k-1)Δt`) and the state at the end of that +window (time `t0 + kΔt`) becomes the `k`-th column of the returned state +matrix. This is the continuous analogue of the discrete update: one state +per input column, with `states[:, k]` representing the reservoir's state +after having processed input `k`. """ struct TerminalStateSampling <: AbstractSampler end @@ -26,8 +31,10 @@ Concrete subtypes provide `_collectstates` methods that run the solver and hand back a state matrix to the readout. The continuous-time `_collectstates` implementation lives in the -`RCODEReservoirExt` package extension and requires `OrdinaryDiffEq`, -`SciMLBase`, and `DataInterpolations` to be loaded. +`RCODEReservoirExt` package extension and requires `SciMLBase` and +`DataInterpolations` to be loaded. Pick any concrete solver package +separately (e.g. `OrdinaryDiffEqTsit5`, `OrdinaryDiffEq`) — its solver +types are what `SciMLProblemReservoir`'s `args[1]` consumes. """ abstract type AbstractSciMLProblemReservoir <: AbstractLuxLayer end @@ -51,13 +58,16 @@ construction time and forwarded to `solve` when `collectstates` runs. - `args...`: positional arguments forwarded to `solve`. The solver algorithm (e.g. `Tsit5()`) is the first element by convention. - `kwargs...`: keyword arguments forwarded to `solve`. The continuous helper - may apply its own protected keys (`saveat`, `save_everystep`, `dense`) and - errors at construction if these collide with user-provided kwargs. + owns three protected keys — `saveat`, `save_everystep`, and `dense` — because + `collectstates` needs to synthesise a sample grid from `tspan` and the input + width. Passing any of them at construction errors immediately. The real `_collectstates` implementation lives in the `RCODEReservoirExt` package extension. Without it loaded, calling `collectstates` on a reservoir computer holding a `SciMLProblemReservoir` will error with a -message instructing the user to `using OrdinaryDiffEq`. +message instructing the user to load `SciMLBase` and `DataInterpolations` +(plus a concrete solver package — `OrdinaryDiffEqTsit5`, `OrdinaryDiffEq`, +…). """ @concrete struct SciMLProblemReservoir <: AbstractSciMLProblemReservoir prob @@ -67,12 +77,36 @@ message instructing the user to `using OrdinaryDiffEq`. kwargs end +# Keyword arguments owned by the continuous `_collectstates` helper: +# - `saveat` is derived from `tspan` and the input width, so a user value +# would silently desync the sample grid from the input grid. +# - `save_everystep` and `dense` are hardcoded to `false` because the +# sampler only ever reads `sol.u` at the `saveat` points; allocating +# the full trajectory would waste memory without changing the result. +# All three are rejected at construction so the user finds out immediately +# rather than getting a wrong-shape state matrix at solve time. Internal — +# not a docstring so Documenter's `:missing_docs` check leaves it alone. +const _PROTECTED_SOLVE_KWARGS = (:saveat, :save_everystep, :dense) + +function _check_protected_kwargs(kwargs) + collisions = filter(key -> key in _PROTECTED_SOLVE_KWARGS, keys(kwargs)) + isempty(collisions) && return nothing + return throw( + ArgumentError( + "SciMLProblemReservoir rejects $(collect(collisions)) in `kwargs`: " * + "these keys are set by `collectstates` from `tspan` and the input " * + "data width. Drop them from the constructor call." + ) + ) +end + function SciMLProblemReservoir(prob, sampler, tspan, args...; kwargs...) # No type constraint on `sampler` here: a constrained outer constructor # makes this method strictly more specific than the inner constructor # generated by `@concrete`, causing infinite recursion at the 5-arg # call below. The DiffEqFlux NeuralDE pattern keeps these arguments # untyped for the same reason. + _check_protected_kwargs(kwargs) return SciMLProblemReservoir(prob, sampler, tspan, args, kwargs) end diff --git a/src/predict.jl b/src/predict.jl index 97e97bd1d..acc24fbce 100644 --- a/src/predict.jl +++ b/src/predict.jl @@ -46,7 +46,7 @@ sequence. ### Returns - `output`: Outputs for each input column, shape `(out_dims, T)`. -- `st`: Updated minal model states. +- `st`: Updated final model states. """ function predict( rc::AbstractLuxLayer, @@ -74,3 +74,86 @@ function predict(rc::AbstractLuxLayer, data::AbstractMatrix, ps, st) end return Y, st end + +# Two-level dispatch on the reservoir field, mirroring `collectstates` / `_collectstates`. +# Continuous reservoirs (`AbstractSciMLProblemReservoir`) plug in their own `_predict` +# methods from `RCODEReservoirExt`; everything else hits the fallbacks below, which +# replicate the discrete `predict(::AbstractLuxLayer, …)` bodies above. +# +# Not every `AbstractReservoirComputer` subtype carries a `:reservoir` field — +# `DeepESN`, for instance, owns a tuple of cells under `:cells`. For those +# subtypes we cannot extract a "reservoir layer" to dispatch on, so we pass +# `nothing` and let the `::Any` fallback take the discrete loop. (Concrete +# types like `DeepESN` already provide their own specialised `collectstates`, +# and `predict` itself only depends on `apply(rc, …)`, which works through +# their own `(rc::DeepESN)(…)` call.) + +function predict( + rc::AbstractReservoirComputer, steps::Integer, ps, st; + initialdata::AbstractVector + ) + res = hasfield(typeof(rc), :reservoir) ? rc.reservoir : nothing + return _predict(res, rc, steps, ps, st; initialdata = initialdata) +end + +function predict(rc::AbstractReservoirComputer, data::AbstractMatrix, ps, st) + res = hasfield(typeof(rc), :reservoir) ? rc.reservoir : nothing + return _predict(res, rc, data, ps, st) +end + +function _predict( + ::AbstractSciMLProblemReservoir, + ::AbstractReservoirComputer, ::Integer, ::Any, ::Any; + initialdata::AbstractVector + ) + return error( + "Autoregressive `predict(rc, steps, ps, st; initialdata)` for a " * + "`SciMLProblemReservoir` requires the `RCODEReservoirExt` extension. " * + "Load `SciMLBase` and `DataInterpolations` (plus an OrdinaryDiffEq " * + "solver package — `OrdinaryDiffEqTsit5`, `OrdinaryDiffEq`, …) to enable it." + ) +end + +function _predict( + ::AbstractSciMLProblemReservoir, + ::AbstractReservoirComputer, ::AbstractMatrix, ::Any, ::Any + ) + return error( + "Teacher-forced `predict(rc, data, ps, st)` for a " * + "`SciMLProblemReservoir` requires the `RCODEReservoirExt` extension. " * + "Load `SciMLBase` and `DataInterpolations` (plus an OrdinaryDiffEq " * + "solver package — `OrdinaryDiffEqTsit5`, `OrdinaryDiffEq`, …) to enable it." + ) +end + +function _predict( + ::Any, rc::AbstractReservoirComputer, steps::Integer, ps, st; + initialdata::AbstractVector + ) + output = zeros(eltype(initialdata), length(initialdata), steps) + for step in 1:steps + initialdata, st = apply(rc, initialdata, ps, st) + output[:, step] = initialdata + end + return output, st +end + +function _predict(::Any, rc::AbstractReservoirComputer, data::AbstractMatrix, ps, st) + n_samples = size(data, 2) + n_samples ≥ 1 || throw( + ArgumentError( + "predict input data must have at least one column, got $n_samples." + ) + ) + + input_cols = eachcol(data) + first_output, st = apply(rc, first(input_cols), ps, st) + outputs = similar(first_output, size(first_output, 1), n_samples) + outputs[:, 1] .= first_output + + for (idx, input_col) in Iterators.drop(enumerate(input_cols), 1) + current_output, st = apply(rc, input_col, ps, st) + outputs[:, idx] .= current_output + end + return outputs, st +end diff --git a/test/test_ode_reservoir_ext.jl b/test/test_ode_reservoir_ext.jl new file mode 100644 index 000000000..924664b01 --- /dev/null +++ b/test/test_ode_reservoir_ext.jl @@ -0,0 +1,397 @@ +using Test +using Random +using LinearAlgebra +using ReservoirComputing +using OrdinaryDiffEq +using SciMLBase +using DataInterpolations + +# --------------------------------------------------------------------------- +# Helpers — small ESN-shaped ODE used across several tests. +# --------------------------------------------------------------------------- + +function esn_rhs!(dx, x, p, t) + input_t = p.input(t) + return dx .= .-x .+ tanh.(p.Wr * x .+ p.Win * input_t .+ p.b) +end + +function build_esn_problem(rng, in_dim, res_dim, tspan) + Wr = 0.2 .* randn(rng, res_dim, res_dim) + Win = 0.5 .* randn(rng, res_dim, in_dim) + bias = 0.1 .* randn(rng, res_dim) + initial_state = zeros(res_dim) + params = (Wr = Wr, Win = Win, b = bias) + return ODEProblem(esn_rhs!, initial_state, tspan, params), + Wr, Win, bias, initial_state +end + +# --------------------------------------------------------------------------- +# 1. Linear ODE — analytic match +# +# Trivial ODE dx/dt = -x + u, u constant, x(0) = 0 has the closed form +# x(t) = u (1 - exp(-t)). Verify that the continuous helper recovers the +# analytic curve to within a tight solver tolerance. With the corrected +# saveat alignment, the first sample is at `t = Δt`, not `t = 0`. +# --------------------------------------------------------------------------- + +@testset "Linear ODE analytic match" begin + function lin_rhs!(dx, x, p, t) + input_t = p.input(t) + return dx .= .-x .+ input_t + end + + T_steps = 10 + tspan = (0.0, 1.0) + Δt = (tspan[2] - tspan[1]) / T_steps + sample_ts = collect(range(tspan[1] + Δt, tspan[2]; length = T_steps)) + u_const = 1.0 + data = fill(u_const, 1, T_steps) + initial_state = [0.0] + + prob = ODEProblem(lin_rhs!, initial_state, tspan, (;)) + res = SciMLProblemReservoir( + prob, TerminalStateSampling(), tspan, Tsit5(); + reltol = 1.0e-10, abstol = 1.0e-12 + ) + rc = ReservoirComputer(res, LinearReadout(1 => 1)) + ps, st = setup(MersenneTwister(0), rc) + + states, _ = collectstates(rc, data, ps, st) + analytic = u_const .* (1 .- exp.(.-sample_ts)) + + @test size(states) == (1, T_steps) + @test states[1, :] ≈ analytic atol = 1.0e-6 +end + +# --------------------------------------------------------------------------- +# 2. Euler equivalence +# +# Solving dx/dt = -x + tanh(Wr x + Win u(t) + b) with explicit Euler at +# step size dt = 1 collapses algebraically to the discrete reservoir update +# x_{k+1} = tanh(Wr x_k + Win u_{k+1} + b). With the corrected alignment +# (inputs at window starts, samples at window ends), there is no off-by-one +# between continuous and discrete trajectories. +# --------------------------------------------------------------------------- + +@testset "Euler equivalence with discrete reservoir update" begin + rng = MersenneTwister(7) + in_dim, res_dim, T_steps = 2, 6, 12 + + prob, Wr, Win, bias, initial_state = build_esn_problem( + rng, in_dim, res_dim, + (0.0, Float64(T_steps)) + ) + data = randn(rng, in_dim, T_steps) + + res = SciMLProblemReservoir( + prob, TerminalStateSampling(), + (0.0, Float64(T_steps)), Euler(); + dt = 1.0 + ) + rc = ReservoirComputer(res, LinearReadout(res_dim => 1)) + ps, st = setup(MersenneTwister(0), rc) + + cont_states, _ = collectstates(rc, data, ps, st) + + disc_states = zeros(res_dim, T_steps) + state = copy(initial_state) + for (step_idx, input_col) in enumerate(eachcol(data)) + state = tanh.(Wr * state + Win * input_col + bias) + disc_states[:, step_idx] = state + end + + @test size(cont_states) == (res_dim, T_steps) + @test cont_states ≈ disc_states atol = 1.0e-10 +end + +# --------------------------------------------------------------------------- +# 3. Sampler shape contract +# +# `TerminalStateSampling` must produce a `(state_dim, T_input)` matrix +# regardless of solver. Guards against accidental transposition or extra +# rows/columns sneaking in via `reduce(hcat, sol.u)`. +# --------------------------------------------------------------------------- + +@testset "TerminalStateSampling output shape" begin + rng = MersenneTwister(11) + in_dim, res_dim, T_steps = 3, 8, 20 + tspan = (0.0, 2.0) + prob, _, _, _, _ = build_esn_problem(rng, in_dim, res_dim, tspan) + data = randn(rng, in_dim, T_steps) + res = SciMLProblemReservoir( + prob, TerminalStateSampling(), tspan, Tsit5(); + reltol = 1.0e-8, abstol = 1.0e-10 + ) + rc = ReservoirComputer(res, LinearReadout(res_dim => 1)) + ps, st = setup(MersenneTwister(0), rc) + + states, _ = collectstates(rc, data, ps, st) + @test size(states) == (res_dim, T_steps) + @test all(isfinite, states) +end + +# --------------------------------------------------------------------------- +# 4. Teacher-forced predict +# +# `predict(rc, data, ps, st)` runs one bulk ODE solve and applies the +# readout column-by-column. Output dims must match the readout's +# `out_dims`, and the result must be deterministic. +# --------------------------------------------------------------------------- + +@testset "Teacher-forced predict" begin + rng = MersenneTwister(23) + in_dim, res_dim, out_dim, T_steps = 2, 6, 4, 15 + tspan = (0.0, 3.0) + prob, _, _, _, _ = build_esn_problem(rng, in_dim, res_dim, tspan) + data = randn(rng, in_dim, T_steps) + res = SciMLProblemReservoir( + prob, TerminalStateSampling(), tspan, Tsit5(); + reltol = 1.0e-8, abstol = 1.0e-10 + ) + rc = ReservoirComputer(res, LinearReadout(res_dim => out_dim)) + ps, st = setup(MersenneTwister(0), rc) + + preds1, _ = predict(rc, data, ps, st) + preds2, _ = predict(rc, data, ps, st) + @test size(preds1) == (out_dim, T_steps) + @test all(isfinite, preds1) + @test preds1 ≈ preds2 +end + +# --------------------------------------------------------------------------- +# 5. Autoregressive predict +# +# `predict(rc, steps, ps, st; initialdata)` runs `steps` sub-solves, +# feeding the previous readout output back as the constant input on the +# next sub-interval. Shape and determinism both checked here. +# --------------------------------------------------------------------------- + +@testset "Autoregressive predict" begin + rng = MersenneTwister(31) + res_dim, dim, steps = 6, 3, 5 + tspan = (0.0, 1.0) + prob, _, _, _, _ = build_esn_problem(rng, dim, res_dim, tspan) + res = SciMLProblemReservoir( + prob, TerminalStateSampling(), tspan, Tsit5(); + reltol = 1.0e-8, abstol = 1.0e-10 + ) + rc = ReservoirComputer(res, LinearReadout(res_dim => dim)) + ps, st = setup(MersenneTwister(0), rc) + + initialdata = randn(dim) + preds1, _ = predict(rc, steps, ps, st; initialdata = initialdata) + preds2, _ = predict(rc, steps, ps, st; initialdata = initialdata) + @test size(preds1) == (dim, steps) + @test all(isfinite, preds1) + @test preds1 ≈ preds2 +end + +# --------------------------------------------------------------------------- +# 6. State modifiers compose with the continuous path +# +# `states_modifiers` must compose with the continuous reservoir the same +# way they do with the discrete one — apply per saved sample, threading +# the modifier state across columns. NLAT2 doubles even-indexed columns +# of its input, so the modified state must differ in those columns from +# the raw one. +# --------------------------------------------------------------------------- + +@testset "State modifiers on continuous path" begin + rng = MersenneTwister(41) + in_dim, res_dim, T_steps = 2, 6, 8 + tspan = (0.0, 1.0) + prob, _, _, _, _ = build_esn_problem(rng, in_dim, res_dim, tspan) + data = randn(rng, in_dim, T_steps) + res = SciMLProblemReservoir( + prob, TerminalStateSampling(), tspan, Tsit5(); + reltol = 1.0e-8, abstol = 1.0e-10 + ) + + rc_plain = ReservoirComputer(res, LinearReadout(res_dim => 1)) + rc_mod = ReservoirComputer(res, (NLAT2(),), LinearReadout(res_dim => 1)) + + ps_plain, st_plain = setup(MersenneTwister(0), rc_plain) + ps_mod, st_mod = setup(MersenneTwister(0), rc_mod) + + states_plain, _ = collectstates(rc_plain, data, ps_plain, st_plain) + states_mod, _ = collectstates(rc_mod, data, ps_mod, st_mod) + + @test size(states_mod) == size(states_plain) + @test all(isfinite, states_mod) + # NLAT2 mutates even-indexed rows in-place — at least some entries must + # differ from the unmodified state matrix. + @test states_mod != states_plain +end + +# --------------------------------------------------------------------------- +# 7. Boundary inputs +# +# Smallest valid sizes: `T_steps = 2` for collectstates, `steps = 1` for +# autoregressive predict. Larger guards (≥2 and ≥1) reject anything +# smaller with an ArgumentError. Make sure both paths land cleanly at +# their lower bounds and that the guards actually fire one step lower. +# --------------------------------------------------------------------------- + +@testset "Boundary sizes" begin + rng = MersenneTwister(53) + in_dim, res_dim = 2, 4 + tspan = (0.0, 1.0) + prob, _, _, _, _ = build_esn_problem(rng, in_dim, res_dim, tspan) + res = SciMLProblemReservoir( + prob, TerminalStateSampling(), tspan, Tsit5(); + reltol = 1.0e-8, abstol = 1.0e-10 + ) + rc = ReservoirComputer(res, LinearReadout(res_dim => in_dim)) + ps, st = setup(MersenneTwister(0), rc) + + # collectstates: T_steps = 2 works + data2 = randn(rng, in_dim, 2) + states2, _ = collectstates(rc, data2, ps, st) + @test size(states2) == (res_dim, 2) + + # collectstates: T_steps < 2 errors + data1 = randn(rng, in_dim, 1) + @test_throws ArgumentError collectstates(rc, data1, ps, st) + data0 = Matrix{Float64}(undef, in_dim, 0) + @test_throws ArgumentError collectstates(rc, data0, ps, st) + + # autoregressive predict: steps = 1 works + preds1, _ = predict(rc, 1, ps, st; initialdata = randn(in_dim)) + @test size(preds1) == (in_dim, 1) + + # autoregressive predict: steps < 1 errors + @test_throws ArgumentError predict(rc, 0, ps, st; initialdata = randn(in_dim)) +end + +# --------------------------------------------------------------------------- +# 8. Construction-time validation of protected kwargs +# +# `SciMLProblemReservoir` rejects `saveat`, `save_everystep`, and `dense` +# in `kwargs` at construction. +# --------------------------------------------------------------------------- + +@testset "Protected solve kwargs rejected at construction" begin + placeholder = (placeholder = true,) + sampler = TerminalStateSampling() + tspan = (0.0, 1.0) + for badkw in (:saveat, :save_everystep, :dense) + @test_throws ArgumentError SciMLProblemReservoir( + placeholder, sampler, tspan, Tsit5(); (badkw => true,)... + ) + end +end + +# --------------------------------------------------------------------------- +# 9. `tspan` must be a strictly positive interval +# +# Degenerate `tspan = (c, c)` (or backward) would divide by zero when the +# extension synthesises the input grid step. Validate at solve time. +# --------------------------------------------------------------------------- + +@testset "tspan strictly positive" begin + rng = MersenneTwister(67) + in_dim, res_dim = 2, 4 + prob, _, _, _, _ = build_esn_problem(rng, in_dim, res_dim, (0.0, 1.0)) + data = randn(rng, in_dim, 4) + + # Equal endpoints + res_eq = SciMLProblemReservoir( + prob, TerminalStateSampling(), (1.0, 1.0), Tsit5() + ) + rc_eq = ReservoirComputer(res_eq, LinearReadout(res_dim => 1)) + ps, st = setup(MersenneTwister(0), rc_eq) + @test_throws ArgumentError collectstates(rc_eq, data, ps, st) + @test_throws ArgumentError predict(rc_eq, 3, ps, st; initialdata = randn(in_dim)) + + # Backward interval + res_back = SciMLProblemReservoir( + prob, TerminalStateSampling(), (1.0, 0.0), Tsit5() + ) + rc_back = ReservoirComputer(res_back, LinearReadout(res_dim => 1)) + @test_throws ArgumentError collectstates(rc_back, data, ps, st) +end + +# --------------------------------------------------------------------------- +# 10. Reserved `:input` key collision +# +# `:input` is the name the extension injects into the solve params so the +# user's RHS can read `p.input(t)`. A `prob.p` already carrying `:input` +# would be silently shadowed — error loudly instead. +# --------------------------------------------------------------------------- + +@testset "Reserved `:input` key collision errors" begin + rng = MersenneTwister(79) + res_dim = 4 + + function rhs_bad!(dx, x, p, t) + return dx .= .-x + end + + prob_bad = ODEProblem( + rhs_bad!, zeros(res_dim), (0.0, 1.0), + (input = "already taken",) + ) + res = SciMLProblemReservoir( + prob_bad, TerminalStateSampling(), (0.0, 1.0), Tsit5() + ) + rc = ReservoirComputer(res, LinearReadout(res_dim => 1)) + ps, st = setup(MersenneTwister(0), rc) + data = randn(rng, 1, 4) + @test_throws ArgumentError collectstates(rc, data, ps, st) +end + +# --------------------------------------------------------------------------- +# 11. `prob.p` accepted forms: NamedTuple / nothing / NullParameters +# +# `_to_namedtuple` advertises three valid inputs and rejects anything +# else. Exercise all three success paths end-to-end so a future refactor +# can't quietly break two of them. +# --------------------------------------------------------------------------- + +@testset "prob.p accepts NamedTuple / nothing / NullParameters" begin + rng = MersenneTwister(83) + in_dim, res_dim, T_steps = 1, 4, 6 + tspan = (0.0, 1.0) + data = randn(rng, in_dim, T_steps) + + # `prob.p` is read inside the RHS, so to truly exercise all three we use + # an RHS that does not touch `p` apart from `p.input(t)`. A `let` block + # captures `Win` as a local binding inside the closure — avoids the + # type-instability hazard of a `global` and keeps the symbol out of the + # surrounding module scope. + rhs_noparams! = let Win = 0.5 .* randn(rng, res_dim, in_dim) + (dx, x, p, t) -> (dx .= .-x .+ Win * p.input(t)) + end + + for p_value in ((;), nothing, SciMLBase.NullParameters()) + prob = ODEProblem(rhs_noparams!, zeros(res_dim), tspan, p_value) + res = SciMLProblemReservoir( + prob, TerminalStateSampling(), tspan, Tsit5(); + reltol = 1.0e-8, abstol = 1.0e-10 + ) + rc = ReservoirComputer(res, LinearReadout(res_dim => 1)) + ps, st = setup(MersenneTwister(0), rc) + states, _ = collectstates(rc, data, ps, st) + @test size(states) == (res_dim, T_steps) + @test all(isfinite, states) + end +end + +# --------------------------------------------------------------------------- +# 12. Non-NamedTuple `prob.p` rejected with a clear error +# --------------------------------------------------------------------------- + +@testset "Non-NamedTuple prob.p errors clearly" begin + function rhs!(dx, x, p, t) + return dx .= .-x + end + prob = ODEProblem(rhs!, [0.0], (0.0, 1.0), [1.0, 2.0]) # Vector params + res = SciMLProblemReservoir( + prob, TerminalStateSampling(), (0.0, 1.0), Tsit5(); + reltol = 1.0e-6 + ) + rc = ReservoirComputer(res, LinearReadout(1 => 1)) + ps, st = setup(MersenneTwister(0), rc) + data = ones(1, 4) + @test_throws ArgumentError collectstates(rc, data, ps, st) +end diff --git a/test/test_sciml_reservoir.jl b/test/test_sciml_reservoir.jl index a5a545d2f..0d5735831 100644 --- a/test/test_sciml_reservoir.jl +++ b/test/test_sciml_reservoir.jl @@ -46,3 +46,48 @@ end data = randn(Float32, 1, 5) @test_throws ErrorException collectstates(rc, data, ps, st) end + +@testset "SciMLProblemReservoir rejects protected solve kwargs" begin + prob = (placeholder = true,) + sampler = TerminalStateSampling() + tspan = (0.0, 1.0) + for badkw in (:saveat, :save_everystep, :dense) + @test_throws ArgumentError SciMLProblemReservoir( + prob, sampler, tspan; (badkw => true,)... + ) + end + # User kwargs that do not collide should still go through. + res_ok = SciMLProblemReservoir(prob, sampler, tspan; reltol = 1.0e-6) + @test res_ok.kwargs[:reltol] == 1.0e-6 +end + +@testset "Continuous _predict errors without extension" begin + prob = (placeholder = true,) + res = SciMLProblemReservoir(prob, TerminalStateSampling(), (0.0, 1.0)) + rc = ReservoirComputer(res, LinearReadout(1 => 1)) + rng = MersenneTwister(0) + ps, st = setup(rng, rc) + data = randn(Float32, 1, 5) + @test_throws ErrorException predict(rc, data, ps, st) + @test_throws ErrorException predict(rc, 3, ps, st; initialdata = randn(Float32, 1)) +end + +# `DeepESN` is an `AbstractReservoirComputer` subtype whose leading field is +# `:cells`, not `:reservoir`. The new two-level `predict` dispatch must not +# unconditionally reach for `rc.reservoir`, or DeepESN crashes with a +# `FieldError` (originally surfaced by the docs `@example` block in +# `tutorials/deep_esn.md`). This testset locks in the `hasfield` guard. +@testset "predict works on reservoir computers without a :reservoir field" begin + rng = MersenneTwister(0) + # `rand_sparse`'s sparsity defaults need a wide-enough reservoir for the + # spectral-radius rescaling to avoid degenerate NaNs. + desn = DeepESN(3, [16, 16], 3) + ps, st = setup(rng, desn) + data = randn(3, 5) + out, _ = predict(desn, data, ps, st) + @test size(out) == (3, 5) + @test all(isfinite, out) + out_ar, _ = predict(desn, 3, ps, st; initialdata = randn(3)) + @test size(out_ar) == (3, 3) + @test all(isfinite, out_ar) +end