Add metadata and getter functions to check what type of NN parameter the package generates#112
Conversation
|
|
||
| Returns `true` if the parameter corresponds to the a neural network parametrisation. | ||
| """ | ||
| isneuralnetworkps(p::Union{Symbolics.Num, AbstractVector{Symbolics.Num}}) = isneuralnetworkps(Symbolics.value(p)) |
There was a problem hiding this comment.
Not including AbstractVector{Symbolics.Num} caused errors when applied to a vector-valued parameter (like theta).
There was a problem hiding this comment.
That should be Symbolics.Arr, the θ should be a symbolic array, not a (julia) array of symbolic vars
SebastianM-C
left a comment
There was a problem hiding this comment.
Thanks for adding this. I do have a few comments on the implementation
|
|
||
| Returns `true` if the parameter corresponds to the a neural network parametrisation. | ||
| """ | ||
| isneuralnetworkps(p::Union{Symbolics.Num, AbstractVector{Symbolics.Num}}) = isneuralnetworkps(Symbolics.value(p)) |
There was a problem hiding this comment.
That should be Symbolics.Arr, the θ should be a symbolic array, not a (julia) array of symbolic vars
|
Also, it might make sense to also include a |
Co-authored-by: Sebastian Micluța-Câmpeanu <31181429+SebastianM-C@users.noreply.github.com>
Co-authored-by: Sebastian Micluța-Câmpeanu <31181429+SebastianM-C@users.noreply.github.com>
|
Thanks a lot for the help fixing those! I could add the Should I also add an accessor that takes a neural network parameter ( |
Yeah, that could be useful to find the neural network block in a hierarchical system where you walk through the system
I think that a more useful helper would be something that gets you the component array out, since that has more steps, but you can add both here. I would say that we should still keep showing in the docs how to access the values from the solution with the usual symbolic indexing, since that emphasizes that the NN and its parameters are just like any other regular MTK parameters, but that does not mean we can't have some helpers. |
SebastianM-C
left a comment
There was a problem hiding this comment.
Looks good to me now, we can continue with the other features mentioned or we can do that separately, however you prefer.
|
Sounds good, will update with additional function in the afternoon.
Should I have an accessor function that checks if a |
|
I think that the hasneuralnetwork helper is enough for this pry, we can discuss the more advanced needs like "is this system an ude?" separately of this PR to not slow it down unnecessarily / add scope creep. I'd mostly say that if you need to check if something is an ude programmatically, you probably have some more specific needs, so I'm not sure if we need anything extra at the MTKNN level. |
|
I have now added so that we have the following functions: ModelingToolkitNeuralNets.isneuralnetwork
ModelingToolkitNeuralNets.hasneuralnetwork
ModelingToolkitNeuralNets.isneuralnetworkps
ModelingToolkitNeuralNets.hasneuralnetworkps
ModelingToolkitNeuralNets.get_nn_chainWas going to do the chain = Lux.Chain(
Lux.Dense(1 => 3, Lux.softplus; use_bias = false),
Lux.Dense(3 => 1, Lux.softplus; use_bias = false),
)
@SymbolicNeuralNetwork NN, θ = chainI.e. getmetadata(θ, Symbolics.VariableDefaultValue)(but I think maybe you meant something else as that is just a a normal vector) Happy to update further with your suggestions. |
|
Also, I agree on your point that the docs probably should just show the current methods for accessing stuff within models, those are already really neat. I added comments in the docstrings that these functions are mostly for internal use in package development. |
|
Thanks for adding that! |
|
Sounds good, let's go with this (which should be enough for the PEtab integration) in this PR and then we can look at a follow-up. It seems like these tests @parameters p1 [neuralnetwork = true] p2 [neuralnetwork = false] p3 [neuralnetworkps = true] p4 [neuralnetworkps = false] p5 p6
@test ModelingToolkitNeuralNets.hasneuralnetwork(p1)
@test ModelingToolkitNeuralNets.hasneuralnetwork(p2)
@test ModelingToolkitNeuralNets.hasneuralnetworkps(p3)
@test ModelingToolkitNeuralNets.hasneuralnetworkps(p4)fails on LTS specifically. The function is rather simple hasneuralnetwork(p::Union{Symbolics.Num, Symbolics.Arr, Symbolics.CallAndWrap}) = hasneuralnetwork(Symbolics.unwrap(p))
function hasneuralnetwork(p::Symbolics.SymbolicT)
hasmetadata(p, NeuralNetworkParameter)
endso not sure why. @AayushSabharwal, do you have any idea? I tried having Codex have a look at it, and it claimed that the metadata might not get added properly, but I also think the syntax for accessing metadata should be same on v1.10. |
Would close #109.
Basically, add functions
that can be used like
These can then be used by downstream packages like PEtab and similar to adapt workflows to handle UDEs better.
If we want to go with something like this, @AayushSabharwal might want to do a double check that I have implemented the metadata correctly.