From c53cd482fbdbe3cec09b894e3b9e7065cf9f59dc Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Mon, 22 Jun 2026 18:44:55 +0200 Subject: [PATCH] fix: removing delay and nla bugs, small all around fixes --- src/ReservoirComputing.jl | 6 +- src/models/esn_deep.jl | 8 ++- src/models/esn_delay.jl | 39 +++++++++++++ src/predict.jl | 5 +- src/states.jl | 30 +++++----- src/train.jl | 4 +- test/test_esn_delay.jl | 116 ++++++++++++++++++++++++++++++++++++++ test/test_states.jl | 2 +- 8 files changed, 183 insertions(+), 27 deletions(-) diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index 5d4e37318..3bec529aa 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -65,12 +65,10 @@ include("extensions/reca.jl") export ReservoirComputer export AbstractSciMLProblemReservoir, SciMLProblemReservoir export AbstractSampler, TerminalStateSampling -export ESNCell, ES2NCell, EuSNCell, MemoryESNCell, MemoryResESNCell, RMNCell +export AdditiveEIESNCell, EIESNCell, ES2NCell, ESNCell, EuSNCell, LIFESNCell, + MemoryESNCell, MemoryResESNCell, ResESNCell, RMNCell export StatefulLayer, LinearReadout, ReservoirChain, Collect, collectstates, DelayLayer, NonlinearFeaturesLayer -export AdditiveEIESNCell, EIESNCell, ES2NCell, ESNCell, EuSNCell, LIFESNCell, ResESNCell -export Collect, collectstates, DelayLayer, LinearReadout, NonlinearFeaturesLayer, - ReservoirChain, StatefulLayer export SVMReadout export LocalInformationFlow export Extend, ExtendedSquare, NLAT1, NLAT2, NLAT3, Pad, PartialSquare diff --git a/src/models/esn_deep.jl b/src/models/esn_deep.jl index fd1d98a75..1c9f54d36 100644 --- a/src/models/esn_deep.jl +++ b/src/models/esn_deep.jl @@ -50,8 +50,12 @@ Per-layer reservoir options (passed to each [`ESNCell`](@ref)): - `init_state`: Initializer(s) used when an external state is not provided. Scalar or length-`L`. Default: `randn32`. - `use_bias`: Whether each reservoir uses a bias term. Boolean scalar or length-`L`. Default: `false`. - - `depth`: Depth of the DeepESN. If the reservoir size is given as a number instead of a vector, this - parameter controls the depth of the model. Default is 2. + +Depth: + + - `depth`: Number of reservoir layers. Only used when `res_dims` is given as a + single integer (the depth is then `depth` layers of that width); it is ignored + when `res_dims` is a vector, whose length already sets the depth `L`. Default: `2`. Composition: diff --git a/src/models/esn_delay.jl b/src/models/esn_delay.jl index 92b11e2dc..fcc5394fb 100644 --- a/src/models/esn_delay.jl +++ b/src/models/esn_delay.jl @@ -517,3 +517,42 @@ function Base.show(io::IO, esn::DelayESN) return end + +# `InputDelayESN`/`DelayESN` carry an extra `input_delay` field that must run +# *before* the reservoir. The generic `AbstractReservoirComputer` machinery in +# `reservoircomputer.jl` only knows about `(:reservoir, :states_modifiers, +# :readout)`, so these models need their own parameter/state setup and their +# own `_partial_apply`. The generic forward call and `collectstates` both route +# through `_partial_apply`, so overriding it here is enough to make them work. +const _InputDelayedESN = Union{InputDelayESN, DelayESN} + +function initialparameters(rng::AbstractRNG, esn::_InputDelayedESN) + return ( + input_delay = initialparameters(rng, esn.input_delay), + reservoir = initialparameters(rng, esn.reservoir), + states_modifiers = map(l -> initialparameters(rng, l), esn.states_modifiers) |> + Tuple, + readout = initialparameters(rng, esn.readout), + ) +end + +function initialstates(rng::AbstractRNG, esn::_InputDelayedESN) + return ( + input_delay = initialstates(rng, esn.input_delay), + reservoir = initialstates(rng, esn.reservoir), + states_modifiers = map(l -> initialstates(rng, l), esn.states_modifiers) |> Tuple, + readout = initialstates(rng, esn.readout), + ) +end + +function _partial_apply(esn::_InputDelayedESN, inp, ps, st) + inp_delayed, st_input_delay = apply( + esn.input_delay, inp, ps.input_delay, st.input_delay + ) + res_state, st_res = apply(esn.reservoir, inp_delayed, ps.reservoir, st.reservoir) + out, st_mods = _apply_seq( + esn.states_modifiers, res_state, ps.states_modifiers, st.states_modifiers + ) + return out, + (input_delay = st_input_delay, reservoir = st_res, states_modifiers = st_mods) +end diff --git a/src/predict.jl b/src/predict.jl index acc24fbce..70fde7fc4 100644 --- a/src/predict.jl +++ b/src/predict.jl @@ -1,5 +1,5 @@ @doc raw""" - predict(rc, steps::Integer, ps, st; initialdata=nothing) + predict(rc, steps::Integer, ps, st; initialdata) predict(rc, data::AbstractMatrix, ps, st) Run the model either in (1) closed-loop (auto-regressive) mode for a fixed number @@ -22,8 +22,7 @@ sequence. ### Keyword Arguments -- `initialdata=nothing`: Column vector used as the first input. - Has to be provided. +- `initialdata`: Column vector used as the first input. Required keyword argument. ### Returns diff --git a/src/states.jl b/src/states.jl index 7a685e747..36fa19d2f 100644 --- a/src/states.jl +++ b/src/states.jl @@ -5,10 +5,10 @@ Function, } cols = axes(states, 2) - states_1 = states_mod(states[:, first(cols)]) + states_1 = states_mod(@view states[:, first(cols)]) new_states = similar(states_1, length(states_1), length(cols)) new_states[:, 1] .= states_1 - for (k, j) in enumerate(cols) + for (k, j) in Iterators.drop(enumerate(cols), 1) new_states[:, k] .= states_mod(@view states[:, j]) end return new_states @@ -158,7 +158,7 @@ esn = ReservoirChain( ) ``` -In this esample the input to `Extend` is the initial value fed to +In this example the input to `Extend` is the initial value fed to [`ReservoirChain`](@ref). After `Extend`, the value in the chain will be the state returned by the [`StatefulLayer`](@ref), `vcat`ed with the input. @@ -375,21 +375,21 @@ julia> mat_old = [1 2 3; julia> mat_new = nlat2(mat_old) 7×3 Matrix{Int64}: - 1 2 3 - 4 5 6 - 4 10 18 - 10 11 12 - 70 88 108 - 16 17 18 - 19 20 21 + 1 2 3 + 4 5 6 + 4 10 18 + 10 11 12 + 70 88 108 + 16 17 18 + 208 238 270 ``` """ function NLAT2(x_old::AbstractVector) x_new = copy(x_old) for idx in eachindex(x_old) - if firstindex(x_old) < idx < lastindex(x_old) && isodd(idx) - x_new[idx, :] .= x_old[idx - 1, :] .* x_old[idx - 2, :] + if firstindex(x_old) < idx && isodd(idx) + x_new[idx] = x_old[idx - 1] * x_old[idx - 2] end end return x_new @@ -435,7 +435,7 @@ None ## Example ```jldoctest nlat3 -julia> nlat2 = NLAT3() +julia> nlat3 = NLAT3() NLAT3 (generic function with 3 methods) julia> x_old = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] @@ -451,7 +451,7 @@ julia> x_old = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 8 9 -julia> n_new = nlat2(x_old) +julia> n_new = nlat3(x_old) 10-element Vector{Int64}: 0 1 @@ -483,7 +483,7 @@ julia> mat_old = [1 2 3; 16 17 18 19 20 21 -julia> mat_new = nlat2(mat_old) +julia> mat_new = nlat3(mat_old) 7×3 Matrix{Int64}: 1 2 3 4 5 6 diff --git a/src/train.jl b/src/train.jl index bcb38069b..3b4a5de26 100644 --- a/src/train.jl +++ b/src/train.jl @@ -66,7 +66,7 @@ additional changes. ## Returns - `output_weights`: Trained readout. Should be a forward method to be hooked into a - layer. For instance, in case of linear regression `output_weights` is a mtrix + layer. For instance, in case of linear regression `output_weights` is a matrix consumable by [`LinearReadout`](@ref). ## Notes @@ -104,7 +104,7 @@ end washout=0, return_states=false) Trains a given reservoir computing by creating the reservoir states from `train_data`, -and then fiting the readout layer using `target_data` as target. +and then fitting the readout layer using `target_data` as target. The learned weights/layer are written into `ps`. Use `return_states=true` to also obtain the feature matrix used for the fit, or call [`collectstates`](@ref) directly. diff --git a/test/test_esn_delay.jl b/test/test_esn_delay.jl index 5cbf76f20..94040f1c1 100644 --- a/test/test_esn_delay.jl +++ b/test/test_esn_delay.jl @@ -188,3 +188,119 @@ end @test Int(fesn2.readout.in_dims) == res_dims * (2 + 1) end end + +# These models carry an `input_delay` field that runs before the reservoir and +# is *not* handled by the generic `AbstractReservoirComputer` machinery, so they +# need bespoke `initialparameters`/`initialstates`/`_partial_apply`. The +# constructor tests above never exercise a forward pass; the round-trips below +# guard against the `input_delay` field being silently dropped from params/state +# or skipped during application (which used to crash with a `DimensionMismatch`). +@testset "InputDelayESN forward pass & training" begin + rng = MersenneTwister(123) + in_dims, res_dims, out_dims, num_delays = 3, 40, 3, 2 + n_steps = 60 + data = rand(rng, Float32, in_dims, n_steps) + target = rand(rng, Float32, out_dims, n_steps) + + idesn = InputDelayESN(in_dims, res_dims, out_dims; num_delays = num_delays) + ps, st = setup(rng, idesn) + + @testset "params/state expose input_delay" begin + @test haskey(ps, :input_delay) + @test haskey(st, :input_delay) + # reservoir input matrix is sized for the augmented (delayed) input + @test size(ps.reservoir.input_matrix, 2) == in_dims * (num_delays + 1) + end + + @testset "collectstates" begin + states, newst = collectstates(idesn, data, ps, st) + @test size(states) == (res_dims, n_steps) + @test all(isfinite, states) + # the internal delay buffer must have advanced once per time step, + # proving the input_delay layer actually ran + @test newst.input_delay.clock == n_steps + end + + @testset "train! + predict" begin + (ps_t, st_t), states = train!( + idesn, data, target, ps, st, StandardRidge(1.0e-6); return_states = true + ) + @test size(states) == (res_dims, n_steps) + + out_tf, _ = predict(idesn, data, ps_t, st_t) + @test size(out_tf) == (out_dims, n_steps) + @test all(isfinite, out_tf) + + steps = 8 + out_ar, _ = predict(idesn, steps, ps_t, st_t; initialdata = data[:, end]) + @test size(out_ar) == (out_dims, steps) + @test all(isfinite, out_ar) + end + + @testset "deterministic for a fixed seed" begin + ps1, st1 = setup(MersenneTwister(99), idesn) + ps2, st2 = setup(MersenneTwister(99), idesn) + s1, _ = collectstates(idesn, data, ps1, st1) + s2, _ = collectstates(idesn, data, ps2, st2) + @test s1 == s2 + end +end + +@testset "DelayESN forward pass & training" begin + rng = MersenneTwister(123) + in_dims, res_dims, out_dims = 3, 40, 2 + num_input_delays, num_state_delays = 2, 1 + n_steps = 60 + data = rand(rng, Float32, in_dims, n_steps) + target = rand(rng, Float32, out_dims, n_steps) + + fesn = DelayESN( + in_dims, res_dims, out_dims; + num_input_delays = num_input_delays, num_state_delays = num_state_delays + ) + ps, st = setup(rng, fesn) + + @testset "params/state expose input_delay" begin + @test haskey(ps, :input_delay) + @test haskey(st, :input_delay) + @test size(ps.reservoir.input_matrix, 2) == in_dims * (num_input_delays + 1) + end + + @testset "collectstates" begin + states, newst = collectstates(fesn, data, ps, st) + # input + state delays: readout sees (num_state_delays + 1) * res_dims + @test size(states) == (res_dims * (num_state_delays + 1), n_steps) + @test all(isfinite, states) + @test newst.input_delay.clock == n_steps + end + + @testset "train! + predict" begin + ps_t, st_t = train!(fesn, data, target, ps, st, StandardRidge(1.0e-6)) + + out_tf, _ = predict(fesn, data, ps_t, st_t) + @test size(out_tf) == (out_dims, n_steps) + @test all(isfinite, out_tf) + end +end + +# StateDelayESN routes its delay through `states_modifiers`, so it already used +# the generic machinery — keep a round-trip here so the three delay models stay +# in lockstep. +@testset "StateDelayESN forward pass & training" begin + rng = MersenneTwister(123) + in_dims, res_dims, out_dims, num_delays = 3, 40, 2, 2 + n_steps = 60 + data = rand(rng, Float32, in_dims, n_steps) + target = rand(rng, Float32, out_dims, n_steps) + + sdesn = StateDelayESN(in_dims, res_dims, out_dims; num_delays = num_delays) + ps, st = setup(rng, sdesn) + + states, _ = collectstates(sdesn, data, ps, st) + @test size(states) == (res_dims * (num_delays + 1), n_steps) + + ps_t, st_t = train!(sdesn, data, target, ps, st, StandardRidge(1.0e-6)) + out_tf, _ = predict(sdesn, data, ps_t, st_t) + @test size(out_tf) == (out_dims, n_steps) + @test all(isfinite, out_tf) +end diff --git a/test/test_states.jl b/test/test_states.jl index d534ef237..9c4271991 100644 --- a/test/test_states.jl +++ b/test/test_states.jl @@ -7,7 +7,7 @@ test_types = [Float64, Float32, Float16] nlas = [ (NLAT1(), [1, 2, 9, 4, 25, 6, 49, 8, 81]), - (NLAT2(), [1, 2, 2, 4, 12, 6, 30, 8, 9]), + (NLAT2(), [1, 2, 2, 4, 12, 6, 30, 8, 56]), (NLAT3(), [1, 2, 8, 4, 24, 6, 48, 8, 9]), (PartialSquare(0.6), [1, 4, 9, 16, 25, 6, 7, 8, 9]), (ExtendedSquare(), [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 4, 9, 16, 25, 36, 49, 64, 81]),