Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@ NeuralNetworkBlock
SymbolicNeuralNetwork
@SymbolicNeuralNetwork
multi_layer_feed_forward
ModelingToolkitNeuralNets.isneuralnetwork
ModelingToolkitNeuralNets.isneuralnetworkps
ModelingToolkitNeuralNets.get_nn_chain
```
15 changes: 9 additions & 6 deletions src/ModelingToolkitNeuralNets.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module ModelingToolkitNeuralNets

using ModelingToolkitBase: @parameters, @named, @variables, System, t_nounits
using ModelingToolkitBase: @parameters, @named, @variables, System, t_nounits, getmetadata, hasmetadata
using IntervalSets: var".."
using Symbolics: Symbolics, @register_array_symbolic, @wrapped, unwrap, wrap, shape
using LuxCore: stateless_apply, outputsize
Expand All @@ -12,6 +12,9 @@ export NeuralNetworkBlock, SymbolicNeuralNetwork, @SymbolicNeuralNetwork, multi_

include("utils.jl")

# Functionality for accessing various neural network-related parameter properties.
include("nn_par_accessors.jl")

"""
NeuralNetworkBlock(; n_input = 1, n_output = 1,
chain = multi_layer_feed_forward(n_input, n_output),
Expand All @@ -32,10 +35,10 @@ function NeuralNetworkBlock(;
)
ca = ComponentArray{eltype}(init_params)

@parameters p[1:length(ca)] = Vector(ca) [tunable = true]
@parameters p[1:length(ca)] = Vector(ca) [tunable = true, neuralnetworkps = true]
@parameters T::typeof(typeof(ca)) = typeof(ca) [tunable = false]
@parameters lux_model::typeof(chain) = chain [tunable = false]
@parameters (lux_apply::typeof(stateless_apply))(..)[1:n_output] = stateless_apply [tunable = false]
@parameters (lux_apply::typeof(stateless_apply))(..)[1:n_output] = stateless_apply [tunable = false, neuralnetwork = true]

@variables inputs(t_nounits)[1:n_input] [input = true]
@variables outputs(t_nounits)[1:n_output] [output = true]
Expand Down Expand Up @@ -112,8 +115,8 @@ function SymbolicNeuralNetwork(;
ca = ComponentArray{eltype}(init_params)
wrapper = StatelessApplyWrapper(chain, typeof(ca))

p = @parameters $(nn_p_name)[1:length(ca)] = Vector(ca)
NN = @parameters ($(nn_name)::typeof(wrapper))(..)[1:n_output] = wrapper
p = @parameters $(nn_p_name)[1:length(ca)] = Vector(ca) [tunable = true, neuralnetworkps = true]
NN = @parameters ($(nn_name)::typeof(wrapper))(..)[1:n_output] = wrapper [tunable = false, neuralnetwork = true]
Comment thread
SebastianM-C marked this conversation as resolved.

return only(NN), only(p)
end
Expand Down Expand Up @@ -179,7 +182,7 @@ rng = Xoshiro(0)
@SymbolicNeuralNetwork NN, p = chain rng
```
Notes:
- The first and last layers of the chain must be one of the following types: `Lux.Dense`. For other first
- The first and last layers of the chain must be one of the following types: `Lux.Dense`. For other first
layer types, use the `SymbolicNeuralNetwork`
- Types that are intended to be supported in the first layer in future updates include `Lux.Bilinear`,
`Lux.RNNCell`, `Lux.LSTMCell`, `Lux.GRUCell`.
Expand Down
70 changes: 70 additions & 0 deletions src/nn_par_accessors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
### Defines Metadata Type ###
struct NeuralNetworkParameter end
struct NeuralNetworkParametrisation end
Symbolics.option_to_metadata_type(::Val{:neuralnetwork}) = NeuralNetworkParameter
Symbolics.option_to_metadata_type(::Val{:neuralnetworkps}) = NeuralNetworkParametrisation

### Defines Metadata Getters ###
"""
ModelingToolkitNeuralNets.isneuralnetwork(p)

Returns `true` if the parameter corresponds to the neural network chain that is saved as a MTK parameter. This function is primarily intended for internal use within dependent packages.

Example:
```julia
@parameters d
@SymbolicNeuralNetwork NN, θ = chain
ModelingToolkitNeuralNets.isneuralnetwork(d) # false
ModelingToolkitNeuralNets.isneuralnetwork(NN) # true
ModelingToolkitNeuralNets.isneuralnetwork(θ) # false
````
"""
isneuralnetwork(p::Union{Symbolics.Num, Symbolics.Arr, Symbolics.CallAndWrap}) = isneuralnetwork(Symbolics.unwrap(p))
function isneuralnetwork(p::Symbolics.SymbolicT)
getmetadata(p, NeuralNetworkParameter, false)
end

"""
ModelingToolkitNeuralNets.isneuralnetworkps(p)

Returns `true` if the parameter corresponds to the a neural network parametrisation. This function is primarily intended for internal use within dependent packages.

Example:
```julia
@parameters d
@SymbolicNeuralNetwork NN, θ = chain
ModelingToolkitNeuralNets.isneuralnetworkps(d) # false
ModelingToolkitNeuralNets.isneuralnetworkps(NN) # false
ModelingToolkitNeuralNets.isneuralnetworkps(θ) # true
````
"""
isneuralnetworkps(p::Union{Symbolics.Num, Symbolics.Arr, Symbolics.CallAndWrap}) = isneuralnetworkps(Symbolics.unwrap(p))
function isneuralnetworkps(p::Symbolics.SymbolicT)
getmetadata(p, NeuralNetworkParametrisation, false)
end


### Defines Other Accessors ###

"""
ModelingToolkitNeuralNets.get_nn_chain(p)

For a neural network parameter `p` (i.e. such that `isneuralnetwork(p) == true`), return the associated neural network chain. This function is primarily intended for internal use within dependent packages.

Example:
```julia
chain = Lux.Chain(
Lux.Dense(1 => 3, Lux.softplus; use_bias = false),
Lux.Dense(3 => 1, Lux.softplus; use_bias = false),
)
@SymbolicNeuralNetwork NN, θ = chain

ModelingToolkitNeuralNets.get_nn_chain(NN) # Returns `chain`.
ModelingToolkitNeuralNets.get_nn_chain(θ) # Throws an error.
````
"""
get_nn_chain(p::Union{Symbolics.Num, Symbolics.Arr, Symbolics.CallAndWrap}) = get_nn_chain(Symbolics.unwrap(p))
function get_nn_chain(p::Symbolics.SymbolicT)
isneuralnetwork(p) || error("Parameter $p does not have a neural network chain associated with it.")
return getmetadata(p, Symbolics.VariableDefaultValue).lux_model
end
121 changes: 121 additions & 0 deletions test/nn_ps_accessors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Fetch packages.
using ModelingToolkitBase, ModelingToolkitNeuralNets, Lux, Random
using ModelingToolkitBase: t_nounits as t, D_nounits as D

# Check that `isneuralnetwork` and `isneuralnetworkps` give correct input on various inputs.
let
# Tests on normally declared parameters.
@variables X(t) Y(t)[1:2]
@parameters p q[1:3]
for s in [X, Y, Y[1], p, q, q[1]]
@test !ModelingToolkitNeuralNets.isneuralnetwork(s)
@test !ModelingToolkitNeuralNets.isneuralnetworkps(s)
end

# Tests on MTKNeuralNets parameters
chain = Lux.Chain(
Lux.Dense(1 => 3, Lux.softplus; use_bias = false),
Lux.Dense(3 => 1, Lux.softplus; use_bias = false),
)
@SymbolicNeuralNetwork NN, θ = chain
U, p = SymbolicNeuralNetwork(; chain, n_input = 1, n_output = 1, nn_name = :U, nn_p_name = :p)
@test ModelingToolkitNeuralNets.isneuralnetwork(NN)
@test !ModelingToolkitNeuralNets.isneuralnetwork(θ)
@test ModelingToolkitNeuralNets.isneuralnetwork(U)
@test !ModelingToolkitNeuralNets.isneuralnetwork(p)
@test !ModelingToolkitNeuralNets.isneuralnetworkps(NN)
@test ModelingToolkitNeuralNets.isneuralnetworkps(θ)
@test !ModelingToolkitNeuralNets.isneuralnetworkps(U)
end

# Check that `isneuralnetwork` and `isneuralnetworkps` give correct input on parameters stored in a model created using symbolic approach.
let
# Model created via symbolic neural network representation.
chain = Lux.Chain(
Lux.Dense(1 => 3, Lux.softplus; use_bias = false),
Lux.Dense(3 => 1, Lux.softplus; use_bias = false),
)
@SymbolicNeuralNetwork NN, θ = chain
@variables X(t) Y(t)
@parameters d
eqs = [
D(X) ~ NN([X], θ)[1] - d*X
D(Y) ~ X - d*Y
]
@mtkcompile sys = System(eqs, t)

# Check that content have the correct metadata tags.
@test !ModelingToolkitNeuralNets.isneuralnetwork(sys.X)
@test !ModelingToolkitNeuralNets.isneuralnetwork(sys.d)
@test ModelingToolkitNeuralNets.isneuralnetwork(sys.NN)
@test !ModelingToolkitNeuralNets.isneuralnetwork(sys.θ)
@test !ModelingToolkitNeuralNets.isneuralnetworkps(sys.X)
@test !ModelingToolkitNeuralNets.isneuralnetworkps(sys.d)
@test !ModelingToolkitNeuralNets.isneuralnetworkps(sys.NN)
@test ModelingToolkitNeuralNets.isneuralnetworkps(sys.θ)
end

# Check that `isneuralnetwork` and `isneuralnetworkps` give correct input on parameters stored in a model created using NNBlock approach.
let
# Model created via NeuralNetwork block.
chain = Lux.Chain(
Lux.Dense(2 => 3, Lux.softplus; use_bias = false),
Lux.Dense(3 => 2, Lux.softplus; use_bias = false),
)
@variables x(t) = 3.1 y(t) = 1.5
@parameters α = 1.3 [tunable = false] δ = 1.8 [tunable = false]
@named nn = NeuralNetworkBlock(2, 2; chain)
eqs = [
D(x) ~ α * x + nn.outputs[1],
D(y) ~ -δ * y + nn.outputs[2],
nn.inputs[1] ~ x,
nn.inputs[2] ~ y,
]
@mtkcompile sys_nnblock = System(eqs, t, systems = [nn])

# Check that content have the correct metadata tags.
@test ModelingToolkitNeuralNets.isneuralnetwork(sys_nnblock.nn.lux_apply)
@test !ModelingToolkitNeuralNets.isneuralnetwork(sys_nnblock.nn.lux_model)
@test !ModelingToolkitNeuralNets.isneuralnetwork(sys_nnblock.nn.p)
@test !ModelingToolkitNeuralNets.isneuralnetwork(sys_nnblock.nn.T)
@test !ModelingToolkitNeuralNets.isneuralnetwork(sys_nnblock.α)
@test !ModelingToolkitNeuralNets.isneuralnetwork(sys_nnblock.δ)
@test !ModelingToolkitNeuralNets.isneuralnetwork(sys_nnblock.x)
@test !ModelingToolkitNeuralNets.isneuralnetwork(sys_nnblock.y)

@test !ModelingToolkitNeuralNets.isneuralnetworkps(sys_nnblock.nn.lux_apply)
@test !ModelingToolkitNeuralNets.isneuralnetworkps(sys_nnblock.nn.lux_model)
@test ModelingToolkitNeuralNets.isneuralnetworkps(sys_nnblock.nn.p)
@test !ModelingToolkitNeuralNets.isneuralnetworkps(sys_nnblock.nn.T)
@test !ModelingToolkitNeuralNets.isneuralnetworkps(sys_nnblock.α)
@test !ModelingToolkitNeuralNets.isneuralnetworkps(sys_nnblock.δ)
@test !ModelingToolkitNeuralNets.isneuralnetworkps(sys_nnblock.x)
@test !ModelingToolkitNeuralNets.isneuralnetworkps(sys_nnblock.y)
end

# Checks the `get_nn_chain` accessor function.
let
# Model created via symbolic neural network representation.
chain = Lux.Chain(
Lux.Dense(1 => 3, Lux.softplus; use_bias = false),
Lux.Dense(3 => 1, Lux.softplus; use_bias = false),
)
@SymbolicNeuralNetwork NN, θ = chain
@variables X(t) Y(t)
@parameters d
eqs = [
D(X) ~ NN([X], θ)[1] - d*X
D(Y) ~ X - d*Y
]
@mtkcompile sys = System(eqs, t)

# Checks accessor function.
@test ModelingToolkitNeuralNets.get_nn_chain(NN) == chain
@test_throws ErrorException ModelingToolkitNeuralNets.get_nn_chain(θ)
@test_throws ErrorException ModelingToolkitNeuralNets.get_nn_chain(X)
@test_throws ErrorException ModelingToolkitNeuralNets.get_nn_chain(d)
@test ModelingToolkitNeuralNets.get_nn_chain(sys.NN) == chain
@test_throws ErrorException ModelingToolkitNeuralNets.get_nn_chain(sys.θ)
@test_throws ErrorException ModelingToolkitNeuralNets.get_nn_chain(sys.X)
@test_throws ErrorException ModelingToolkitNeuralNets.get_nn_chain(sys.d)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ using SafeTestsets
@safetestset "Basic" include("lotka_volterra.jl")
@safetestset "MTK model macro compatibility" include("macro.jl")
@safetestset "Symbolic Neural Network Macro" include("symbolicnn_macro.jl")
@safetestset "Neural Network Parameter Metadata" include("nn_ps_accessors.jl")
@safetestset "Reported issues" include("reported_issues.jl")
end
Loading