diff --git a/docs/src/api.md b/docs/src/api.md index 3ba0f52..fc2b4e7 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -5,4 +5,7 @@ NeuralNetworkBlock SymbolicNeuralNetwork @SymbolicNeuralNetwork multi_layer_feed_forward +ModelingToolkitNeuralNets.isneuralnetwork +ModelingToolkitNeuralNets.isneuralnetworkps +ModelingToolkitNeuralNets.get_nn_chain ``` diff --git a/src/ModelingToolkitNeuralNets.jl b/src/ModelingToolkitNeuralNets.jl index 7487927..54ab790 100644 --- a/src/ModelingToolkitNeuralNets.jl +++ b/src/ModelingToolkitNeuralNets.jl @@ -1,6 +1,6 @@ module ModelingToolkitNeuralNets -using ModelingToolkitBase: @parameters, @named, @variables, System, t_nounits +using ModelingToolkitBase: @parameters, @named, @variables, System, t_nounits, getmetadata, hasmetadata using IntervalSets: var".." using Symbolics: Symbolics, @register_array_symbolic, @wrapped, unwrap, wrap, shape using LuxCore: stateless_apply, outputsize @@ -12,6 +12,9 @@ export NeuralNetworkBlock, SymbolicNeuralNetwork, @SymbolicNeuralNetwork, multi_ include("utils.jl") +# Functionality for accessing various neural network-related parameter properties. +include("nn_par_accessors.jl") + """ NeuralNetworkBlock(; n_input = 1, n_output = 1, chain = multi_layer_feed_forward(n_input, n_output), @@ -32,10 +35,10 @@ function NeuralNetworkBlock(; ) ca = ComponentArray{eltype}(init_params) - @parameters p[1:length(ca)] = Vector(ca) [tunable = true] + @parameters p[1:length(ca)] = Vector(ca) [tunable = true, neuralnetworkps = true] @parameters T::typeof(typeof(ca)) = typeof(ca) [tunable = false] @parameters lux_model::typeof(chain) = chain [tunable = false] - @parameters (lux_apply::typeof(stateless_apply))(..)[1:n_output] = stateless_apply [tunable = false] + @parameters (lux_apply::typeof(stateless_apply))(..)[1:n_output] = stateless_apply [tunable = false, neuralnetwork = true] @variables inputs(t_nounits)[1:n_input] [input = true] @variables outputs(t_nounits)[1:n_output] [output = true] @@ -112,8 +115,8 @@ function SymbolicNeuralNetwork(; ca = ComponentArray{eltype}(init_params) wrapper = StatelessApplyWrapper(chain, typeof(ca)) - p = @parameters $(nn_p_name)[1:length(ca)] = Vector(ca) - NN = @parameters ($(nn_name)::typeof(wrapper))(..)[1:n_output] = wrapper + p = @parameters $(nn_p_name)[1:length(ca)] = Vector(ca) [tunable = true, neuralnetworkps = true] + NN = @parameters ($(nn_name)::typeof(wrapper))(..)[1:n_output] = wrapper [tunable = false, neuralnetwork = true] return only(NN), only(p) end @@ -179,7 +182,7 @@ rng = Xoshiro(0) @SymbolicNeuralNetwork NN, p = chain rng ``` Notes: -- The first and last layers of the chain must be one of the following types: `Lux.Dense`. For other first +- The first and last layers of the chain must be one of the following types: `Lux.Dense`. For other first layer types, use the `SymbolicNeuralNetwork` - Types that are intended to be supported in the first layer in future updates include `Lux.Bilinear`, `Lux.RNNCell`, `Lux.LSTMCell`, `Lux.GRUCell`. diff --git a/src/nn_par_accessors.jl b/src/nn_par_accessors.jl new file mode 100644 index 0000000..9515bed --- /dev/null +++ b/src/nn_par_accessors.jl @@ -0,0 +1,70 @@ +### Defines Metadata Type ### +struct NeuralNetworkParameter end +struct NeuralNetworkParametrisation end +Symbolics.option_to_metadata_type(::Val{:neuralnetwork}) = NeuralNetworkParameter +Symbolics.option_to_metadata_type(::Val{:neuralnetworkps}) = NeuralNetworkParametrisation + +### Defines Metadata Getters ### +""" + ModelingToolkitNeuralNets.isneuralnetwork(p) + +Returns `true` if the parameter corresponds to the neural network chain that is saved as a MTK parameter. This function is primarily intended for internal use within dependent packages. + +Example: +```julia +@parameters d +@SymbolicNeuralNetwork NN, θ = chain +ModelingToolkitNeuralNets.isneuralnetwork(d) # false +ModelingToolkitNeuralNets.isneuralnetwork(NN) # true +ModelingToolkitNeuralNets.isneuralnetwork(θ) # false +```` +""" +isneuralnetwork(p::Union{Symbolics.Num, Symbolics.Arr, Symbolics.CallAndWrap}) = isneuralnetwork(Symbolics.unwrap(p)) +function isneuralnetwork(p::Symbolics.SymbolicT) + getmetadata(p, NeuralNetworkParameter, false) +end + +""" + ModelingToolkitNeuralNets.isneuralnetworkps(p) + +Returns `true` if the parameter corresponds to the a neural network parametrisation. This function is primarily intended for internal use within dependent packages. + +Example: +```julia +@parameters d +@SymbolicNeuralNetwork NN, θ = chain +ModelingToolkitNeuralNets.isneuralnetworkps(d) # false +ModelingToolkitNeuralNets.isneuralnetworkps(NN) # false +ModelingToolkitNeuralNets.isneuralnetworkps(θ) # true +```` +""" +isneuralnetworkps(p::Union{Symbolics.Num, Symbolics.Arr, Symbolics.CallAndWrap}) = isneuralnetworkps(Symbolics.unwrap(p)) +function isneuralnetworkps(p::Symbolics.SymbolicT) + getmetadata(p, NeuralNetworkParametrisation, false) +end + + +### Defines Other Accessors ### + +""" + ModelingToolkitNeuralNets.get_nn_chain(p) + +For a neural network parameter `p` (i.e. such that `isneuralnetwork(p) == true`), return the associated neural network chain. This function is primarily intended for internal use within dependent packages. + +Example: +```julia +chain = Lux.Chain( + Lux.Dense(1 => 3, Lux.softplus; use_bias = false), + Lux.Dense(3 => 1, Lux.softplus; use_bias = false), +) +@SymbolicNeuralNetwork NN, θ = chain + +ModelingToolkitNeuralNets.get_nn_chain(NN) # Returns `chain`. +ModelingToolkitNeuralNets.get_nn_chain(θ) # Throws an error. +```` +""" +get_nn_chain(p::Union{Symbolics.Num, Symbolics.Arr, Symbolics.CallAndWrap}) = get_nn_chain(Symbolics.unwrap(p)) +function get_nn_chain(p::Symbolics.SymbolicT) + isneuralnetwork(p) || error("Parameter $p does not have a neural network chain associated with it.") + return getmetadata(p, Symbolics.VariableDefaultValue).lux_model +end diff --git a/test/nn_ps_accessors.jl b/test/nn_ps_accessors.jl new file mode 100644 index 0000000..269feab --- /dev/null +++ b/test/nn_ps_accessors.jl @@ -0,0 +1,121 @@ +# Fetch packages. +using ModelingToolkitBase, ModelingToolkitNeuralNets, Lux, Random +using ModelingToolkitBase: t_nounits as t, D_nounits as D + +# Check that `isneuralnetwork` and `isneuralnetworkps` give correct input on various inputs. +let + # Tests on normally declared parameters. + @variables X(t) Y(t)[1:2] + @parameters p q[1:3] + for s in [X, Y, Y[1], p, q, q[1]] + @test !ModelingToolkitNeuralNets.isneuralnetwork(s) + @test !ModelingToolkitNeuralNets.isneuralnetworkps(s) + end + + # Tests on MTKNeuralNets parameters + chain = Lux.Chain( + Lux.Dense(1 => 3, Lux.softplus; use_bias = false), + Lux.Dense(3 => 1, Lux.softplus; use_bias = false), + ) + @SymbolicNeuralNetwork NN, θ = chain + U, p = SymbolicNeuralNetwork(; chain, n_input = 1, n_output = 1, nn_name = :U, nn_p_name = :p) + @test ModelingToolkitNeuralNets.isneuralnetwork(NN) + @test !ModelingToolkitNeuralNets.isneuralnetwork(θ) + @test ModelingToolkitNeuralNets.isneuralnetwork(U) + @test !ModelingToolkitNeuralNets.isneuralnetwork(p) + @test !ModelingToolkitNeuralNets.isneuralnetworkps(NN) + @test ModelingToolkitNeuralNets.isneuralnetworkps(θ) + @test !ModelingToolkitNeuralNets.isneuralnetworkps(U) +end + +# Check that `isneuralnetwork` and `isneuralnetworkps` give correct input on parameters stored in a model created using symbolic approach. +let + # Model created via symbolic neural network representation. + chain = Lux.Chain( + Lux.Dense(1 => 3, Lux.softplus; use_bias = false), + Lux.Dense(3 => 1, Lux.softplus; use_bias = false), + ) + @SymbolicNeuralNetwork NN, θ = chain + @variables X(t) Y(t) + @parameters d + eqs = [ + D(X) ~ NN([X], θ)[1] - d*X + D(Y) ~ X - d*Y + ] + @mtkcompile sys = System(eqs, t) + + # Check that content have the correct metadata tags. + @test !ModelingToolkitNeuralNets.isneuralnetwork(sys.X) + @test !ModelingToolkitNeuralNets.isneuralnetwork(sys.d) + @test ModelingToolkitNeuralNets.isneuralnetwork(sys.NN) + @test !ModelingToolkitNeuralNets.isneuralnetwork(sys.θ) + @test !ModelingToolkitNeuralNets.isneuralnetworkps(sys.X) + @test !ModelingToolkitNeuralNets.isneuralnetworkps(sys.d) + @test !ModelingToolkitNeuralNets.isneuralnetworkps(sys.NN) + @test ModelingToolkitNeuralNets.isneuralnetworkps(sys.θ) +end + +# Check that `isneuralnetwork` and `isneuralnetworkps` give correct input on parameters stored in a model created using NNBlock approach. +let + # Model created via NeuralNetwork block. + chain = Lux.Chain( + Lux.Dense(2 => 3, Lux.softplus; use_bias = false), + Lux.Dense(3 => 2, Lux.softplus; use_bias = false), + ) + @variables x(t) = 3.1 y(t) = 1.5 + @parameters α = 1.3 [tunable = false] δ = 1.8 [tunable = false] + @named nn = NeuralNetworkBlock(2, 2; chain) + eqs = [ + D(x) ~ α * x + nn.outputs[1], + D(y) ~ -δ * y + nn.outputs[2], + nn.inputs[1] ~ x, + nn.inputs[2] ~ y, + ] + @mtkcompile sys_nnblock = System(eqs, t, systems = [nn]) + + # Check that content have the correct metadata tags. + @test ModelingToolkitNeuralNets.isneuralnetwork(sys_nnblock.nn.lux_apply) + @test !ModelingToolkitNeuralNets.isneuralnetwork(sys_nnblock.nn.lux_model) + @test !ModelingToolkitNeuralNets.isneuralnetwork(sys_nnblock.nn.p) + @test !ModelingToolkitNeuralNets.isneuralnetwork(sys_nnblock.nn.T) + @test !ModelingToolkitNeuralNets.isneuralnetwork(sys_nnblock.α) + @test !ModelingToolkitNeuralNets.isneuralnetwork(sys_nnblock.δ) + @test !ModelingToolkitNeuralNets.isneuralnetwork(sys_nnblock.x) + @test !ModelingToolkitNeuralNets.isneuralnetwork(sys_nnblock.y) + + @test !ModelingToolkitNeuralNets.isneuralnetworkps(sys_nnblock.nn.lux_apply) + @test !ModelingToolkitNeuralNets.isneuralnetworkps(sys_nnblock.nn.lux_model) + @test ModelingToolkitNeuralNets.isneuralnetworkps(sys_nnblock.nn.p) + @test !ModelingToolkitNeuralNets.isneuralnetworkps(sys_nnblock.nn.T) + @test !ModelingToolkitNeuralNets.isneuralnetworkps(sys_nnblock.α) + @test !ModelingToolkitNeuralNets.isneuralnetworkps(sys_nnblock.δ) + @test !ModelingToolkitNeuralNets.isneuralnetworkps(sys_nnblock.x) + @test !ModelingToolkitNeuralNets.isneuralnetworkps(sys_nnblock.y) +end + +# Checks the `get_nn_chain` accessor function. +let + # Model created via symbolic neural network representation. + chain = Lux.Chain( + Lux.Dense(1 => 3, Lux.softplus; use_bias = false), + Lux.Dense(3 => 1, Lux.softplus; use_bias = false), + ) + @SymbolicNeuralNetwork NN, θ = chain + @variables X(t) Y(t) + @parameters d + eqs = [ + D(X) ~ NN([X], θ)[1] - d*X + D(Y) ~ X - d*Y + ] + @mtkcompile sys = System(eqs, t) + + # Checks accessor function. + @test ModelingToolkitNeuralNets.get_nn_chain(NN) == chain + @test_throws ErrorException ModelingToolkitNeuralNets.get_nn_chain(θ) + @test_throws ErrorException ModelingToolkitNeuralNets.get_nn_chain(X) + @test_throws ErrorException ModelingToolkitNeuralNets.get_nn_chain(d) + @test ModelingToolkitNeuralNets.get_nn_chain(sys.NN) == chain + @test_throws ErrorException ModelingToolkitNeuralNets.get_nn_chain(sys.θ) + @test_throws ErrorException ModelingToolkitNeuralNets.get_nn_chain(sys.X) + @test_throws ErrorException ModelingToolkitNeuralNets.get_nn_chain(sys.d) +end diff --git a/test/runtests.jl b/test/runtests.jl index 2fd97b0..c442c60 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,5 +7,6 @@ using SafeTestsets @safetestset "Basic" include("lotka_volterra.jl") @safetestset "MTK model macro compatibility" include("macro.jl") @safetestset "Symbolic Neural Network Macro" include("symbolicnn_macro.jl") + @safetestset "Neural Network Parameter Metadata" include("nn_ps_accessors.jl") @safetestset "Reported issues" include("reported_issues.jl") end