Skip to content

Add metadata and getter functions to check what type of NN parameter the package generates#112

Merged
SebastianM-C merged 8 commits into
SciML:mainfrom
TorkelE:add_nn_metadata
Mar 26, 2026
Merged

Add metadata and getter functions to check what type of NN parameter the package generates#112
SebastianM-C merged 8 commits into
SciML:mainfrom
TorkelE:add_nn_metadata

Conversation

@TorkelE
Copy link
Copy Markdown
Member

@TorkelE TorkelE commented Mar 24, 2026

Would close #109.

Basically, add functions

ModelingToolkitNeuralNets.isneuralnetwork
ModelingToolkitNeuralNets.isneuralnetworkps

that can be used like

@SymbolicNeuralNetwork NN, θ = chain
@variables X(t)
@parameters d
ModelingToolkitNeuralNets.isneuralnetwork(NN) # true
ModelingToolkitNeuralNets.isneuralnetwork(θ) # false
ModelingToolkitNeuralNets.isneuralnetwork(d) # false
ModelingToolkitNeuralNets.isneuralnetworkps(NN) # false
ModelingToolkitNeuralNets.isneuralnetworkps(θ) # true
ModelingToolkitNeuralNets.isneuralnetworkps(d) # false

These can then be used by downstream packages like PEtab and similar to adapt workflows to handle UDEs better.

If we want to go with something like this, @AayushSabharwal might want to do a double check that I have implemented the metadata correctly.

Comment thread src/ModelingToolkitNeuralNets.jl
Comment thread src/nn_par_metadata.jl Outdated

Returns `true` if the parameter corresponds to the a neural network parametrisation.
"""
isneuralnetworkps(p::Union{Symbolics.Num, AbstractVector{Symbolics.Num}}) = isneuralnetworkps(Symbolics.value(p))
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not including AbstractVector{Symbolics.Num} caused errors when applied to a vector-valued parameter (like theta).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should be Symbolics.Arr, the θ should be a symbolic array, not a (julia) array of symbolic vars

Copy link
Copy Markdown
Member

@SebastianM-C SebastianM-C left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this. I do have a few comments on the implementation

Comment thread src/ModelingToolkitNeuralNets.jl
Comment thread src/nn_par_metadata.jl Outdated
Comment thread src/nn_par_metadata.jl Outdated
Comment thread src/nn_par_metadata.jl Outdated

Returns `true` if the parameter corresponds to the a neural network parametrisation.
"""
isneuralnetworkps(p::Union{Symbolics.Num, AbstractVector{Symbolics.Num}}) = isneuralnetworkps(Symbolics.value(p))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should be Symbolics.Arr, the θ should be a symbolic array, not a (julia) array of symbolic vars

@SebastianM-C
Copy link
Copy Markdown
Member

Also, it might make sense to also include a hasneuralnetwork and the ps variant too to match metadata implementations in MTK

TorkelE and others added 2 commits March 24, 2026 22:00
Co-authored-by: Sebastian Micluța-Câmpeanu <31181429+SebastianM-C@users.noreply.github.com>
Co-authored-by: Sebastian Micluța-Câmpeanu <31181429+SebastianM-C@users.noreply.github.com>
@TorkelE
Copy link
Copy Markdown
Member Author

TorkelE commented Mar 24, 2026

Thanks a lot for the help fixing those!

I could add the hasneuralnetwork version (i.e. just returning true if the metadata is there, whatever its value, right?).

Should I also add an accessor that takes a neural network parameter (NN) and returns the chain stored in it? I think PEtab will want that, but that function could just live ine PEtab isntead, so depends on whenever you'd want it here or not.

@SebastianM-C
Copy link
Copy Markdown
Member

I could add the hasneuralnetwork version (i.e. just returning true if the metadata is there, whatever its value, right?).

Yeah, that could be useful to find the neural network block in a hierarchical system where you walk through the system

Should I also add an accessor that takes a neural network parameter (NN) and returns the chain stored in it? I think PEtab will want that, but that function could just live ine PEtab isntead, so depends on whenever you'd want it here or not.

I think that a more useful helper would be something that gets you the component array out, since that has more steps, but you can add both here. I would say that we should still keep showing in the docs how to access the values from the solution with the usual symbolic indexing, since that emphasizes that the NN and its parameters are just like any other regular MTK parameters, but that does not mean we can't have some helpers.

Copy link
Copy Markdown
Member

@SebastianM-C SebastianM-C left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me now, we can continue with the other features mentioned or we can do that separately, however you prefer.

@TorkelE
Copy link
Copy Markdown
Member Author

TorkelE commented Mar 25, 2026

Sounds good, will update with additional function in the afternoon.

Yeah, that could be useful to find the neural network block in a hierarchical system where you walk through the system

Should I have an accessor function that checks if a System has a neural network, i.e. is an UDE, for this?

@SebastianM-C
Copy link
Copy Markdown
Member

SebastianM-C commented Mar 25, 2026

I think that the hasneuralnetwork helper is enough for this pry, we can discuss the more advanced needs like "is this system an ude?" separately of this PR to not slow it down unnecessarily / add scope creep.

I'd mostly say that if you need to check if something is an ude programmatically, you probably have some more specific needs, so I'm not sure if we need anything extra at the MTKNN level.

@TorkelE
Copy link
Copy Markdown
Member Author

TorkelE commented Mar 25, 2026

I have now added so that we have the following functions:

ModelingToolkitNeuralNets.isneuralnetwork
ModelingToolkitNeuralNets.hasneuralnetwork
ModelingToolkitNeuralNets.isneuralnetworkps
ModelingToolkitNeuralNets.hasneuralnetworkps
ModelingToolkitNeuralNets.get_nn_chain

Was going to do the ComponentArray fetcher as well, but got a little bit uncertain exactly what you ijntended with this one. I.e. do you mean the default values stored here in θ?

chain = Lux.Chain(
    Lux.Dense(1 => 3, Lux.softplus; use_bias = false),
    Lux.Dense(3 => 1, Lux.softplus; use_bias = false),
)
@SymbolicNeuralNetwork NN, θ = chain

I.e.

getmetadata(θ, Symbolics.VariableDefaultValue)

(but I think maybe you meant something else as that is just a a normal vector)

Happy to update further with your suggestions.

@TorkelE
Copy link
Copy Markdown
Member Author

TorkelE commented Mar 25, 2026

Also, I agree on your point that the docs probably should just show the current methods for accessing stuff within models, those are already really neat. I added comments in the docstrings that these functions are mostly for internal use in package development.

@SebastianM-C
Copy link
Copy Markdown
Member

Thanks for adding that!
For the ComponentArray fetcher I was referring to the same convert process that we do internally to switch from the plain vector to the ComponentArray, but that can be something separate.

@TorkelE
Copy link
Copy Markdown
Member Author

TorkelE commented Mar 25, 2026

Sounds good, let's go with this (which should be enough for the PEtab integration) in this PR and then we can look at a follow-up.

It seems like these tests

@parameters p1 [neuralnetwork = true] p2 [neuralnetwork = false] p3 [neuralnetworkps = true] p4 [neuralnetworkps = false] p5 p6
@test ModelingToolkitNeuralNets.hasneuralnetwork(p1)
@test ModelingToolkitNeuralNets.hasneuralnetwork(p2)
@test ModelingToolkitNeuralNets.hasneuralnetworkps(p3)
@test ModelingToolkitNeuralNets.hasneuralnetworkps(p4)

fails on LTS specifically. The function is rather simple

hasneuralnetwork(p::Union{Symbolics.Num, Symbolics.Arr, Symbolics.CallAndWrap}) = hasneuralnetwork(Symbolics.unwrap(p))
function hasneuralnetwork(p::Symbolics.SymbolicT)
    hasmetadata(p, NeuralNetworkParameter)
end

so not sure why. @AayushSabharwal, do you have any idea? I tried having Codex have a look at it, and it claimed that the metadata might not get added properly, but I also think the syntax for accessing metadata should be same on v1.10.

@SebastianM-C SebastianM-C merged commit f33fec7 into SciML:main Mar 26, 2026
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add metadata to neural network parameters, designating them as such

3 participants