Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
name = "XESMF"
uuid = "2e0b0046-e7a1-486f-88de-807ee8ffabe5"
authors = ["NumericalEarth and contributors"]
version = "0.1.6"
version = "0.2.0"

[deps]
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
Oceananigans = "9e8cae18-63c1-5223-a75c-80ca9d6e9a09"

[extensions]
XESMFOceananigansExt = "Oceananigans"

[compat]
CondaPkg = "0.2.31"
Oceananigans = "0.98, 0.99, 0.100"
Oceananigans = "0.107, 0.108"
PythonCall = "0.9.27"
SparseArrays = "1"
Test = "1"
Expand Down
203 changes: 203 additions & 0 deletions ext/XESMFOceananigansExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
module XESMFOceananigansExt

using Oceananigans
using Oceananigans.Architectures: architecture, on_architecture
using Oceananigans.Fields: AbstractField, topology, location
using Oceananigans.Grids: AbstractGrid, λnodes, φnodes, Center, Face, total_length
# Always load XESMF _after_ Oceananigans.jl (and MPI.jl) so that we load Julia's
# libmpi before XESFM loads its own.
using XESMF

import Oceananigans.Fields: regrid!
import Oceananigans.Architectures: on_architecture
import XESMF: Regridder, xesmf_coordinates

node_array(ξ::AbstractMatrix, Nx, Ny) = view(ξ, 1:Nx, 1:Ny)

x_node_array(x::AbstractVector, Nx, Ny) = repeat(view(x, 1:Nx), 1, Ny)
x_node_array(x::AbstractMatrix, Nx, Ny) = node_array(x, Nx, Ny)

y_node_array(y::AbstractVector, Nx, Ny) = repeat(transpose(view(y, 1:Ny)), Nx, 1)
y_node_array(y::AbstractMatrix, Nx, Ny) = node_array(y, Nx, Ny)

vertex_array(ξ::AbstractMatrix, Nx, Ny) = view(ξ, 1:Nx+1, 1:Ny+1)

x_vertex_array(x::AbstractVector, Nx, Ny) = repeat(view(x, 1:Nx+1), 1, Ny+1)
x_vertex_array(x::AbstractMatrix, Nx, Ny) = vertex_array(x, Nx, Ny)

y_vertex_array(y::AbstractVector, Nx, Ny) = repeat(transpose(view(y, 1:Ny+1)), Nx+1, 1)
y_vertex_array(y::AbstractMatrix, Nx, Ny) = vertex_array(y, Nx, Ny)

"""
xesmf_coordinates(grid::AbstractGrid, ℓx, ℓy, ℓz)

Extract the coordinates (latitude/longitude) and the coordinates' bounds from
`grid` at locations `ℓx, ℓy, ℓz`.
"""
function xesmf_coordinates(grid::AbstractGrid, ℓx, ℓy, ℓz)
Nx, Ny, Nz = size(grid)

# Do we need to use ℓx and ℓy eventually?
λ = λnodes(grid, Center(), Center(), ℓz, with_halos=true)
φ = φnodes(grid, Center(), Center(), ℓz, with_halos=true)
λv = λnodes(grid, Face(), Face(), ℓz, with_halos=true)
φv = φnodes(grid, Face(), Face(), ℓz, with_halos=true)

# Build data structures expected by xESMF
Nx, Ny, Nz = size(grid)

λ = x_node_array(λ, Nx, Ny)
φ = y_node_array(φ, Nx, Ny)
λv = x_vertex_array(λv, Nx, Ny)
φv = y_vertex_array(φv, Nx, Ny)

# Python's xESMF expects 2D arrays with (x, y) coordinates
# in which y varies in dim=1 and x varies in dim=2
# therefore we transpose the coordinate matrices
coords_dictionary = Dict("lat" => permutedims(φ, (2, 1)), # φ is latitude
"lon" => permutedims(λ, (2, 1)), # λ is longitude
"lat_b" => permutedims(φv, (2, 1)),
"lon_b" => permutedims(λv, (2, 1)))

return coords_dictionary
end

"""
xesmf_coordinates(field::AbstractField)

Extract the coordinates (latitude/longitude) and the coordinates' bounds from
the `field`'s grid.
"""
function xesmf_coordinates(field::AbstractField)
ℓx, ℓy, ℓz = Oceananigans.Fields.instantiated_location(field)
return xesmf_coordinates(field.grid, ℓx, ℓy, ℓz)
end

"""
Regridder(dst_field::AbstractField, src_field::AbstractField; method="conservative")

Return a regridder from `src_field` to `dst_field` using the specified `method`.
The regridder contains a sparse matrix with the regridding weights.
The regridding weights are obtained via xESMF Python package.
xESMF exposes five different regridding algorithms from the ESMF library,
specified with the `method` keyword argument:

* `"bilinear"`: `ESMF.RegridMethod.BILINEAR`
* `"conservative"`: `ESMF.RegridMethod.CONSERVE`
* `"conservative_normed"`: `ESMF.RegridMethod.CONSERVE`
* `"patch"`: `ESMF.RegridMethod.PATCH`
* `"nearest_s2d"`: `ESMF.RegridMethod.NEAREST_STOD`
* `"nearest_d2s"`: `ESMF.RegridMethod.NEAREST_DTOS`

where `conservative_normed` is just the conservative method with the normalization set to
`ESMF.NormType.FRACAREA` instead of the default `norm_type = ESMF.NormType.DSTAREA`.

For more information, see the Python xESMF documentation at:

> https://xesmf.readthedocs.io/en/latest/notebooks/Compare_algorithms.html

Example
=======

To create a regridder for two fields that live on different grids.

```@example regridding
using Oceananigans
using XESMF

z = (-1, 0)
tg = TripolarGrid(; size=(180, 85, 1), z, southernmost_latitude = -80)
llg = LatitudeLongitudeGrid(; size=(170, 80, 1), z,
longitude=(0, 360), latitude=(-82, 90))

src_field = CenterField(tg)
dst_field = CenterField(llg)

regridder = XESMF.Regridder(dst_field, src_field, method="conservative")
```

We can use the above regridder to regrid via [`regrid!`](@ref).
"""
function Regridder(dst_field::AbstractField, src_field::AbstractField; method="conservative")

ℓx, ℓy, ℓz = Oceananigans.Fields.instantiated_location(src_field)

# We only support regridding between centered fields
@assert ℓx isa Center
@assert ℓy isa Center
@assert (ℓx, ℓy, ℓz) == Oceananigans.Fields.instantiated_location(dst_field)

src_Nz = size(src_field)[3]
dst_Nz = size(dst_field)[3]
@assert src_field.grid.z.cᵃᵃᶠ[1:src_Nz+1] == dst_field.grid.z.cᵃᵃᶠ[1:dst_Nz+1]

dst_coordinates = xesmf_coordinates(dst_field)
src_coordinates = xesmf_coordinates(src_field)
periodic = Oceananigans.Grids.topology(src_field.grid, 1) === Periodic ? true : false

regridder = XESMF.Regridder(src_coordinates, dst_coordinates; method, periodic)
weights = regridder.weights

arch = architecture(src_field)

weights = on_architecture(arch, weights)

temp_src = on_architecture(architecture(src_field), regridder.src_temp)
temp_dst = on_architecture(architecture(dst_field), regridder.dst_temp)

return XESMF.Regridder(method, weights, temp_src, temp_dst)
end

on_architecture(on, r::XESMF.Regridder) = XESMF.Regridder(on_architecture(on, r.method),
on_architecture(on, r.weights),
on_architecture(on, r.src_temp),
on_architecture(on, r.dst_temp))

"""
regrid!(dst_field, regrider::XESMF.Regridder, src_field)

Regrid `src_field` onto the grid of field `dst_field` using the regrider `r`.

Example
=======

```@example
using Oceananigans
using XESMF

z = (-1, 0)

tg = TripolarGrid(; size=(360, 170, 1), z, southernmost_latitude = -80)

llg = LatitudeLongitudeGrid(; size=(360, 180, 1), z,
longitude=(0, 360), latitude=(-82, 90))

src_field = CenterField(tg)
dst_field = CenterField(llg)

λ₀, φ₀ = 150, 30 # degrees
width = 12 # degrees
set!(src_field, (λ, φ, z) -> exp(-((λ - λ₀)^2 + (φ - φ₀)^2) / 2width^2))

regridder = XESMF.Regridder(dst_field, src_field, method="conservative")

regrid!(dst_field, regridder, src_field)

first(Field(Integral(dst_field, dims=(1, 2))))
```
"""
function regrid!(dst_field, regridder::XESMF.Regridder, src_field)
Nz = size(src_field.grid)[3]
topo_z = topology(src_field)[3]()
ℓz = location(src_field)[3]()

for k in 1:total_length(ℓz, topo_z, Nz)
src = vec(interior(src_field, :, :, k))
dst = vec(interior(dst_field, :, :, k))
regridder(dst, src)
end

return dst_field
end

end # module
46 changes: 46 additions & 0 deletions test/test_oceananigans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ include("setup_runtests.jl")
using Oceananigans
using Oceananigans.Fields: AbstractField
using SparseArrays
using LinearAlgebra

function x_node_array(x::AbstractVector, Nx, Ny)
return Array(repeat(view(x, 1:Nx), 1, Ny))'
Expand Down Expand Up @@ -119,3 +120,48 @@ end
@test all(dense_ll .== strided_ll)
end
end

gaussian_bump(λ, φ; λ₀=0, φ₀=0, width=10) = exp(-((λ - λ₀)^2 + (φ - φ₀)^2) / 2width^2)

@testset "XESMF extension" begin
@info "Testing XESMF regridding..."

arch = CPU()
z = (-1, 0)
southernmost_latitude = -80
radius = Oceananigans.defaults.planet_radius

llg_coarse = LatitudeLongitudeGrid(arch; z, radius,
size = (176, 88, 1),
longitude = (0, 360),
latitude = (southernmost_latitude, 90))

llg_fine = LatitudeLongitudeGrid(arch; z, radius,
size = (360, 170, 1),
longitude = (0, 360),
latitude = (southernmost_latitude, 90))

tg = TripolarGrid(arch; size=(360, 170, 1), z, southernmost_latitude, radius)

for (src_grid, dst_grid) in ((llg_coarse, llg_fine),
(llg_fine, llg_coarse),
(tg, llg_fine))

@info " Regridding from $(nameof(typeof(src_grid))) to $(nameof(typeof(dst_grid)))"

src_field = CenterField(src_grid)
dst_field = CenterField(dst_grid)

width = 12 # degrees
set!(src_field, (λ, φ, z) -> gaussian_bump(λ, φ; λ₀=150, φ₀=30, width) - 2gaussian_bump(λ, φ; λ₀=270, φ₀=-20, width))

regridder = XESMF.Regridder(dst_field, src_field)
@test regridder.weights isa SparseMatrixCSC

regrid!(dst_field, regridder, src_field)

# ∫ dst_field dA ≈ ∫ src_field dA
@test isapprox(first(Field(Integral(dst_field, dims=(1, 2)))),
first(Field(Integral(src_field, dims=(1, 2)))), rtol=1e-4)
end
end
Loading