Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/ReservoirComputing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,10 @@ include("extensions/reca.jl")
export ReservoirComputer
export AbstractSciMLProblemReservoir, SciMLProblemReservoir
export AbstractSampler, TerminalStateSampling
export ESNCell, ES2NCell, EuSNCell, MemoryESNCell, MemoryResESNCell, RMNCell
export AdditiveEIESNCell, EIESNCell, ES2NCell, ESNCell, EuSNCell, LIFESNCell,
MemoryESNCell, MemoryResESNCell, ResESNCell, RMNCell
export StatefulLayer, LinearReadout, ReservoirChain, Collect, collectstates,
DelayLayer, NonlinearFeaturesLayer
export AdditiveEIESNCell, EIESNCell, ES2NCell, ESNCell, EuSNCell, LIFESNCell, ResESNCell
export Collect, collectstates, DelayLayer, LinearReadout, NonlinearFeaturesLayer,
ReservoirChain, StatefulLayer
export SVMReadout
export LocalInformationFlow
export Extend, ExtendedSquare, NLAT1, NLAT2, NLAT3, Pad, PartialSquare
Expand Down
8 changes: 6 additions & 2 deletions src/models/esn_deep.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@ Per-layer reservoir options (passed to each [`ESNCell`](@ref)):
- `init_state`: Initializer(s) used when an external state is not provided.
Scalar or length-`L`. Default: `randn32`.
- `use_bias`: Whether each reservoir uses a bias term. Boolean scalar or length-`L`. Default: `false`.
- `depth`: Depth of the DeepESN. If the reservoir size is given as a number instead of a vector, this
parameter controls the depth of the model. Default is 2.

Depth:

- `depth`: Number of reservoir layers. Only used when `res_dims` is given as a
single integer (the depth is then `depth` layers of that width); it is ignored
when `res_dims` is a vector, whose length already sets the depth `L`. Default: `2`.

Composition:

Expand Down
39 changes: 39 additions & 0 deletions src/models/esn_delay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -517,3 +517,42 @@ function Base.show(io::IO, esn::DelayESN)

return
end

# `InputDelayESN`/`DelayESN` carry an extra `input_delay` field that must run
# *before* the reservoir. The generic `AbstractReservoirComputer` machinery in
# `reservoircomputer.jl` only knows about `(:reservoir, :states_modifiers,
# :readout)`, so these models need their own parameter/state setup and their
# own `_partial_apply`. The generic forward call and `collectstates` both route
# through `_partial_apply`, so overriding it here is enough to make them work.
const _InputDelayedESN = Union{InputDelayESN, DelayESN}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious about one thing: since AbstractReservoirComputer{Fields} already encodes the field-name tuple, would it make sense down the line to have the generic initialparameters / _partial_apply introspect Fields directly, so any future composite model with extra pre- or post-fields just works? Could be a follow-up — wondering whether you'd considered that direction, or whether there's a reason a per-model override is preferable here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that could be an elegant solution yeha, if there's a way to make it generic enough it should solve similar gotchas to this one


function initialparameters(rng::AbstractRNG, esn::_InputDelayedESN)
return (
input_delay = initialparameters(rng, esn.input_delay),
reservoir = initialparameters(rng, esn.reservoir),
states_modifiers = map(l -> initialparameters(rng, l), esn.states_modifiers) |>
Tuple,
readout = initialparameters(rng, esn.readout),
)
end

function initialstates(rng::AbstractRNG, esn::_InputDelayedESN)
return (
input_delay = initialstates(rng, esn.input_delay),
reservoir = initialstates(rng, esn.reservoir),
states_modifiers = map(l -> initialstates(rng, l), esn.states_modifiers) |> Tuple,
readout = initialstates(rng, esn.readout),
)
end

function _partial_apply(esn::_InputDelayedESN, inp, ps, st)
inp_delayed, st_input_delay = apply(
esn.input_delay, inp, ps.input_delay, st.input_delay
)
res_state, st_res = apply(esn.reservoir, inp_delayed, ps.reservoir, st.reservoir)
out, st_mods = _apply_seq(
esn.states_modifiers, res_state, ps.states_modifiers, st.states_modifiers
)
return out,
(input_delay = st_input_delay, reservoir = st_res, states_modifiers = st_mods)
end
5 changes: 2 additions & 3 deletions src/predict.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@doc raw"""
predict(rc, steps::Integer, ps, st; initialdata=nothing)
predict(rc, steps::Integer, ps, st; initialdata)
predict(rc, data::AbstractMatrix, ps, st)

Run the model either in (1) closed-loop (auto-regressive) mode for a fixed number
Expand All @@ -22,8 +22,7 @@ sequence.

### Keyword Arguments

- `initialdata=nothing`: Column vector used as the first input.
Has to be provided.
- `initialdata`: Column vector used as the first input. Required keyword argument.

### Returns

Expand Down
30 changes: 15 additions & 15 deletions src/states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
Function,
}
cols = axes(states, 2)
states_1 = states_mod(states[:, first(cols)])
states_1 = states_mod(@view states[:, first(cols)])
new_states = similar(states_1, length(states_1), length(cols))
new_states[:, 1] .= states_1
for (k, j) in enumerate(cols)
for (k, j) in Iterators.drop(enumerate(cols), 1)
new_states[:, k] .= states_mod(@view states[:, j])
end
return new_states
Expand Down Expand Up @@ -158,7 +158,7 @@ esn = ReservoirChain(
)
```

In this esample the input to `Extend` is the initial value fed to
In this example the input to `Extend` is the initial value fed to
[`ReservoirChain`](@ref). After `Extend`, the value in the chain will
be the state returned by the [`StatefulLayer`](@ref), `vcat`ed with
the input.
Expand Down Expand Up @@ -375,21 +375,21 @@ julia> mat_old = [1 2 3;

julia> mat_new = nlat2(mat_old)
7×3 Matrix{Int64}:
1 2 3
4 5 6
4 10 18
10 11 12
70 88 108
16 17 18
19 20 21
1 2 3
4 5 6
4 10 18
10 11 12
70 88 108
16 17 18
208 238 270

```
"""
function NLAT2(x_old::AbstractVector)
x_new = copy(x_old)
for idx in eachindex(x_old)
if firstindex(x_old) < idx < lastindex(x_old) && isodd(idx)
x_new[idx, :] .= x_old[idx - 1, :] .* x_old[idx - 2, :]
if firstindex(x_old) < idx && isodd(idx)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When firstindex(x_old) > 1 (e.g. an OffsetArray), this would start applying at idx = firstindex + 1 if odd. Is that the intended semantics for offset inputs, or is NLAT2 effectively assumed to receive standard 1-based vectors? Probably never comes up in practice but I noticed it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't the only algo that would have issues with nonstandard arrays. It's in the plans to make a larger check for this kind of issues

x_new[idx] = x_old[idx - 1] * x_old[idx - 2]
end
end
return x_new
Expand Down Expand Up @@ -435,7 +435,7 @@ None
## Example

```jldoctest nlat3
julia> nlat2 = NLAT3()
julia> nlat3 = NLAT3()
NLAT3 (generic function with 3 methods)

julia> x_old = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Expand All @@ -451,7 +451,7 @@ julia> x_old = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
8
9

julia> n_new = nlat2(x_old)
julia> n_new = nlat3(x_old)
10-element Vector{Int64}:
0
1
Expand Down Expand Up @@ -483,7 +483,7 @@ julia> mat_old = [1 2 3;
16 17 18
19 20 21

julia> mat_new = nlat2(mat_old)
julia> mat_new = nlat3(mat_old)
7×3 Matrix{Int64}:
1 2 3
4 5 6
Expand Down
4 changes: 2 additions & 2 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ additional changes.
## Returns

- `output_weights`: Trained readout. Should be a forward method to be hooked into a
layer. For instance, in case of linear regression `output_weights` is a mtrix
layer. For instance, in case of linear regression `output_weights` is a matrix
consumable by [`LinearReadout`](@ref).

## Notes
Expand Down Expand Up @@ -104,7 +104,7 @@ end
washout=0, return_states=false)

Trains a given reservoir computing by creating the reservoir states from `train_data`,
and then fiting the readout layer using `target_data` as target.
and then fitting the readout layer using `target_data` as target.
The learned weights/layer are written into `ps`. Use `return_states=true` to also
obtain the feature matrix used for the fit, or call [`collectstates`](@ref) directly.

Expand Down
116 changes: 116 additions & 0 deletions test/test_esn_delay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,119 @@ end
@test Int(fesn2.readout.in_dims) == res_dims * (2 + 1)
end
end

# These models carry an `input_delay` field that runs before the reservoir and
# is *not* handled by the generic `AbstractReservoirComputer` machinery, so they
# need bespoke `initialparameters`/`initialstates`/`_partial_apply`. The
# constructor tests above never exercise a forward pass; the round-trips below
# guard against the `input_delay` field being silently dropped from params/state
# or skipped during application (which used to crash with a `DimensionMismatch`).
@testset "InputDelayESN forward pass & training" begin
rng = MersenneTwister(123)
in_dims, res_dims, out_dims, num_delays = 3, 40, 3, 2
n_steps = 60
data = rand(rng, Float32, in_dims, n_steps)
target = rand(rng, Float32, out_dims, n_steps)

idesn = InputDelayESN(in_dims, res_dims, out_dims; num_delays = num_delays)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of the new round-trip tests touch stride > 1. The fix itself isn't stride-sensitive (stride only changes what the buffer keeps), but I wondered if a quick stride=2 smoke test might be worth adding here just to lock the end-to-end behavior down. Or do you think the current coverage is enough?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeha that would be a good idea, I'll add it

ps, st = setup(rng, idesn)

@testset "params/state expose input_delay" begin
@test haskey(ps, :input_delay)
@test haskey(st, :input_delay)
# reservoir input matrix is sized for the augmented (delayed) input
@test size(ps.reservoir.input_matrix, 2) == in_dims * (num_delays + 1)
end

@testset "collectstates" begin
states, newst = collectstates(idesn, data, ps, st)
@test size(states) == (res_dims, n_steps)
@test all(isfinite, states)
# the internal delay buffer must have advanced once per time step,
# proving the input_delay layer actually ran
@test newst.input_delay.clock == n_steps
end

@testset "train! + predict" begin
(ps_t, st_t), states = train!(
idesn, data, target, ps, st, StandardRidge(1.0e-6); return_states = true
)
@test size(states) == (res_dims, n_steps)

out_tf, _ = predict(idesn, data, ps_t, st_t)
@test size(out_tf) == (out_dims, n_steps)
@test all(isfinite, out_tf)

steps = 8
out_ar, _ = predict(idesn, steps, ps_t, st_t; initialdata = data[:, end])
@test size(out_ar) == (out_dims, steps)
@test all(isfinite, out_ar)
end

@testset "deterministic for a fixed seed" begin
ps1, st1 = setup(MersenneTwister(99), idesn)
ps2, st2 = setup(MersenneTwister(99), idesn)
s1, _ = collectstates(idesn, data, ps1, st1)
s2, _ = collectstates(idesn, data, ps2, st2)
@test s1 == s2
end
end

@testset "DelayESN forward pass & training" begin
rng = MersenneTwister(123)
in_dims, res_dims, out_dims = 3, 40, 2
num_input_delays, num_state_delays = 2, 1
n_steps = 60
data = rand(rng, Float32, in_dims, n_steps)
target = rand(rng, Float32, out_dims, n_steps)

fesn = DelayESN(
in_dims, res_dims, out_dims;
num_input_delays = num_input_delays, num_state_delays = num_state_delays
)
ps, st = setup(rng, fesn)

@testset "params/state expose input_delay" begin
@test haskey(ps, :input_delay)
@test haskey(st, :input_delay)
@test size(ps.reservoir.input_matrix, 2) == in_dims * (num_input_delays + 1)
end

@testset "collectstates" begin
states, newst = collectstates(fesn, data, ps, st)
# input + state delays: readout sees (num_state_delays + 1) * res_dims
@test size(states) == (res_dims * (num_state_delays + 1), n_steps)
@test all(isfinite, states)
@test newst.input_delay.clock == n_steps
end

@testset "train! + predict" begin
ps_t, st_t = train!(fesn, data, target, ps, st, StandardRidge(1.0e-6))

out_tf, _ = predict(fesn, data, ps_t, st_t)
@test size(out_tf) == (out_dims, n_steps)
@test all(isfinite, out_tf)
end
end

# StateDelayESN routes its delay through `states_modifiers`, so it already used
# the generic machinery — keep a round-trip here so the three delay models stay
# in lockstep.
@testset "StateDelayESN forward pass & training" begin
rng = MersenneTwister(123)
in_dims, res_dims, out_dims, num_delays = 3, 40, 2, 2
n_steps = 60
data = rand(rng, Float32, in_dims, n_steps)
target = rand(rng, Float32, out_dims, n_steps)

sdesn = StateDelayESN(in_dims, res_dims, out_dims; num_delays = num_delays)
ps, st = setup(rng, sdesn)

states, _ = collectstates(sdesn, data, ps, st)
@test size(states) == (res_dims * (num_delays + 1), n_steps)

ps_t, st_t = train!(sdesn, data, target, ps, st, StandardRidge(1.0e-6))
out_tf, _ = predict(sdesn, data, ps_t, st_t)
@test size(out_tf) == (out_dims, n_steps)
@test all(isfinite, out_tf)
end
2 changes: 1 addition & 1 deletion test/test_states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ test_types = [Float64, Float32, Float16]

nlas = [
(NLAT1(), [1, 2, 9, 4, 25, 6, 49, 8, 81]),
(NLAT2(), [1, 2, 2, 4, 12, 6, 30, 8, 9]),
(NLAT2(), [1, 2, 2, 4, 12, 6, 30, 8, 56]),
(NLAT3(), [1, 2, 8, 4, 24, 6, 48, 8, 9]),
(PartialSquare(0.6), [1, 4, 9, 16, 25, 6, 7, 8, 9]),
(ExtendedSquare(), [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 4, 9, 16, 25, 36, 49, 64, 81]),
Expand Down
Loading