Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/FormatCheck.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: julia-actions/setup-julia@v2
with:
version: '1'
- uses: fredrikekre/runic-action@v1
with:
version: '1'
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ ModelingToolkitStandardLibrary = "2.24"
OptimizationBase = "4.0.2, 5"
OptimizationOptimJL = "0.4.8"
OptimizationOptimisers = "0.3"
OrdinaryDiffEqVerner = "1.2"
OrdinaryDiffEqVerner = "1, 2"
Random = "1.10"
SafeTestsets = "0.1"
SciCompDSL = "1"
Expand Down
6 changes: 3 additions & 3 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ ModelingToolkitStandardLibrary = "2.7"
Optimization = "4.0, 5"
OptimizationOptimJL = "0.4"
OptimizationOptimisers = "0.3"
OrdinaryDiffEqTsit5 = "1"
OrdinaryDiffEqVerner = "1"
OrdinaryDiffEqTsit5 = "1, 2"
OrdinaryDiffEqVerner = "1, 2"
Plots = "1"
SciMLBase = "2"
SciMLBase = "2, 3"
SciMLSensitivity = "7"
SciMLStructures = "1.1.0"
StableRNGs = "1"
Expand Down
12 changes: 8 additions & 4 deletions src/ModelingToolkitNeuralNets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ function make_symbolic_nn_declaration(expr::Expr)
end

# Extracts individual component symbols.
nn,p = expr.args[1].args
nn, p = expr.args[1].args
chain, rng = if Meta.isexpr(expr.args[2], :tuple)
if (length(expr.args[2].args) > 2)
error("@SymbolicNeuralNetwork accepts no more than 2 inputs on the right-hand side.")
Expand All @@ -219,9 +219,13 @@ function make_symbolic_nn_declaration(expr::Expr)
end

# Constructs the output expression.
snn_dec = :(($nn, $p) = SymbolicNeuralNetwork(; chain = $chain, nn_name = $(QuoteNode(nn)),
nn_p_name = $(QuoteNode(p)), n_input = ModelingToolkitNeuralNets._num_chain_inputs($chain),
n_output = ModelingToolkitNeuralNets._num_chain_outputs($chain)))
snn_dec = :(
($nn, $p) = SymbolicNeuralNetwork(;
chain = $chain, nn_name = $(QuoteNode(nn)),
nn_p_name = $(QuoteNode(p)), n_input = ModelingToolkitNeuralNets._num_chain_inputs($chain),
n_output = ModelingToolkitNeuralNets._num_chain_outputs($chain)
)
)
if !isnothing(rng)
push!(snn_dec.args[2].args[2].args, Expr(:kw, :rng, rng))
end
Expand Down
4 changes: 2 additions & 2 deletions src/nn_par_accessors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ ModelingToolkitNeuralNets.isneuralnetwork(θ) # false
"""
isneuralnetwork(p::Union{Symbolics.Num, Symbolics.Arr, Symbolics.CallAndWrap}) = isneuralnetwork(Symbolics.unwrap(p))
function isneuralnetwork(p::Symbolics.SymbolicT)
getmetadata(p, NeuralNetworkParameter, false)
return getmetadata(p, NeuralNetworkParameter, false)
end

"""
Expand All @@ -40,7 +40,7 @@ ModelingToolkitNeuralNets.isneuralnetworkps(θ) # true
"""
isneuralnetworkps(p::Union{Symbolics.Num, Symbolics.Arr, Symbolics.CallAndWrap}) = isneuralnetworkps(Symbolics.unwrap(p))
function isneuralnetworkps(p::Symbolics.SymbolicT)
getmetadata(p, NeuralNetworkParametrisation, false)
return getmetadata(p, NeuralNetworkParametrisation, false)
end


Expand Down
8 changes: 4 additions & 4 deletions test/nn_ps_accessors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ let
@variables X(t) Y(t)
@parameters d
eqs = [
D(X) ~ NN([X], θ)[1] - d*X
D(Y) ~ X - d*Y
D(X) ~ NN([X], θ)[1] - d * X
D(Y) ~ X - d * Y
]
@mtkcompile sys = System(eqs, t)

Expand Down Expand Up @@ -104,8 +104,8 @@ let
@variables X(t) Y(t)
@parameters d
eqs = [
D(X) ~ NN([X], θ)[1] - d*X
D(Y) ~ X - d*Y
D(X) ~ NN([X], θ)[1] - d * X
D(Y) ~ X - d * Y
]
@mtkcompile sys = System(eqs, t)

Expand Down
15 changes: 7 additions & 8 deletions test/symbolicnn_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ let
end



# Checks that symbolic networks declared with/without the macro are identical (1).
let
# Declares the neural networks.
Expand Down Expand Up @@ -86,20 +85,20 @@ let
@SymbolicNeuralNetwork NN, p = chain
NN_func, p_func = SymbolicNeuralNetwork(; chain, n_input = 1, n_output = 1, nn_name = :NN, nn_p_name = :p)

# Checks that they are identical.
# Checks that they are identical.
@test isequal(NN, NN_func)
@test isequal(p, p_func)

# Creates corresponding MTK models.
@variables X(t) Y(t)
@parameters d
eqs_macro = [
D(X) ~ NN([X], p)[1] - d*X
D(Y) ~ X - d*Y
D(X) ~ NN([X], p)[1] - d * X
D(Y) ~ X - d * Y
]
eqs_func = [
D(X) ~ NN_func([X], p_func)[1] - d*X
D(Y) ~ X - d*Y
D(X) ~ NN_func([X], p_func)[1] - d * X
D(Y) ~ X - d * Y
]
@mtkcompile sys_macro = System(eqs_macro, t)
@mtkcompile sys_func = System(eqs_func, t)
Expand Down Expand Up @@ -149,8 +148,8 @@ let

# Checks that non-supported neural network architectures throw errors.
@test_throws Exception @eval @SymbolicNeuralNetwork NN, p = Lux.Chain(
Lux.Conv((3, 3), 1 => 8, Lux.relu; pad=1),
Lux.Conv((3, 3), 8 => 16, Lux.relu; pad=1),
Lux.Conv((3, 3), 1 => 8, Lux.relu; pad = 1),
Lux.Conv((3, 3), 8 => 16, Lux.relu; pad = 1),
Lux.Dense(28 * 28 * 16 => 10)
)
end
Expand Down
Loading