From 26622235580ae44b11be92756691261b922da076 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 16 Dec 2025 14:25:49 +0100 Subject: [PATCH 1/6] gloabal sensitivity example --- examples/scripts/Project.toml | 1 + .../scripts/soil_heat_global_sensitivity.jl | 57 +++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 examples/scripts/soil_heat_global_sensitivity.jl diff --git a/examples/scripts/Project.toml b/examples/scripts/Project.toml index 03b35b927..9dfaf41d8 100644 --- a/examples/scripts/Project.toml +++ b/examples/scripts/Project.toml @@ -2,6 +2,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" Checkpointing = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" GeoMakie = "db073c08-6b98-4ee5-b6a4-5efafb3259c6" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" diff --git a/examples/scripts/soil_heat_global_sensitivity.jl b/examples/scripts/soil_heat_global_sensitivity.jl new file mode 100644 index 000000000..cf2a9aaac --- /dev/null +++ b/examples/scripts/soil_heat_global_sensitivity.jl @@ -0,0 +1,57 @@ +using Terrarium + +using CUDA +using Dates +using Rasters, NCDatasets +using Statistics + +using CairoMakie, GeoMakie +using Enzyme, Checkpointing + +import RingGrids +import SpeedyWeather + +# run on GPU if available +arch = CUDA.functional() ? Terrarium.GPU() : Terrarium.CPU() + +# Load land-sea mask at ~1° resolution +land_sea_frac = convert.(Float32, dropdims(Raster("inputs/era5-land_land_sea_mask_N72.nc"), dims=Ti)) +land_sea_frac_field = RingGrids.FullGaussianField(Matrix(land_sea_frac), input_as=Matrix) +heatmap(land_sea_frac_field) + +# Set up grids +land_mask = land_sea_frac_field .> 0.5 # select only grid points with > 50% land +grid = ColumnRingGrid(arch, Float64, ExponentialSpacing(N=30), land_mask.grid, land_mask) +lon, lat = RingGrids.get_londlatds(grid.rings) + +# Initial conditions +initializer = FieldInitializers( + # steady-ish state initial condition for temperature + temperature = (x,z) -> -1 - 0.01*z, + # fully saturated soil pores + saturation_water_ice = 1.0, +) +model = SoilModel(grid; initializer) +# constant surface temperature of 1°C +bcs = PrescribedSurfaceTemperature(:T_ub, 1.0) +integrator = initialize(model, ForwardEuler(), boundary_conditions=bcs) + +# spin up a little +@time run!(integrator, period=Day(5), Δt=900.0) + +# Enzyme prep +scheme = Revolve(1) +dintegrator = make_zero(integrator) +N_t = 200 + +autodiff(set_runtime_activity(Enzyme.Reverse), run!, Const, Duplicated(integrator, dintegrator), Const(scheme), Const(N_t)) + +function run_sim!(integrater, N_t) + run!(integrater, steps=N_t, Δt=900.0) + return nothing +end + +N_t = 1 +autodiff(set_runtime_activity(Enzyme.Reverse), run_sim!, Const, Duplicated(integrator, dintegrator), Const(N_t)) + + From c657502e6fd133c9b25c48d8a6a6d7edd1eeff8c Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 16 Dec 2025 14:36:42 +0100 Subject: [PATCH 2/6] add some comments --- examples/scripts/soil_heat_global_sensitivity.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/scripts/soil_heat_global_sensitivity.jl b/examples/scripts/soil_heat_global_sensitivity.jl index cf2a9aaac..84d4c66a3 100644 --- a/examples/scripts/soil_heat_global_sensitivity.jl +++ b/examples/scripts/soil_heat_global_sensitivity.jl @@ -44,6 +44,7 @@ scheme = Revolve(1) dintegrator = make_zero(integrator) N_t = 200 +# this uses checkpointing autodiff(set_runtime_activity(Enzyme.Reverse), run!, Const, Duplicated(integrator, dintegrator), Const(scheme), Const(N_t)) function run_sim!(integrater, N_t) @@ -51,6 +52,7 @@ function run_sim!(integrater, N_t) return nothing end +# no checkpointing N_t = 1 autodiff(set_runtime_activity(Enzyme.Reverse), run_sim!, Const, Duplicated(integrator, dintegrator), Const(N_t)) From 1829be9ee16f27047f0210c030a12fcfd419fb5e Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 16 Dec 2025 15:12:00 +0100 Subject: [PATCH 3/6] no mask in example --- examples/scripts/soil_heat_global_sensitivity.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/scripts/soil_heat_global_sensitivity.jl b/examples/scripts/soil_heat_global_sensitivity.jl index 84d4c66a3..dde4bad52 100644 --- a/examples/scripts/soil_heat_global_sensitivity.jl +++ b/examples/scripts/soil_heat_global_sensitivity.jl @@ -21,7 +21,9 @@ heatmap(land_sea_frac_field) # Set up grids land_mask = land_sea_frac_field .> 0.5 # select only grid points with > 50% land -grid = ColumnRingGrid(arch, Float64, ExponentialSpacing(N=30), land_mask.grid, land_mask) + +# for now let's do it actually without the mask to keep the example simple +grid = ColumnRingGrid(arch, Float64, ExponentialSpacing(N=30), land_mask.grid) #, land_mask.grid, land_mask) lon, lat = RingGrids.get_londlatds(grid.rings) # Initial conditions From 000cf94039e7883effd6dbb706225a32d9acb6d4 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 16 Dec 2025 16:15:44 +0100 Subject: [PATCH 4/6] working example in Julia 1.10 --- examples/scripts/soil_heat_global_sensitivity.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/scripts/soil_heat_global_sensitivity.jl b/examples/scripts/soil_heat_global_sensitivity.jl index dde4bad52..33e651ba4 100644 --- a/examples/scripts/soil_heat_global_sensitivity.jl +++ b/examples/scripts/soil_heat_global_sensitivity.jl @@ -1,3 +1,4 @@ +# currently only works in Julia 1.10 with Enzyme using Terrarium using CUDA @@ -26,6 +27,10 @@ land_mask = land_sea_frac_field .> 0.5 # select only grid points with > 50% land grid = ColumnRingGrid(arch, Float64, ExponentialSpacing(N=30), land_mask.grid) #, land_mask.grid, land_mask) lon, lat = RingGrids.get_londlatds(grid.rings) +# alternative with ColumnGrid +#grid = ColumnGrid(arch, Float64, ExponentialSpacing(N=30)) #, land_mask.grid, land_mask) +#grid = ColumnGrid(ExponentialSpacing()) #, land_mask.grid, land_mask) + # Initial conditions initializer = FieldInitializers( # steady-ish state initial condition for temperature From b552093b93b04ef0478e21dca56a4e298c462563 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 16 Dec 2025 16:23:22 +0100 Subject: [PATCH 5/6] minor clean up --- examples/scripts/soil_heat_global_sensitivity.jl | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/examples/scripts/soil_heat_global_sensitivity.jl b/examples/scripts/soil_heat_global_sensitivity.jl index 33e651ba4..cc3cc44ce 100644 --- a/examples/scripts/soil_heat_global_sensitivity.jl +++ b/examples/scripts/soil_heat_global_sensitivity.jl @@ -53,14 +53,3 @@ N_t = 200 # this uses checkpointing autodiff(set_runtime_activity(Enzyme.Reverse), run!, Const, Duplicated(integrator, dintegrator), Const(scheme), Const(N_t)) - -function run_sim!(integrater, N_t) - run!(integrater, steps=N_t, Δt=900.0) - return nothing -end - -# no checkpointing -N_t = 1 -autodiff(set_runtime_activity(Enzyme.Reverse), run_sim!, Const, Duplicated(integrator, dintegrator), Const(N_t)) - - From 285bea0724e654da330f8dda991d8a3b5186717f Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 16 Dec 2025 17:55:23 +0100 Subject: [PATCH 6/6] update script --- examples/scripts/soil_heat_global_sensitivity.jl | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/examples/scripts/soil_heat_global_sensitivity.jl b/examples/scripts/soil_heat_global_sensitivity.jl index cc3cc44ce..c16eb8767 100644 --- a/examples/scripts/soil_heat_global_sensitivity.jl +++ b/examples/scripts/soil_heat_global_sensitivity.jl @@ -20,6 +20,11 @@ land_sea_frac = convert.(Float32, dropdims(Raster("inputs/era5-land_land_sea_mas land_sea_frac_field = RingGrids.FullGaussianField(Matrix(land_sea_frac), input_as=Matrix) heatmap(land_sea_frac_field) +# Load ERA-5 2 meter air temperature at ~1° resolution +Tair_raster = Raster("inputs/external/era5-land/2m_temperature/era5_land_2m_temperature_2023_N72.nc") +Tsurf_0 = convert.(Float32, replace_missing(Tair_raster, NaN)) .- 273.15f0 +# heatmap(Tair_raster[:,:,1]) + # Set up grids land_mask = land_sea_frac_field .> 0.5 # select only grid points with > 50% land @@ -27,9 +32,9 @@ land_mask = land_sea_frac_field .> 0.5 # select only grid points with > 50% land grid = ColumnRingGrid(arch, Float64, ExponentialSpacing(N=30), land_mask.grid) #, land_mask.grid, land_mask) lon, lat = RingGrids.get_londlatds(grid.rings) -# alternative with ColumnGrid -#grid = ColumnGrid(arch, Float64, ExponentialSpacing(N=30)) #, land_mask.grid, land_mask) -#grid = ColumnGrid(ExponentialSpacing()) #, land_mask.grid, land_mask) +# Construct input sources +Tair_forcing = InputSource(grid, rebuild(Tair_raster, name=:Tair)) +Tsurf_0 = Tair_raster[Ti(1)][findall(land_mask)] # Initial conditions initializer = FieldInitializers( @@ -40,8 +45,8 @@ initializer = FieldInitializers( ) model = SoilModel(grid; initializer) # constant surface temperature of 1°C -bcs = PrescribedSurfaceTemperature(:T_ub, 1.0) -integrator = initialize(model, ForwardEuler(), boundary_conditions=bcs) +boundary_conditions = PrescribedSurfaceTemperature(:Tair) +integrator = initialize(model, ForwardEuler(), Tair_forcing; boundary_conditions) # spin up a little @time run!(integrator, period=Day(5), Δt=900.0)