-
-
Notifications
You must be signed in to change notification settings - Fork 49
fix: removing delay and nla bugs, small all around fixes #459
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,10 +5,10 @@ | |
| Function, | ||
| } | ||
| cols = axes(states, 2) | ||
| states_1 = states_mod(states[:, first(cols)]) | ||
| states_1 = states_mod(@view states[:, first(cols)]) | ||
| new_states = similar(states_1, length(states_1), length(cols)) | ||
| new_states[:, 1] .= states_1 | ||
| for (k, j) in enumerate(cols) | ||
| for (k, j) in Iterators.drop(enumerate(cols), 1) | ||
| new_states[:, k] .= states_mod(@view states[:, j]) | ||
| end | ||
| return new_states | ||
|
|
@@ -158,7 +158,7 @@ esn = ReservoirChain( | |
| ) | ||
| ``` | ||
|
|
||
| In this esample the input to `Extend` is the initial value fed to | ||
| In this example the input to `Extend` is the initial value fed to | ||
| [`ReservoirChain`](@ref). After `Extend`, the value in the chain will | ||
| be the state returned by the [`StatefulLayer`](@ref), `vcat`ed with | ||
| the input. | ||
|
|
@@ -375,21 +375,21 @@ julia> mat_old = [1 2 3; | |
|
|
||
| julia> mat_new = nlat2(mat_old) | ||
| 7×3 Matrix{Int64}: | ||
| 1 2 3 | ||
| 4 5 6 | ||
| 4 10 18 | ||
| 10 11 12 | ||
| 70 88 108 | ||
| 16 17 18 | ||
| 19 20 21 | ||
| 1 2 3 | ||
| 4 5 6 | ||
| 4 10 18 | ||
| 10 11 12 | ||
| 70 88 108 | ||
| 16 17 18 | ||
| 208 238 270 | ||
|
|
||
| ``` | ||
| """ | ||
| function NLAT2(x_old::AbstractVector) | ||
| x_new = copy(x_old) | ||
| for idx in eachindex(x_old) | ||
| if firstindex(x_old) < idx < lastindex(x_old) && isodd(idx) | ||
| x_new[idx, :] .= x_old[idx - 1, :] .* x_old[idx - 2, :] | ||
| if firstindex(x_old) < idx && isodd(idx) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't the only algo that would have issues with nonstandard arrays. It's in the plans to make a larger check for this kind of issues |
||
| x_new[idx] = x_old[idx - 1] * x_old[idx - 2] | ||
| end | ||
| end | ||
| return x_new | ||
|
|
@@ -435,7 +435,7 @@ None | |
| ## Example | ||
|
|
||
| ```jldoctest nlat3 | ||
| julia> nlat2 = NLAT3() | ||
| julia> nlat3 = NLAT3() | ||
| NLAT3 (generic function with 3 methods) | ||
|
|
||
| julia> x_old = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] | ||
|
|
@@ -451,7 +451,7 @@ julia> x_old = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] | |
| 8 | ||
| 9 | ||
|
|
||
| julia> n_new = nlat2(x_old) | ||
| julia> n_new = nlat3(x_old) | ||
| 10-element Vector{Int64}: | ||
| 0 | ||
| 1 | ||
|
|
@@ -483,7 +483,7 @@ julia> mat_old = [1 2 3; | |
| 16 17 18 | ||
| 19 20 21 | ||
|
|
||
| julia> mat_new = nlat2(mat_old) | ||
| julia> mat_new = nlat3(mat_old) | ||
| 7×3 Matrix{Int64}: | ||
| 1 2 3 | ||
| 4 5 6 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -188,3 +188,119 @@ end | |
| @test Int(fesn2.readout.in_dims) == res_dims * (2 + 1) | ||
| end | ||
| end | ||
|
|
||
| # These models carry an `input_delay` field that runs before the reservoir and | ||
| # is *not* handled by the generic `AbstractReservoirComputer` machinery, so they | ||
| # need bespoke `initialparameters`/`initialstates`/`_partial_apply`. The | ||
| # constructor tests above never exercise a forward pass; the round-trips below | ||
| # guard against the `input_delay` field being silently dropped from params/state | ||
| # or skipped during application (which used to crash with a `DimensionMismatch`). | ||
| @testset "InputDelayESN forward pass & training" begin | ||
| rng = MersenneTwister(123) | ||
| in_dims, res_dims, out_dims, num_delays = 3, 40, 3, 2 | ||
| n_steps = 60 | ||
| data = rand(rng, Float32, in_dims, n_steps) | ||
| target = rand(rng, Float32, out_dims, n_steps) | ||
|
|
||
| idesn = InputDelayESN(in_dims, res_dims, out_dims; num_delays = num_delays) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. None of the new round-trip tests touch
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeha that would be a good idea, I'll add it |
||
| ps, st = setup(rng, idesn) | ||
|
|
||
| @testset "params/state expose input_delay" begin | ||
| @test haskey(ps, :input_delay) | ||
| @test haskey(st, :input_delay) | ||
| # reservoir input matrix is sized for the augmented (delayed) input | ||
| @test size(ps.reservoir.input_matrix, 2) == in_dims * (num_delays + 1) | ||
| end | ||
|
|
||
| @testset "collectstates" begin | ||
| states, newst = collectstates(idesn, data, ps, st) | ||
| @test size(states) == (res_dims, n_steps) | ||
| @test all(isfinite, states) | ||
| # the internal delay buffer must have advanced once per time step, | ||
| # proving the input_delay layer actually ran | ||
| @test newst.input_delay.clock == n_steps | ||
| end | ||
|
|
||
| @testset "train! + predict" begin | ||
| (ps_t, st_t), states = train!( | ||
| idesn, data, target, ps, st, StandardRidge(1.0e-6); return_states = true | ||
| ) | ||
| @test size(states) == (res_dims, n_steps) | ||
|
|
||
| out_tf, _ = predict(idesn, data, ps_t, st_t) | ||
| @test size(out_tf) == (out_dims, n_steps) | ||
| @test all(isfinite, out_tf) | ||
|
|
||
| steps = 8 | ||
| out_ar, _ = predict(idesn, steps, ps_t, st_t; initialdata = data[:, end]) | ||
| @test size(out_ar) == (out_dims, steps) | ||
| @test all(isfinite, out_ar) | ||
| end | ||
|
|
||
| @testset "deterministic for a fixed seed" begin | ||
| ps1, st1 = setup(MersenneTwister(99), idesn) | ||
| ps2, st2 = setup(MersenneTwister(99), idesn) | ||
| s1, _ = collectstates(idesn, data, ps1, st1) | ||
| s2, _ = collectstates(idesn, data, ps2, st2) | ||
| @test s1 == s2 | ||
| end | ||
| end | ||
|
|
||
| @testset "DelayESN forward pass & training" begin | ||
| rng = MersenneTwister(123) | ||
| in_dims, res_dims, out_dims = 3, 40, 2 | ||
| num_input_delays, num_state_delays = 2, 1 | ||
| n_steps = 60 | ||
| data = rand(rng, Float32, in_dims, n_steps) | ||
| target = rand(rng, Float32, out_dims, n_steps) | ||
|
|
||
| fesn = DelayESN( | ||
| in_dims, res_dims, out_dims; | ||
| num_input_delays = num_input_delays, num_state_delays = num_state_delays | ||
| ) | ||
| ps, st = setup(rng, fesn) | ||
|
|
||
| @testset "params/state expose input_delay" begin | ||
| @test haskey(ps, :input_delay) | ||
| @test haskey(st, :input_delay) | ||
| @test size(ps.reservoir.input_matrix, 2) == in_dims * (num_input_delays + 1) | ||
| end | ||
|
|
||
| @testset "collectstates" begin | ||
| states, newst = collectstates(fesn, data, ps, st) | ||
| # input + state delays: readout sees (num_state_delays + 1) * res_dims | ||
| @test size(states) == (res_dims * (num_state_delays + 1), n_steps) | ||
| @test all(isfinite, states) | ||
| @test newst.input_delay.clock == n_steps | ||
| end | ||
|
|
||
| @testset "train! + predict" begin | ||
| ps_t, st_t = train!(fesn, data, target, ps, st, StandardRidge(1.0e-6)) | ||
|
|
||
| out_tf, _ = predict(fesn, data, ps_t, st_t) | ||
| @test size(out_tf) == (out_dims, n_steps) | ||
| @test all(isfinite, out_tf) | ||
| end | ||
| end | ||
|
|
||
| # StateDelayESN routes its delay through `states_modifiers`, so it already used | ||
| # the generic machinery — keep a round-trip here so the three delay models stay | ||
| # in lockstep. | ||
| @testset "StateDelayESN forward pass & training" begin | ||
| rng = MersenneTwister(123) | ||
| in_dims, res_dims, out_dims, num_delays = 3, 40, 2, 2 | ||
| n_steps = 60 | ||
| data = rand(rng, Float32, in_dims, n_steps) | ||
| target = rand(rng, Float32, out_dims, n_steps) | ||
|
|
||
| sdesn = StateDelayESN(in_dims, res_dims, out_dims; num_delays = num_delays) | ||
| ps, st = setup(rng, sdesn) | ||
|
|
||
| states, _ = collectstates(sdesn, data, ps, st) | ||
| @test size(states) == (res_dims * (num_delays + 1), n_steps) | ||
|
|
||
| ps_t, st_t = train!(sdesn, data, target, ps, st, StandardRidge(1.0e-6)) | ||
| out_tf, _ = predict(sdesn, data, ps_t, st_t) | ||
| @test size(out_tf) == (out_dims, n_steps) | ||
| @test all(isfinite, out_tf) | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious about one thing: since
AbstractReservoirComputer{Fields}already encodes the field-name tuple, would it make sense down the line to have the genericinitialparameters/_partial_applyintrospectFieldsdirectly, so any future composite model with extra pre- or post-fields just works? Could be a follow-up — wondering whether you'd considered that direction, or whether there's a reason a per-model override is preferable here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that could be an elegant solution yeha, if there's a way to make it generic enough it should solve similar gotchas to this one