From 757ea70e83e36970c35652d110a63f73e09a3b74 Mon Sep 17 00:00:00 2001 From: Ady0333 Date: Fri, 20 Mar 2026 09:25:34 +0530 Subject: [PATCH] Add NeuralOperatorROMs submodule: DeepONet-based non-intrusive ROM MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Proof of concept for GSoC 2026 "Reduced Order Modelling with Neural Operators". Adds a new NeuralOperatorROMs submodule that implements a DeepONet-based non-intrusive reduced order model for parametric PDEs: - Snapshots.jl: Parameter sampling (LHS), FEM solve loop, free-DOF coordinate extraction via cell connectivity - DeepONet.jl: Branch-trunk architecture as a Lux.jl layer with precomputed trunk matrix for O(N·p) online inference - Training.jl: MSE training loop with normalization, early stopping, and Zygote AD - Reconstruction.jl: Predicted DOFs → Gridap FEFunction reconstruction Integration points: - src/GridapROMs.jl: includes NeuralOperatorROMs module - src/Exports.jl: re-exports all public symbols via @publish - Project.toml: adds Lux, Optimisers, Zygote dependencies - test/runtests.jl: adds NeuralOperatorROMs test suite (319 tests) - examples/poisson_deeponet.jl: end-to-end demo on parametric Poisson Signed-off-by: Aditya --- Project.toml | 6 + examples/poisson_deeponet.jl | 139 ++++++++++++++ src/Exports.jl | 14 ++ src/GridapROMs.jl | 2 + src/NeuralOperatorROMs/DeepONet.jl | 100 ++++++++++ src/NeuralOperatorROMs/NeuralOperatorROMs.jl | 41 ++++ src/NeuralOperatorROMs/Reconstruction.jl | 47 +++++ src/NeuralOperatorROMs/Snapshots.jl | 138 +++++++++++++ src/NeuralOperatorROMs/Training.jl | 192 +++++++++++++++++++ test/NeuralOperatorROMs/runtests.jl | 147 ++++++++++++++ test/runtests.jl | 2 + 11 files changed, 828 insertions(+) create mode 100644 examples/poisson_deeponet.jl create mode 100644 src/NeuralOperatorROMs/DeepONet.jl create mode 100644 src/NeuralOperatorROMs/NeuralOperatorROMs.jl create mode 100644 src/NeuralOperatorROMs/Reconstruction.jl create mode 100644 src/NeuralOperatorROMs/Snapshots.jl create mode 100644 src/NeuralOperatorROMs/Training.jl create mode 100644 test/NeuralOperatorROMs/runtests.jl diff --git a/Project.toml b/Project.toml index 7dabf08e..316d96b5 100755 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,10 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ArraysOfArrays = "0.6" @@ -64,7 +67,10 @@ SparseMatricesCSR = "0.6" StaticArrays = "1" Statistics = "1" StatsBase = "0.34" +Lux = "1" +Optimisers = "0.3, 0.4" UnPack = "1" +Zygote = "0.6" julia = "1.9" [extras] diff --git a/examples/poisson_deeponet.jl b/examples/poisson_deeponet.jl new file mode 100644 index 00000000..14c71fd9 --- /dev/null +++ b/examples/poisson_deeponet.jl @@ -0,0 +1,139 @@ +# End-to-end PoC: Neural Operator ROM for Parametric Poisson Equation +# +# Problem: -∇·(κ(μ) ∇u) = f on [0,1]², u = 0 on ∂Ω +# +# The diffusivity coefficient κ depends on parameters: +# κ(x; μ) = μ₁ + μ₂ · sin(π·x₁) · sin(π·x₂) +# +# We train a DeepONet to learn the map μ → u_h(μ) (the DOF vector), +# then predict solutions for new parameters without running FEM. + +module PoissonDeepONet + +using Gridap +using GridapNeuralROMs +using LinearAlgebra +using Random +using Statistics + +function main(;n_train=60,n_test=10,epochs=300,seed=42) + rng = Random.MersenneTwister(seed) + + println("="^60) + println("Neural Operator ROM for Parametric Poisson") + println("="^60) + + # ── 1. Gridap FEM setup ────────────────────────────────────────── + + domain = (0,1,0,1) + partition = (16,16) + model = CartesianDiscreteModel(domain,partition) + + order = 1 + reffe = ReferenceFE(lagrangian,Float64,order) + V = TestFESpace(model,reffe;conformity=:H1,dirichlet_tags="boundary") + U = TrialFESpace(V,0.0) + + Ω = Triangulation(model) + dΩ = Measure(Ω,2*order) + + N_dofs = num_free_dofs(V) + println("\nMesh: $(partition[1])×$(partition[2]), Free DOFs: $N_dofs") + + # ── 2. Parametric solver ───────────────────────────────────────── + # Black-box function: μ → FEFunction + # This is what makes neural operator ROMs non-intrusive — + # we only need input/output pairs, not the PDE operators. + + function solve_poisson(μ) + κ(x) = μ[1] + μ[2] * sin(π*x[1]) * sin(π*x[2]) + f(x) = 1.0 + a(u,v) = ∫( κ ⊙ (∇(u)⋅∇(v)) )dΩ + l(v) = ∫( f*v )dΩ + op = AffineFEOperator(a,l,U,V) + return Gridap.solve(op) + end + + # ── 3. Collect snapshots ───────────────────────────────────────── + + param_bounds = [(0.1,5.0),(0.0,4.0)] # μ₁ ∈ [0.1,5], μ₂ ∈ [0,4] + + println("\nGenerating $n_train training snapshots...") + train_params = sample_parameters(param_bounds,n_train) + t_snap = @elapsed train_data = collect_snapshots( + solve_poisson,train_params;trial=U + ) + println(" Snapshot generation: $(round(t_snap,digits=2))s") + println(" Solution matrix: $(size(train_data.solutions))") + println(" Coordinate matrix: $(size(train_data.coordinates))") + + # ── 4. Build and train DeepONet ────────────────────────────────── + + d_param = length(first(train_params)) + spatial_dim = size(train_data.coordinates,1) + + deeponet = build_deeponet(; + param_dim=d_param, + n_dofs=N_dofs, + spatial_dim=spatial_dim, + latent_dim=32, + branch_width=64, + trunk_width=64, + n_branch_layers=3, + n_trunk_layers=3, + ) + + println("\nTraining DeepONet (latent_dim=32, width=64, 3 layers)...") + config = TrainingConfig(;epochs=epochs,lr=1e-3,batch_size=16, + patience=80,verbose=true) + t_train = @elapsed result = train_operator(train_data,deeponet;config,rng) + println(" Training time: $(round(t_train,digits=2))s") + println(" Best epoch: $(result.best_epoch)") + println(" Final val loss: $(round(result.val_losses[result.best_epoch],digits=6))") + + # ── 5. Evaluate on test parameters ────────────────────────────── + + println("\nEvaluating on $n_test test parameters...") + test_params = sample_parameters(param_bounds,n_test) + + relative_errors = Float64[] + speedups = Float64[] + + for (i,μ) in enumerate(test_params) + # Ground truth via FEM + t_fem = @elapsed uh_true = solve_poisson(μ) + dofs_true = get_free_dof_values(uh_true) + + # Neural operator prediction + t_rom = @elapsed uh_pred = reconstruct_fe_function(result,μ,U) + dofs_pred = get_free_dof_values(uh_pred) + + # Relative L2 error on DOF vector + err = norm(dofs_pred .- dofs_true) / norm(dofs_true) + push!(relative_errors,err) + push!(speedups,t_fem / max(t_rom,1e-10)) + end + + # ── 6. Report ─────────────────────────────────────────────────── + + println("\n" * "="^60) + println("RESULTS") + println("="^60) + println(" Mean relative L2 error: $(round(mean(relative_errors),digits=6))") + println(" Max relative L2 error: $(round(maximum(relative_errors),digits=6))") + println(" Min relative L2 error: $(round(minimum(relative_errors),digits=6))") + println(" Mean speedup (FEM/ROM): $(round(mean(speedups),digits=1))×") + println(" Training samples: $n_train") + println(" Test samples: $n_test") + println(" FEM DOFs: $N_dofs") + println("="^60) + + return (;relative_errors,speedups,result) +end + +end # module + +# Run if executed directly +if abspath(PROGRAM_FILE) == @__FILE__ + PoissonDeepONet.main() +end diff --git a/src/Exports.jl b/src/Exports.jl index eb7ce330..d4ba44d5 100755 --- a/src/Exports.jl +++ b/src/Exports.jl @@ -146,3 +146,17 @@ using GridapROMs.Extensions: ⊕; export ⊕ @publish RBTransient TransientHyperReduction @publish RBTransient TransientProjection @publish RBTransient TransientRBOperator + +@publish NeuralOperatorROMs SnapshotData +@publish NeuralOperatorROMs collect_snapshots +@publish NeuralOperatorROMs extract_coordinates +@publish NeuralOperatorROMs sample_parameters +@publish NeuralOperatorROMs DeepONetLayer +@publish NeuralOperatorROMs build_deeponet +@publish NeuralOperatorROMs precompute_trunk_matrix +@publish NeuralOperatorROMs NormalizationStats +@publish NeuralOperatorROMs TrainingConfig +@publish NeuralOperatorROMs TrainingResult +@publish NeuralOperatorROMs train_operator +@publish NeuralOperatorROMs reconstruct_fe_function +@publish NeuralOperatorROMs evaluate_rom diff --git a/src/GridapROMs.jl b/src/GridapROMs.jl index 19f47dfa..bc94ea41 100755 --- a/src/GridapROMs.jl +++ b/src/GridapROMs.jl @@ -26,6 +26,8 @@ include("RB/RBTransient/RBTransient.jl") include("Distributed/Distributed.jl") +include("NeuralOperatorROMs/NeuralOperatorROMs.jl") + include("Exports.jl") end diff --git a/src/NeuralOperatorROMs/DeepONet.jl b/src/NeuralOperatorROMs/DeepONet.jl new file mode 100644 index 00000000..c9e5e299 --- /dev/null +++ b/src/NeuralOperatorROMs/DeepONet.jl @@ -0,0 +1,100 @@ +# DeepONet implementation for parametric PDE ROMs using Lux.jl. +# +# Key insight for PDE ROMs: the trunk network evaluates at fixed DOF +# coordinates that don't change between queries. So we precompute the +# trunk matrix T ∈ R^{N_dofs × p} once, and online prediction reduces to: +# +# û(μ) = T · b(μ) + bias +# +# where b(μ) = branch(μ) ∈ R^p. This is an O(N·p) matrix-vector product, +# independent of the FEM assembly cost. + +""" + DeepONetLayer{B,T} <: Lux.AbstractLuxContainerLayer{(:branch,:trunk)} + +Custom Lux layer implementing DeepONet for parametric PDE surrogate modelling. + +The branch network encodes the parameter vector μ into a latent code b(μ). +The trunk network encodes spatial coordinates x into basis functions t(x). +The output at DOF location xᵢ is: û_i = Σₖ bₖ(μ) · tₖ(xᵢ) + bias_i. +""" +struct DeepONetLayer{B,T} <: Lux.AbstractLuxContainerLayer{(:branch,:trunk)} + branch::B + trunk::T + latent_dim::Int + n_dofs::Int +end + +""" + build_deeponet(; + param_dim, n_dofs, spatial_dim, + latent_dim=32, branch_width=64, trunk_width=64, + n_branch_layers=2, n_trunk_layers=2, + activation=Lux.gelu + ) -> DeepONetLayer + +Construct a DeepONet for mapping parameter vectors to FEM DOF vectors. + +- `param_dim`: dimension of the parameter space (branch input) +- `n_dofs`: number of free DOFs in the FEM discretization (output dim) +- `spatial_dim`: spatial dimension D of the mesh (trunk input) +- `latent_dim`: dimension p of the shared latent space +""" +function build_deeponet(; + param_dim::Int, + n_dofs::Int, + spatial_dim::Int, + latent_dim::Int=32, + branch_width::Int=64, + trunk_width::Int=64, + n_branch_layers::Int=2, + n_trunk_layers::Int=2, + activation=Lux.gelu +) + # Branch: μ ∈ R^d → b(μ) ∈ R^p + branch_layers = Any[Dense(param_dim,branch_width,activation)] + for _ in 2:n_branch_layers + push!(branch_layers,Dense(branch_width,branch_width,activation)) + end + push!(branch_layers,Dense(branch_width,latent_dim)) + branch = Chain(branch_layers...) + + # Trunk: x ∈ R^D → t(x) ∈ R^p + trunk_layers = Any[Dense(spatial_dim,trunk_width,activation)] + for _ in 2:n_trunk_layers + push!(trunk_layers,Dense(trunk_width,trunk_width,activation)) + end + push!(trunk_layers,Dense(trunk_width,latent_dim)) + trunk = Chain(trunk_layers...) + + return DeepONetLayer(branch,trunk,latent_dim,n_dofs) +end + +""" + precompute_trunk_matrix(model, coord_matrix, ps, st) -> Matrix{Float32} + +Evaluate the trunk network at all DOF coordinates once. +Returns T ∈ R^{N_dofs × latent_dim} so that prediction is T * b(μ) + bias. +""" +function precompute_trunk_matrix(model::DeepONetLayer,coord_matrix::AbstractMatrix,ps,st) + # coord_matrix: D × N_nodes (Float32) + trunk_out,_ = Lux.apply(model.trunk,coord_matrix,ps.trunk,st.trunk) + # trunk_out: latent_dim × N_nodes + return Matrix(transpose(trunk_out)) # N_nodes × latent_dim +end + +# Forward pass: branch encodes μ, then multiply with precomputed trunk matrix. +# Input: x = (mu, trunk_matrix) packed as a tuple +# Output: û ∈ R^{N_dofs × batch} +function (l::DeepONetLayer)(x::Tuple,ps,st) + mu,trunk_matrix = x + # mu: d_param × batch + b,new_st_branch = Lux.apply(l.branch,mu,ps.branch,st.branch) + # b: latent_dim × batch + + # û = T * b + bias → (N_dofs × batch) + u_hat = trunk_matrix * b + + new_st = (branch=new_st_branch,trunk=st.trunk) + return u_hat,new_st +end diff --git a/src/NeuralOperatorROMs/NeuralOperatorROMs.jl b/src/NeuralOperatorROMs/NeuralOperatorROMs.jl new file mode 100644 index 00000000..37eefa6d --- /dev/null +++ b/src/NeuralOperatorROMs/NeuralOperatorROMs.jl @@ -0,0 +1,41 @@ +module NeuralOperatorROMs + +using Gridap +using Gridap.FESpaces +using Gridap.Geometry +using Gridap.CellData +using Gridap.ReferenceFEs + +using Lux +using Optimisers +using Zygote +using Random +using LinearAlgebra +using Statistics + +include("Snapshots.jl") + +include("DeepONet.jl") + +include("Training.jl") + +include("Reconstruction.jl") + +export SnapshotData +export collect_snapshots +export extract_coordinates +export sample_parameters + +export DeepONetLayer +export build_deeponet +export precompute_trunk_matrix + +export NormalizationStats +export TrainingConfig +export TrainingResult +export train_operator + +export reconstruct_fe_function +export evaluate_rom + +end # module diff --git a/src/NeuralOperatorROMs/Reconstruction.jl b/src/NeuralOperatorROMs/Reconstruction.jl new file mode 100644 index 00000000..fdebd9da --- /dev/null +++ b/src/NeuralOperatorROMs/Reconstruction.jl @@ -0,0 +1,47 @@ +# Reconstruct Gridap FEFunction from neural operator predictions. +# +# This closes the loop: the neural operator outputs a raw DOF vector, +# and we wrap it back into a Gridap FEFunction so it can be used for +# visualization (writevtk), error computation, and postprocessing +# with the full Gridap ecosystem. + +""" + evaluate_rom(result::TrainingResult, μ::AbstractVector) -> Vector{Float64} + +Predict the DOF vector for a new parameter μ using the trained neural operator. +Returns denormalized DOF values ready for FEFunction reconstruction. +""" +function evaluate_rom(result::TrainingResult,μ::AbstractVector) + mu_col = Float32.(reshape(μ,length(μ),1)) + mu_n = normalize(result.input_norm,mu_col) + + u_hat_n,_ = Lux.apply( + result.model,(mu_n,result.trunk_matrix),result.params,result.state + ) + u_hat = denormalize(result.output_norm,vec(u_hat_n)) + return Float64.(u_hat) +end + +""" + reconstruct_fe_function( + result::TrainingResult, + μ::AbstractVector, + trial::FESpace + ) -> FEFunction + +Evaluate the neural operator at parameter μ and reconstruct a Gridap +FEFunction. The trial space provides the Dirichlet DOF values and the +mesh/basis function information needed to build a proper FEFunction. + +This is the key integration point with Gridap: predicted DOFs go in via +`FEFunction(trial, free_values)`, the same constructor Gridap uses +internally after solving a linear system. +""" +function reconstruct_fe_function( + result::TrainingResult, + μ::AbstractVector, + trial::FESpace +) + predicted_dofs = evaluate_rom(result,μ) + return FEFunction(trial,predicted_dofs) +end diff --git a/src/NeuralOperatorROMs/Snapshots.jl b/src/NeuralOperatorROMs/Snapshots.jl new file mode 100644 index 00000000..6d24e5ea --- /dev/null +++ b/src/NeuralOperatorROMs/Snapshots.jl @@ -0,0 +1,138 @@ +# Snapshot collection: solve a parametric PDE at sampled parameters, +# extract free DOF vectors as training data for neural operators. +# +# This mirrors the role of GridapROMs.jl's `solution_snapshots` but in a +# minimal, non-intrusive form: we only need the DOF vectors and coordinates, +# not the residual/Jacobian snapshots that Galerkin projection requires. + +""" + SnapshotData + +Training dataset for a neural operator ROM. Stores parameter vectors, +the corresponding FEM solution DOF vectors, and mesh coordinate data. + +Fields: +- `parameters`: d_param × M matrix, each column is one parameter sample μ +- `solutions`: N_dofs × M matrix, each column is get_free_dof_values(uh(μ)) +- `coordinates`: D × N_free matrix, spatial coordinates for each free DOF +""" +struct SnapshotData + parameters::Matrix{Float64} + solutions::Matrix{Float64} + coordinates::Matrix{Float64} +end + +""" + extract_coordinates(trial::FESpace) -> Matrix{Float64} + +Extract spatial coordinates for each free DOF as a D × N_free matrix. +These serve as input to the trunk network in DeepONet. + +For Lagrangian elements, each DOF corresponds to a mesh node. We build +a mapping from free DOF index to node coordinate using the cell-level +DOF and node connectivity. This ensures the coordinate matrix has exactly +N_free columns (matching the solution DOF vector length), excluding +Dirichlet boundary DOFs. +""" +function extract_coordinates(trial::FESpace) + trian = get_triangulation(trial) + all_coords = get_node_coordinates(trian) + D = length(first(all_coords)) + N_free = num_free_dofs(trial) + + # Build free-DOF → node-coordinate mapping via cell connectivity + cell_dof_ids = get_cell_dof_ids(trial) + cell_node_ids = Gridap.Geometry.get_cell_node_ids(trian) + + coord_matrix = zeros(Float64,D,N_free) + filled = falses(N_free) + + for cell in 1:length(cell_dof_ids) + dofs = cell_dof_ids[cell] + nodes = cell_node_ids[cell] + for (local_i,dof_id) in enumerate(dofs) + if dof_id > 0 && !filled[dof_id] + c = all_coords[nodes[local_i]] + for d in 1:D + coord_matrix[d,dof_id] = c[d] + end + filled[dof_id] = true + end + end + end + + return coord_matrix +end + +""" + collect_snapshots( + solver_fn, + param_samples::Vector{<:AbstractVector}; + trial::FESpace + ) -> SnapshotData + +Solve a parametric PDE for each parameter sample and collect DOF vectors. + +Arguments: +- `solver_fn`: a function μ -> uh::FEFunction that solves the PDE at parameter μ +- `param_samples`: vector of parameter vectors [{μ₁}, {μ₂}, ...] +- `trial`: the trial FESpace (needed for coordinate extraction) + +The `solver_fn` encapsulates the entire Gridap problem setup: mesh, spaces, +weak form, and linear solve. This keeps the snapshot collector non-intrusive — +it treats the FEM solver as a black box, which is exactly the philosophy +behind neural operator ROMs. +""" +function collect_snapshots( + solver_fn, + param_samples::Vector{<:AbstractVector}; + trial::FESpace +) + M = length(param_samples) + d_param = length(first(param_samples)) + + # Solve for first sample to determine DOF dimension + uh_first = solver_fn(param_samples[1]) + dofs_first = get_free_dof_values(uh_first) + N_dofs = length(dofs_first) + + # Allocate storage + parameters = zeros(Float64,d_param,M) + solutions = zeros(Float64,N_dofs,M) + + # Store first solution + parameters[:,1] .= param_samples[1] + solutions[:,1] .= dofs_first + + # Solve remaining + for i in 2:M + μ = param_samples[i] + uh = solver_fn(μ) + parameters[:,i] .= μ + solutions[:,i] .= get_free_dof_values(uh) + end + + coordinates = extract_coordinates(trial) + return SnapshotData(parameters,solutions,coordinates) +end + +""" + sample_parameters(bounds::Vector{Tuple{Float64,Float64}}, n::Int) -> Vector{Vector{Float64}} + +Latin hypercube sampling over a box-constrained parameter space. +Each element of `bounds` is (lower, upper) for one parameter dimension. +""" +function sample_parameters(bounds::Vector{Tuple{Float64,Float64}},n::Int) + D = length(bounds) + samples = Vector{Vector{Float64}}(undef,n) + perms = [randperm(n) for _ in 1:D] + for i in 1:n + μ = zeros(D) + for d in 1:D + lo,hi = bounds[d] + μ[d] = lo + (perms[d][i] - rand()) / n * (hi - lo) + end + samples[i] = μ + end + return samples +end diff --git a/src/NeuralOperatorROMs/Training.jl b/src/NeuralOperatorROMs/Training.jl new file mode 100644 index 00000000..e11b04d2 --- /dev/null +++ b/src/NeuralOperatorROMs/Training.jl @@ -0,0 +1,192 @@ +# Training loop for neural operator ROMs. +# +# Follows the standard Lux.jl pattern: explicit parameters, Zygote AD, +# Optimisers.jl for parameter updates. The training data comes from +# SnapshotData produced by collect_snapshots. + +""" + NormalizationStats + +Per-feature mean/std for input and output normalization. +Normalization is critical for training stability — FEM DOF values can span +orders of magnitude, and parameter ranges are problem-dependent. +""" +struct NormalizationStats + mean::Vector{Float32} + std::Vector{Float32} +end + +function normalize(s::NormalizationStats,x::AbstractMatrix) + return (x .- s.mean) ./ s.std +end + +function denormalize(s::NormalizationStats,x::AbstractMatrix) + return x .* s.std .+ s.mean +end + +function denormalize(s::NormalizationStats,x::AbstractVector) + return x .* s.std .+ s.mean +end + +function compute_normalization(x::AbstractMatrix) + m = vec(mean(x,dims=2)) + s = vec(std(x,dims=2)) + s[s .< 1f-8] .= 1f0 + return NormalizationStats(Float32.(m),Float32.(s)) +end + +""" + TrainingConfig(; kwargs...) + +Hyperparameters for neural operator training. +""" +struct TrainingConfig + epochs::Int + lr::Float64 + batch_size::Int + validation_split::Float64 + patience::Int + verbose::Bool +end + +function TrainingConfig(; + epochs=500,lr=1e-3,batch_size=16, + validation_split=0.1,patience=50,verbose=true +) + TrainingConfig(epochs,lr,batch_size,validation_split,patience,verbose) +end + +""" + TrainingResult + +Output of `train_operator`: the trained model, its parameters/state, +normalization statistics, the precomputed trunk matrix, and loss history. +""" +struct TrainingResult + model::DeepONetLayer + params::Any + state::Any + trunk_matrix::Matrix{Float32} + input_norm::NormalizationStats + output_norm::NormalizationStats + train_losses::Vector{Float64} + val_losses::Vector{Float64} + best_epoch::Int +end + +""" + train_operator(data::SnapshotData, model::DeepONetLayer; + config=TrainingConfig(), rng=Random.default_rng()) -> TrainingResult + +Train a DeepONet on FEM snapshot data. + +The training loop: +1. Normalizes parameters and DOF vectors +2. Precomputes trunk matrix T from mesh coordinates (recomputed each epoch + since trunk weights change) +3. Minimizes MSE between predicted and true DOF vectors +4. Uses early stopping on a held-out validation set +""" +function train_operator( + data::SnapshotData, + model::DeepONetLayer; + config::TrainingConfig=TrainingConfig(), + rng::AbstractRNG=Random.default_rng() +) + M = size(data.parameters,2) + + # Normalization + input_norm = compute_normalization(data.parameters) + output_norm = compute_normalization(data.solutions) + + params_n = Float32.(normalize(input_norm,data.parameters)) + sols_n = Float32.(normalize(output_norm,data.solutions)) + coords_f32 = Float32.(data.coordinates) + + # Train/val split + n_val = max(1,round(Int,M * config.validation_split)) + n_train = M - n_val + perm = randperm(rng,M) + train_idx = perm[1:n_train] + val_idx = perm[n_train+1:end] + + # Initialize model + ps,st = Lux.setup(rng,model) + opt_state = Optimisers.setup(Adam(Float32(config.lr)),ps) + + # Tracking + train_losses = Float64[] + val_losses = Float64[] + best_val_loss = Inf + best_ps = deepcopy(ps) + best_epoch = 0 + patience_counter = 0 + + for epoch in 1:config.epochs + # Precompute trunk matrix with current trunk weights + trunk_matrix = precompute_trunk_matrix(model,coords_f32,ps,st) + + # Mini-batch SGD + shuffle!(rng,train_idx) + epoch_loss = 0.0 + n_batches = 0 + + for batch_start in 1:config.batch_size:n_train + batch_end = min(batch_start + config.batch_size - 1,n_train) + idx = train_idx[batch_start:batch_end] + + mu_batch = params_n[:,idx] + u_batch = sols_n[:,idx] + + (loss,_),grads = Zygote.withgradient(ps) do p + u_hat,_ = Lux.apply(model,(mu_batch,trunk_matrix),p,st) + mse = mean((u_hat .- u_batch).^2) + return mse,nothing + end + + opt_state,ps = Optimisers.update(opt_state,ps,grads[1]) + epoch_loss += loss + n_batches += 1 + end + + avg_train_loss = epoch_loss / n_batches + push!(train_losses,avg_train_loss) + + # Validation + trunk_matrix = precompute_trunk_matrix(model,coords_f32,ps,st) + mu_val = params_n[:,val_idx] + u_val = sols_n[:,val_idx] + u_hat_val,_ = Lux.apply(model,(mu_val,trunk_matrix),ps,st) + val_loss = mean((u_hat_val .- u_val).^2) + push!(val_losses,val_loss) + + if config.verbose && (epoch % 50 == 0 || epoch == 1) + println(" Epoch $epoch/$( config.epochs): " * + "train_loss=$(round(avg_train_loss,digits=6)), " * + "val_loss=$(round(Float64(val_loss),digits=6))") + end + + # Early stopping + if val_loss < best_val_loss + best_val_loss = val_loss + best_ps = deepcopy(ps) + best_epoch = epoch + patience_counter = 0 + else + patience_counter += 1 + if patience_counter >= config.patience + config.verbose && println(" Early stopping at epoch $epoch (best: $best_epoch)") + break + end + end + end + + # Final trunk matrix with best parameters + trunk_matrix = precompute_trunk_matrix(model,coords_f32,best_ps,st) + + return TrainingResult( + model,best_ps,st,trunk_matrix, + input_norm,output_norm, + train_losses,val_losses,best_epoch + ) +end diff --git a/test/NeuralOperatorROMs/runtests.jl b/test/NeuralOperatorROMs/runtests.jl new file mode 100644 index 00000000..1f71ac17 --- /dev/null +++ b/test/NeuralOperatorROMs/runtests.jl @@ -0,0 +1,147 @@ +using Test +using GridapROMs +using GridapROMs.NeuralOperatorROMs +using Gridap +using Lux +using Zygote +using Random +using LinearAlgebra +using Statistics + +@testset "NeuralOperatorROMs" begin + + @testset "snapshot collection" begin + # Minimal Gridap problem: Poisson on [0,1]² + model = CartesianDiscreteModel((0,1,0,1),(8,8)) + reffe = ReferenceFE(lagrangian,Float64,1) + V = TestFESpace(model,reffe;conformity=:H1,dirichlet_tags="boundary") + U = TrialFESpace(V,0.0) + Ω = Triangulation(model) + dΩ = Measure(Ω,2) + + function solve_poisson(μ) + κ(x) = μ[1] + a(u,v) = ∫(κ⊙(∇(u)⋅∇(v)))dΩ + l(v) = ∫(1.0*v)dΩ + op = AffineFEOperator(a,l,U,V) + Gridap.solve(op) + end + + params = [[1.0],[2.0],[3.0]] + data = collect_snapshots(solve_poisson,params;trial=U) + + @test size(data.parameters) == (1,3) + @test size(data.solutions,2) == 3 + @test size(data.solutions,1) == num_free_dofs(V) + @test size(data.coordinates,1) == 2 # 2D + + # Different κ should give different solutions + @test data.solutions[:,1] ≉ data.solutions[:,2] + # Higher κ → smaller solution magnitude (for -∇·(κ∇u)=1) + @test norm(data.solutions[:,3]) < norm(data.solutions[:,1]) + end + + @testset "coordinate extraction" begin + model = CartesianDiscreteModel((0,1,0,1),(4,4)) + reffe = ReferenceFE(lagrangian,Float64,1) + V = TestFESpace(model,reffe;conformity=:H1,dirichlet_tags="boundary") + U = TrialFESpace(V,0.0) + + coords = extract_coordinates(U) + @test size(coords,1) == 2 # 2D + @test all(0.0 .<= coords[1,:] .<= 1.0) + @test all(0.0 .<= coords[2,:] .<= 1.0) + end + + @testset "DeepONet construction and forward pass" begin + rng = Random.MersenneTwister(123) + + net = build_deeponet(; + param_dim=2,n_dofs=49,spatial_dim=2, + latent_dim=16,branch_width=32,trunk_width=32, + n_branch_layers=2,n_trunk_layers=2, + ) + + ps,st = Lux.setup(rng,net) + + # Mock data + coords = Float32.(rand(rng,2,49)) + trunk_matrix = precompute_trunk_matrix(net,coords,ps,st) + @test size(trunk_matrix) == (49,16) + + mu = Float32.(rand(rng,2,5)) # batch of 5 + u_hat,_ = Lux.apply(net,(mu,trunk_matrix),ps,st) + @test size(u_hat) == (49,5) + + # Gradients should flow + (loss,_),grads = Zygote.withgradient(ps) do p + u,_ = Lux.apply(net,(mu,trunk_matrix),p,st) + return sum(u.^2),nothing + end + @test loss > 0 + @test !isnothing(grads[1]) + end + + @testset "training smoke test" begin + rng = Random.MersenneTwister(456) + + # Synthetic data: simple linear map μ → u = μ₁ * ones(10) + M = 30 + params = randn(rng,1,M) + sols = params[1,:]' .* ones(10) # 10 × M + + data = SnapshotData( + Float64.(params), + Float64.(sols), + Float64.(rand(rng,2,10)) # dummy coordinates + ) + + net = build_deeponet(; + param_dim=1,n_dofs=10,spatial_dim=2, + latent_dim=8,branch_width=16,trunk_width=16, + ) + + config = TrainingConfig(;epochs=100,lr=1e-3,batch_size=8, + patience=100,verbose=false) + result = train_operator(data,net;config,rng) + + # Loss should decrease + @test result.train_losses[end] < result.train_losses[1] + @test length(result.val_losses) > 0 + + # Prediction should be in the right ballpark + μ_test = [2.0] + pred = evaluate_rom(result,μ_test) + @test length(pred) == 10 + expected = 2.0 * ones(10) + @test norm(pred .- expected) / norm(expected) < 0.5 # within 50% for smoke test + end + + @testset "FEFunction reconstruction" begin + model = CartesianDiscreteModel((0,1,0,1),(4,4)) + reffe = ReferenceFE(lagrangian,Float64,1) + V = TestFESpace(model,reffe;conformity=:H1,dirichlet_tags="boundary") + U = TrialFESpace(V,0.0) + + N = num_free_dofs(V) + fake_dofs = rand(N) + + # Verify we can construct FEFunction from DOF vector + uh = FEFunction(U,fake_dofs) + recovered = get_free_dof_values(uh) + @test recovered ≈ fake_dofs + end + + @testset "parameter sampling" begin + bounds = [(0.0,1.0),(10.0,20.0),(-5.0,5.0)] + samples = sample_parameters(bounds,100) + @test length(samples) == 100 + @test all(length(s) == 3 for s in samples) + for s in samples + @test 0.0 <= s[1] <= 1.0 + @test 10.0 <= s[2] <= 20.0 + @test -5.0 <= s[3] <= 5.0 + end + end + +end # top-level testset diff --git a/test/runtests.jl b/test/runtests.jl index 63541e7e..0dbb7447 100755 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,4 +16,6 @@ using Test @testset "moving elasticity" begin include("RBMovingGeometries/moving_elasticity.jl") end @testset "moving stokes" begin include("RBMovingGeometries/moving_stokes.jl") end +@testset "NeuralOperatorROMs" begin include("NeuralOperatorROMs/runtests.jl") end + end # module