From 3bd935e36bbed7e0f5c4e5aebb83ce7eafd71c95 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 1 Oct 2025 00:23:21 +0100 Subject: [PATCH 01/20] MixedInterpolator --- firedrake/assemble.py | 28 --- firedrake/interpolation.py | 164 ++++++++++++++---- .../firedrake/regression/test_interpolate.py | 47 +++++ 3 files changed, 173 insertions(+), 66 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 497e596edf..2c3707cb74 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -18,7 +18,6 @@ import finat.ufl from firedrake import (extrusion_utils as eutils, matrix, parameters, solving, tsfc_interface, utils) -from firedrake.formmanipulation import split_form from firedrake.adjoint_utils import annotate_assemble from firedrake.ufl_expr import extract_unique_domain from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit @@ -570,36 +569,9 @@ def base_form_assembly_visitor(self, expr, tensor, *args): rank = len(expr.arguments()) if rank > 2: raise ValueError("Cannot assemble an Interpolate with more than two arguments") - # If argument numbers have been swapped => Adjoint. - arg_operand = ufl.algorithms.extract_arguments(operand) - is_adjoint = (arg_operand and arg_operand[0].number() == 0) - # Get the target space V = v.function_space().dual() - # Dual interpolation from mixed source - if is_adjoint and len(V) > 1: - cur = 0 - sub_operands = [] - components = numpy.reshape(operand, (-1,)) - for Vi in V: - sub_operands.append(ufl.as_tensor(components[cur:cur+Vi.value_size].reshape(Vi.value_shape))) - cur += Vi.value_size - - # Component-split of the primal operands interpolated into the dual argument-split - split_interp = sum(reconstruct_interp(sub_operands[i], v=vi) for (i,), vi in split_form(v)) - return assemble(split_interp, tensor=tensor) - - # Dual interpolation into mixed target - if is_adjoint and len(arg_operand[0].function_space()) > 1 and rank == 1: - V = arg_operand[0].function_space() - tensor = tensor or firedrake.Cofunction(V.dual()) - - # Argument-split of the Interpolate gets assembled into the corresponding sub-tensor - for (i,), sub_interp in split_form(expr): - assemble(sub_interp, tensor=tensor.subfunctions[i]) - return tensor - # Get the interpolator interp_data = expr.interp_data.copy() default_missing_val = interp_data.pop('default_missing_val', None) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 40f35e18a7..0c568cd797 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -263,7 +263,18 @@ class Interpolator(abc.ABC): def __new__(cls, expr, V, **kwargs): if isinstance(expr, ufl.Interpolate): + # Mixed spaces are handled well only by the primal 1-form. + # Are we a 2-form or a dual 1-form? + arguments = expr.arguments() + if any(not isinstance(a, Coargument) for a in arguments): + # Do we have mixed source or target spaces? + spaces = [a.function_space() for a in arguments] + if len(spaces) < 2: + spaces.append(V) + if any(len(space) > 1 for space in spaces): + return object.__new__(MixedInterpolator) expr, = expr.ufl_operands + target_mesh = as_domain(V) source_mesh = extract_unique_domain(expr) or target_mesh submesh_interp_implemented = \ @@ -309,9 +320,10 @@ def __init__( target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh vom_onto_other_vom = ((source_mesh is not target_mesh) + and isinstance(self, SameMeshInterpolator) and isinstance(source_mesh.topology, VertexOnlyMeshTopology) and isinstance(target_mesh.topology, VertexOnlyMeshTopology)) - if not isinstance(self, SameMeshInterpolator) or vom_onto_other_vom: + if isinstance(self, CrossMeshInterpolator) or vom_onto_other_vom: # For bespoke interpolation, we currently rely on different assembly procedures: # 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Forward operator (2-form) # 2) Interpolate(Argument(V1, 0), Argument(V2.dual(), 1)) -> Adjoint operator (2-form) @@ -369,7 +381,7 @@ def _interpolate(self, *args, **kwargs): """ pass - def assemble(self, tensor=None, default_missing_val=None): + def assemble(self, tensor=None, **kwargs): """Assemble the operator (or its action).""" from firedrake.assemble import assemble needs_adjoint = self.ufl_interpolate_renumbered != self.ufl_interpolate @@ -383,13 +395,11 @@ def assemble(self, tensor=None, default_missing_val=None): if needs_adjoint: # Out-of-place Hermitian transpose petsc_mat.hermitianTranspose(out=res) - elif res: - petsc_mat.copy(res) + elif tensor: + petsc_mat.copy(tensor.petscmat) else: res = petsc_mat - if tensor is None: - tensor = firedrake.AssembledMatrix(arguments, self.bcs, res) - return tensor + return tensor or firedrake.AssembledMatrix(arguments, self.bcs, res) else: # Assembling the action cofunctions = () @@ -401,11 +411,11 @@ def assemble(self, tensor=None, default_missing_val=None): cofunctions = (dual_arg,) if needs_adjoint and len(arguments) == 0: - Iu = self._interpolate(default_missing_val=default_missing_val) + Iu = self._interpolate(**kwargs) return assemble(ufl.Action(*cofunctions, Iu), tensor=tensor) else: return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint, - default_missing_val=default_missing_val) + **kwargs) class DofNotDefinedError(Exception): @@ -975,33 +985,10 @@ def callable(): return callable else: loops = [] - if len(V) == 1: - expressions = (expr,) - else: - if (hasattr(operand, "subfunctions") and len(operand.subfunctions) == len(V) - and all(sub_op.ufl_shape == Vsub.value_shape for Vsub, sub_op in zip(V, operand.subfunctions))): - # Use subfunctions if they match the target shapes - operands = operand.subfunctions - else: - # Unflatten the expression into the shapes of the mixed components - offset = 0 - operands = [] - for Vsub in V: - if len(Vsub.value_shape) == 0: - operands.append(operand[offset]) - else: - components = [operand[offset + j] for j in range(Vsub.value_size)] - operands.append(ufl.as_tensor(numpy.reshape(components, Vsub.value_shape))) - offset += Vsub.value_size - - # Split the dual argument - if isinstance(dual_arg, Cofunction): - duals = dual_arg.subfunctions - elif isinstance(dual_arg, Coargument): - duals = [Coargument(Vsub, number=dual_arg.number()) for Vsub in dual_arg.function_space()] - else: - duals = [v for _, v in sorted(firedrake.formmanipulation.split_form(dual_arg))] - expressions = map(expr._ufl_expr_reconstruct_, operands, duals) + expressions = split_interpolate_target(expr) + + if access == op2.INC: + loops.append(tensor.zero) # Interpolate each sub expression into each function space for Vsub, sub_tensor, sub_expr in zip(V, tensor, expressions): @@ -1074,8 +1061,6 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): parameters['scalar_type'] = utils.ScalarType callables = () - if access == op2.INC: - callables += (tensor.zero,) # For the matfree adjoint 1-form and the 0-form, the cellwise kernel will add multiple # contributions from the facet DOFs of the dual argument. @@ -1720,3 +1705,106 @@ def _wrap_dummy_mat(self): def duplicate(self, mat=None, op=None): return self._wrap_dummy_mat() + + +def split_interpolate_target(expr: ufl.Interpolate): + """Split an Interpolate into the components (subfunctions) of the target space.""" + dual_arg, operand = expr.argument_slots() + V = dual_arg.function_space().dual() + if len(V) == 1: + return (expr,) + # Split the target (dual) argument + if isinstance(dual_arg, Cofunction): + duals = dual_arg.subfunctions + elif isinstance(dual_arg, ufl.Coargument): + duals = [Coargument(Vsub, dual_arg.number()) for Vsub in dual_arg.function_space()] + else: + duals = [vi for _, vi in sorted(firedrake.formmanipulation.split_form(dual_arg))] + # Split the operand into the target shapes + if (isinstance(operand, firedrake.Function) and len(operand.subfunctions) == len(V) + and all(fsub.ufl_shape == Vsub.value_shape for Vsub, fsub in zip(V, operand.subfunctions))): + # Use subfunctions if they match the target shapes + operands = operand.subfunctions + else: + # Unflatten the expression into the target shapes + cur = 0 + operands = [] + components = numpy.reshape(operand, (-1,)) + for Vi in V: + operands.append(ufl.as_tensor(components[cur:cur+Vi.value_size].reshape(Vi.value_shape))) + cur += Vi.value_size + expressions = tuple(map(expr._ufl_expr_reconstruct_, operands, duals)) + return expressions + + +class MixedInterpolator(Interpolator): + """A reusable interpolation object between MixedFunctionSpaces. + + Parameters + ---------- + expr + The underlying ufl.Interpolate or the operand to the ufl.Interpolate. + V + The :class:`.FunctionSpace` or :class:`.Function` to + interpolate into. + bcs + A list of boundary conditions. + **kwargs + Any extra kwargs are passed on to the sub Interpolators. + For details see :class:`firedrake.interpolation.Interpolator`. + """ + def __init__(self, expr, V, bcs=None, **kwargs): + super(MixedInterpolator, self).__init__(expr, V, bcs=bcs, **kwargs) + expr = self.ufl_interpolate + bcs = bcs or () + self.arguments = expr.arguments() + + # Split the target (dual) argument + dual_split = split_interpolate_target(expr) + self.sub_interpolators = {} + for i, form in enumerate(dual_split): + # Split the source (primal) argument + for j, sub_interp in firedrake.formmanipulation.split_form(form): + j = max(j) if j else 0 + # Ensure block sparsity + vi, operand = sub_interp.argument_slots() + if not isinstance(operand, ufl.classes.Zero): + Vtarget = vi.function_space().dual() + adjoint = vi.number() == 1 if isinstance(vi, Coargument) else True + + args = sub_interp.arguments() + Vsource = args[0 if adjoint else 1].function_space() + sub_bcs = [bc for bc in bcs if bc.function_space() in {Vsource, Vtarget}] + + indices = (j, i) if adjoint else (i, j) + Isub = Interpolator(sub_interp, Vtarget, bcs=sub_bcs, **kwargs) + self.sub_interpolators[indices] = Isub + + self.callable = self._callable + + def _callable(self): + """Assemble the operator.""" + shape = tuple(len(a.function_space()) for a in self.arguments) + Isubs = self.sub_interpolators + blocks = numpy.reshape([Isubs[ij].callable().handle if ij in Isubs else PETSc.Mat() + for ij in numpy.ndindex(shape)], shape) + petscmat = PETSc.Mat().createNest(blocks) + tensor = firedrake.AssembledMatrix(self.arguments, self.bcs, petscmat) + return tensor.M + + def _interpolate(self, output=None, adjoint=False, **kwargs): + """Assemble the action.""" + tensor = output + rank = len(self.arguments) + if rank == 1: + # Assemble the action + if tensor is None: + V_dest = self.arguments[0].function_space().dual() + tensor = firedrake.Function(V_dest) + for k, fsub in enumerate(tensor.subfunctions): + fsub.assign(sum(Isub.assemble(**kwargs) for (i, j), Isub in self.sub_interpolators.items() if i == k)) + return tensor + elif rank == 0: + # Assemble the double action + result = sum(Isub.assemble(**kwargs) for (i, j), Isub in self.sub_interpolators.items()) + return tensor.assign(result) if tensor else result diff --git a/tests/firedrake/regression/test_interpolate.py b/tests/firedrake/regression/test_interpolate.py index 8a310e4508..47b3dc7a6d 100644 --- a/tests/firedrake/regression/test_interpolate.py +++ b/tests/firedrake/regression/test_interpolate.py @@ -519,3 +519,50 @@ def test_interpolate_logical_not(): a = assemble(interpolate(conditional(Not(x < .2), 1, 0), V)) b = assemble(interpolate(conditional(x >= .2, 1, 0), V)) assert np.allclose(a.dat.data, b.dat.data) + + +@pytest.mark.parametrize("mode", ("forward", "adjoint")) +def test_mixed_matrix(mode): + nx = 3 + mesh = UnitSquareMesh(nx, nx) + + V1 = VectorFunctionSpace(mesh, "CG", 2) + V2 = FunctionSpace(mesh, "CG", 1) + V3 = FunctionSpace(mesh, "CG", 1) + V4 = FunctionSpace(mesh, "DG", 1) + + Z = V1 * V2 + W = V3 * V3 * V4 + + if mode == "forward": + I = Interpolate(TrialFunction(Z), TestFunction(W.dual())) + a = assemble(I) + assert a.arguments()[0].function_space() == W.dual() + assert a.arguments()[1].function_space() == Z + assert a.petscmat.getSize() == (W.dim(), Z.dim()) + assert a.petscmat.getType() == "nest" + + u = Function(Z) + u.subfunctions[0].sub(0).assign(1) + u.subfunctions[0].sub(1).assign(2) + u.subfunctions[1].assign(3) + result_matfree = assemble(Interpolate(u, TestFunction(W.dual()))) + elif mode == "adjoint": + I = Interpolate(TestFunction(Z), TrialFunction(W.dual())) + a = assemble(I) + assert a.arguments()[1].function_space() == W.dual() + assert a.arguments()[0].function_space() == Z + assert a.petscmat.getSize() == (Z.dim(), W.dim()) + assert a.petscmat.getType() == "nest" + + u = Function(W.dual()) + u.subfunctions[0].assign(1) + u.subfunctions[1].assign(2) + u.subfunctions[2].assign(3) + result_matfree = assemble(Interpolate(TestFunction(Z), u)) + else: + raise ValueError(f"Unrecognized mode {mode}") + + result_explicit = assemble(action(a, u)) + for x, y in zip(result_explicit.subfunctions, result_matfree.subfunctions): + assert np.allclose(x.dat.data, y.dat.data) From 76ec36770f649f7418e4eb42315f307f88ed356e Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 1 Oct 2025 07:26:07 +0100 Subject: [PATCH 02/20] Fixup --- firedrake/formmanipulation.py | 6 ++++++ firedrake/interpolation.py | 8 ++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index c24bda8cc1..cae0d63bec 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -171,6 +171,12 @@ def matrix(self, o): bcs = () return AssembledMatrix(tuple(args), bcs, submat) + def interpolate(self, o, operand): + if isinstance(operand, Zero): + return ZeroBaseForm(o.arguments()) + + return o._ufl_expr_reconstruct_(operand) + SplitForm = collections.namedtuple("SplitForm", ["indices", "form"]) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 0c568cd797..dbc8a5c40b 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -1767,8 +1767,8 @@ def __init__(self, expr, V, bcs=None, **kwargs): for j, sub_interp in firedrake.formmanipulation.split_form(form): j = max(j) if j else 0 # Ensure block sparsity - vi, operand = sub_interp.argument_slots() - if not isinstance(operand, ufl.classes.Zero): + if not isinstance(sub_interp, ufl.ZeroBaseForm): + vi, operand = sub_interp.argument_slots() Vtarget = vi.function_space().dual() adjoint = vi.number() == 1 if isinstance(vi, Coargument) else True @@ -1780,9 +1780,9 @@ def __init__(self, expr, V, bcs=None, **kwargs): Isub = Interpolator(sub_interp, Vtarget, bcs=sub_bcs, **kwargs) self.sub_interpolators[indices] = Isub - self.callable = self._callable + self.callable = self._get_callable - def _callable(self): + def _get_callable(self): """Assemble the operator.""" shape = tuple(len(a.function_space()) for a in self.arguments) Isubs = self.sub_interpolators From f1080ea3f9d3cb0477b7b48f1367a65a4ba0de28 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 1 Oct 2025 09:11:07 +0100 Subject: [PATCH 03/20] Interpolate: support fieldsplit --- firedrake/formmanipulation.py | 30 +++++++++++- firedrake/interpolation.py | 91 +++++++++++++---------------------- 2 files changed, 61 insertions(+), 60 deletions(-) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index cae0d63bec..7fdadffe3c 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -2,7 +2,7 @@ import numpy import collections -from ufl import as_vector, split +from ufl import as_tensor, as_vector, split from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm from ufl.algorithms.map_integrands import map_integrand_dags from ufl.algorithms import expand_derivatives @@ -14,6 +14,7 @@ from firedrake.petsc import PETSc from firedrake.functionspace import MixedFunctionSpace from firedrake.cofunction import Cofunction +from firedrake.ufl_expr import Coargument from firedrake.matrix import AssembledMatrix @@ -175,7 +176,32 @@ def interpolate(self, o, operand): if isinstance(operand, Zero): return ZeroBaseForm(o.arguments()) - return o._ufl_expr_reconstruct_(operand) + dual_arg, _ = o.argument_slots() + V = dual_arg.function_space() + if len(V) == 1: + return o._ufl_expr_reconstruct_(operand, dual_arg) + + # Split the target (dual) argument + if isinstance(dual_arg, Coargument): + indices = self.blocks[dual_arg.number()] + W = subspace(dual_arg.function_space(), indices) + dual_arg = Coargument(W, dual_arg.number()) + else: + raise NotImplementedError() + + # Unflatten the expression into the target shapes + cur = 0 + operands = [] + components = numpy.reshape(operand, (-1,)) + for i, Vi in enumerate(V): + if i in indices: + operands.extend(components[cur:cur+Vi.value_size]) + cur += Vi.value_size + + operand = as_tensor(numpy.reshape(operands, W.value_shape)) + if isinstance(operand, Zero): + return ZeroBaseForm(o.arguments()) + return o._ufl_expr_reconstruct_(operand, dual_arg) SplitForm = collections.namedtuple("SplitForm", ["indices", "form"]) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index dbc8a5c40b..7be230360a 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -985,14 +985,18 @@ def callable(): return callable else: loops = [] - expressions = split_interpolate_target(expr) if access == op2.INC: loops.append(tensor.zero) # Interpolate each sub expression into each function space - for Vsub, sub_tensor, sub_expr in zip(V, tensor, expressions): - loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs)) + tensors = list(tensor) + if len(tensors) == 1: + split = [((0,), expr)] + else: + split = firedrake.formmanipulation.split_form(expr) + for (i,), sub_expr in split: + loops.extend(_interpolator(V[i], tensors[i], sub_expr, subset, arguments, access, bcs=bcs)) if bcs and rank == 1: loops.extend(partial(bc.apply, f) for bc in bcs) @@ -1707,36 +1711,6 @@ def duplicate(self, mat=None, op=None): return self._wrap_dummy_mat() -def split_interpolate_target(expr: ufl.Interpolate): - """Split an Interpolate into the components (subfunctions) of the target space.""" - dual_arg, operand = expr.argument_slots() - V = dual_arg.function_space().dual() - if len(V) == 1: - return (expr,) - # Split the target (dual) argument - if isinstance(dual_arg, Cofunction): - duals = dual_arg.subfunctions - elif isinstance(dual_arg, ufl.Coargument): - duals = [Coargument(Vsub, dual_arg.number()) for Vsub in dual_arg.function_space()] - else: - duals = [vi for _, vi in sorted(firedrake.formmanipulation.split_form(dual_arg))] - # Split the operand into the target shapes - if (isinstance(operand, firedrake.Function) and len(operand.subfunctions) == len(V) - and all(fsub.ufl_shape == Vsub.value_shape for Vsub, fsub in zip(V, operand.subfunctions))): - # Use subfunctions if they match the target shapes - operands = operand.subfunctions - else: - # Unflatten the expression into the target shapes - cur = 0 - operands = [] - components = numpy.reshape(operand, (-1,)) - for Vi in V: - operands.append(ufl.as_tensor(components[cur:cur+Vi.value_size].reshape(Vi.value_shape))) - cur += Vi.value_size - expressions = tuple(map(expr._ufl_expr_reconstruct_, operands, duals)) - return expressions - - class MixedInterpolator(Interpolator): """A reusable interpolation object between MixedFunctionSpaces. @@ -1754,39 +1728,40 @@ class MixedInterpolator(Interpolator): For details see :class:`firedrake.interpolation.Interpolator`. """ def __init__(self, expr, V, bcs=None, **kwargs): + bcs = bcs or () super(MixedInterpolator, self).__init__(expr, V, bcs=bcs, **kwargs) + expr = self.ufl_interpolate - bcs = bcs or () self.arguments = expr.arguments() - - # Split the target (dual) argument - dual_split = split_interpolate_target(expr) - self.sub_interpolators = {} - for i, form in enumerate(dual_split): - # Split the source (primal) argument - for j, sub_interp in firedrake.formmanipulation.split_form(form): - j = max(j) if j else 0 - # Ensure block sparsity - if not isinstance(sub_interp, ufl.ZeroBaseForm): - vi, operand = sub_interp.argument_slots() - Vtarget = vi.function_space().dual() - adjoint = vi.number() == 1 if isinstance(vi, Coargument) else True - - args = sub_interp.arguments() - Vsource = args[0 if adjoint else 1].function_space() - sub_bcs = [bc for bc in bcs if bc.function_space() in {Vsource, Vtarget}] - - indices = (j, i) if adjoint else (i, j) - Isub = Interpolator(sub_interp, Vtarget, bcs=sub_bcs, **kwargs) - self.sub_interpolators[indices] = Isub - + rank = len(self.arguments) + if rank < 2: + dual_arg, operand = expr.argument_slots() + # Split the dual argument + dual_split = dict(firedrake.formmanipulation.split_form(dual_arg)) + # Create the Jacobian to split into blocks + expr = expr._ufl_expr_reconstruct_(operand, firedrake.TrialFunction(dual_arg.function_space())) + + Isub = {} + for indices, form in firedrake.formmanipulation.split_form(expr): + if not isinstance(form, ufl.ZeroBaseForm): + args = form.arguments() + vi, operand = form.argument_slots() + Vtarget = vi.function_space().dual() + Vsource = args[1-vi.number()].function_space() + sub_bcs = [bc for bc in self.bcs if bc.function_space() in {Vsource, Vtarget}] + if rank == 1: + # Take the action of each sub-cofunction against each block + form = action(form, dual_split[indices[1:]]) + Isub[indices] = Interpolator(form, Vtarget, bcs=sub_bcs, **kwargs) + + self.sub_interpolators = Isub self.callable = self._get_callable def _get_callable(self): """Assemble the operator.""" + Isub = self.sub_interpolators shape = tuple(len(a.function_space()) for a in self.arguments) - Isubs = self.sub_interpolators - blocks = numpy.reshape([Isubs[ij].callable().handle if ij in Isubs else PETSc.Mat() + blocks = numpy.reshape([Isub[ij].callable().handle if ij in Isub else PETSc.Mat() for ij in numpy.ndindex(shape)], shape) petscmat = PETSc.Mat().createNest(blocks) tensor = firedrake.AssembledMatrix(self.arguments, self.bcs, petscmat) From 997e63886f43be79bb55f6b81582581afee58e85 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 1 Oct 2025 12:57:59 +0100 Subject: [PATCH 04/20] cleanup --- firedrake/formmanipulation.py | 26 +++++++--- firedrake/interpolation.py | 92 +++++++++++++++++++++-------------- 2 files changed, 74 insertions(+), 44 deletions(-) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 7fdadffe3c..e1f9ddb885 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -134,6 +134,17 @@ def argument(self, o): args.extend(Zero() for j in numpy.ndindex(V[i].value_shape)) return self._arg_cache.setdefault(o, as_vector(args)) + def coargument(self, o): + V = o.function_space() + + if len(V) == 1: + # Not on a mixed space, just return ourselves. + return o + + indices = self.blocks[o.number()] + W = subspace(V, indices) + return Coargument(W, number=o.number(), part=o.part()) + def cofunction(self, o): V = o.function_space() @@ -178,29 +189,30 @@ def interpolate(self, o, operand): dual_arg, _ = o.argument_slots() V = dual_arg.function_space() - if len(V) == 1: + if len(V) == 1 or len(dual_arg.arguments()) == 1: return o._ufl_expr_reconstruct_(operand, dual_arg) # Split the target (dual) argument if isinstance(dual_arg, Coargument): + dual_arg = self(dual_arg) indices = self.blocks[dual_arg.number()] - W = subspace(dual_arg.function_space(), indices) - dual_arg = Coargument(W, dual_arg.number()) else: raise NotImplementedError() # Unflatten the expression into the target shapes cur = 0 - operands = [] - components = numpy.reshape(operand, (-1,)) + cindices = [] for i, Vi in enumerate(V): if i in indices: - operands.extend(components[cur:cur+Vi.value_size]) + cindices.extend(range(cur, cur+Vi.value_size)) cur += Vi.value_size - operand = as_tensor(numpy.reshape(operands, W.value_shape)) + W = dual_arg.function_space() + components = [operand[i] for i in cindices] + operand = as_tensor(numpy.reshape(components, W.value_shape)) if isinstance(operand, Zero): return ZeroBaseForm(o.arguments()) + return o._ufl_expr_reconstruct_(operand, dual_arg) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 7be230360a..65b6d70065 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -263,7 +263,7 @@ class Interpolator(abc.ABC): def __new__(cls, expr, V, **kwargs): if isinstance(expr, ufl.Interpolate): - # Mixed spaces are handled well only by the primal 1-form. + # MixedFunctionSpace is only implemented for the primal 1-form. # Are we a 2-form or a dual 1-form? arguments = expr.arguments() if any(not isinstance(a, Coargument) for a in arguments): @@ -989,14 +989,21 @@ def callable(): if access == op2.INC: loops.append(tensor.zero) - # Interpolate each sub expression into each function space - tensors = list(tensor) - if len(tensors) == 1: - split = [((0,), expr)] + if rank == 0 and len(V) > 1: + dual_arg, operand = expr.argument_slots() + interp = expr._ufl_expr_reconstruct_(operand, V) + interp_split = dict(firedrake.formmanipulation.split_form(interp)) + dual_split = dict(firedrake.formmanipulation.split_form(dual_arg)) + expressions = {i: action(interp_split[i], dual_split[i]) for i in dual_split} + elif len(V) > 1: + expressions = dict(firedrake.formmanipulation.split_form(expr)) else: - split = firedrake.formmanipulation.split_form(expr) - for (i,), sub_expr in split: - loops.extend(_interpolator(V[i], tensors[i], sub_expr, subset, arguments, access, bcs=bcs)) + expressions = {(0,): expr} + + # Interpolate each sub expression into each function space + for (i,), sub_expr in expressions.items(): + sub_tensor = tensor[i] if (rank == 1 and len(V) > 1) else tensor + loops.extend(_interpolator(V[i], sub_tensor, sub_expr, subset, arguments, access, bcs=bcs)) if bcs and rank == 1: loops.extend(partial(bc.apply, f) for bc in bcs) @@ -1728,9 +1735,7 @@ class MixedInterpolator(Interpolator): For details see :class:`firedrake.interpolation.Interpolator`. """ def __init__(self, expr, V, bcs=None, **kwargs): - bcs = bcs or () super(MixedInterpolator, self).__init__(expr, V, bcs=bcs, **kwargs) - expr = self.ufl_interpolate self.arguments = expr.arguments() rank = len(self.arguments) @@ -1738,48 +1743,61 @@ def __init__(self, expr, V, bcs=None, **kwargs): dual_arg, operand = expr.argument_slots() # Split the dual argument dual_split = dict(firedrake.formmanipulation.split_form(dual_arg)) - # Create the Jacobian to split into blocks - expr = expr._ufl_expr_reconstruct_(operand, firedrake.TrialFunction(dual_arg.function_space())) + # Create the Jacobian be split into blocks + expr = expr._ufl_expr_reconstruct_(operand, V) Isub = {} for indices, form in firedrake.formmanipulation.split_form(expr): - if not isinstance(form, ufl.ZeroBaseForm): + if isinstance(form, ufl.ZeroBaseForm): + # Ensure block sparsity + continue + vi, _ = form.argument_slots() + Vtarget = vi.function_space().dual() + if bcs and rank != 0: args = form.arguments() - vi, operand = form.argument_slots() - Vtarget = vi.function_space().dual() Vsource = args[1-vi.number()].function_space() - sub_bcs = [bc for bc in self.bcs if bc.function_space() in {Vsource, Vtarget}] - if rank == 1: - # Take the action of each sub-cofunction against each block - form = action(form, dual_split[indices[1:]]) - Isub[indices] = Interpolator(form, Vtarget, bcs=sub_bcs, **kwargs) + sub_bcs = [bc for bc in bcs if bc.function_space() in {Vsource, Vtarget}] + else: + sub_bcs = None + if rank < 2: + # Take the action of each sub-cofunction against each block + form = action(form, dual_split[indices[1:]]) - self.sub_interpolators = Isub + Isub[indices] = Interpolator(form, Vtarget, bcs=sub_bcs, **kwargs) + + self._sub_interpolators = Isub self.callable = self._get_callable + def __getitem__(self, item): + return self._sub_interpolators[item] + + def __iter__(self): + return iter(self._sub_interpolators) + def _get_callable(self): """Assemble the operator.""" - Isub = self.sub_interpolators shape = tuple(len(a.function_space()) for a in self.arguments) - blocks = numpy.reshape([Isub[ij].callable().handle if ij in Isub else PETSc.Mat() - for ij in numpy.ndindex(shape)], shape) + blocks = numpy.full(shape, PETSc.Mat(), dtype=object) + for i in self: + blocks[i] = self[i].callable().handle petscmat = PETSc.Mat().createNest(blocks) tensor = firedrake.AssembledMatrix(self.arguments, self.bcs, petscmat) return tensor.M - def _interpolate(self, output=None, adjoint=False, **kwargs): + def _interpolate(self, *function, output=None, adjoint=False, **kwargs): """Assemble the action.""" - tensor = output rank = len(self.arguments) + if rank == 0: + result = sum(self[i].assemble(**kwargs) for i in self) + return output.assign(result) if output else result + + if output is None: + output = firedrake.Function(self.arguments[-1].function_space().dual()) + if rank == 1: - # Assemble the action - if tensor is None: - V_dest = self.arguments[0].function_space().dual() - tensor = firedrake.Function(V_dest) - for k, fsub in enumerate(tensor.subfunctions): - fsub.assign(sum(Isub.assemble(**kwargs) for (i, j), Isub in self.sub_interpolators.items() if i == k)) - return tensor - elif rank == 0: - # Assemble the double action - result = sum(Isub.assemble(**kwargs) for (i, j), Isub in self.sub_interpolators.items()) - return tensor.assign(result) if tensor else result + for k, sub_tensor in enumerate(output.subfunctions): + sub_tensor.assign(sum(self[i, j].assemble(**kwargs) for (i, j) in self if i == k)) + elif rank == 2: + for k, sub_tensor in enumerate(output.subfunctions): + sub_tensor.assign(sum(self[i, j]._interpolate(*function, adjoint=adjoint, **kwargs) for (i, j) in self if i == k)) + return output From e400fc99cc263c5695d0c60ac62818a5e03555ba Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 1 Oct 2025 15:16:39 +0100 Subject: [PATCH 05/20] cleanup --- firedrake/interpolation.py | 63 ++++++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 65b6d70065..b88215c2e4 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -262,21 +262,17 @@ class Interpolator(abc.ABC): """ def __new__(cls, expr, V, **kwargs): - if isinstance(expr, ufl.Interpolate): - # MixedFunctionSpace is only implemented for the primal 1-form. - # Are we a 2-form or a dual 1-form? - arguments = expr.arguments() - if any(not isinstance(a, Coargument) for a in arguments): - # Do we have mixed source or target spaces? - spaces = [a.function_space() for a in arguments] - if len(spaces) < 2: - spaces.append(V) - if any(len(space) > 1 for space in spaces): - return object.__new__(MixedInterpolator) - expr, = expr.ufl_operands + if not isinstance(expr, ufl.Interpolate): + expr = interpolate(expr, V if isinstance(V, ufl.FunctionSpace) else V.function_space()) + + spaces = [a.function_space() for a in expr.arguments()] + has_mixed_spaces = any(len(space) > 1 for space in spaces) + if len(spaces) == 2 and has_mixed_spaces: + return object.__new__(MixedInterpolator) + operand, = expr.ufl_operands target_mesh = as_domain(V) - source_mesh = extract_unique_domain(expr) or target_mesh + source_mesh = extract_unique_domain(operand) or target_mesh submesh_interp_implemented = \ all(isinstance(m.topology, firedrake.mesh.MeshTopology) for m in [target_mesh, source_mesh]) and \ target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] and \ @@ -284,8 +280,11 @@ def __new__(cls, expr, V, **kwargs): if target_mesh is source_mesh or submesh_interp_implemented: return object.__new__(SameMeshInterpolator) else: + needs_adjoint = not isinstance(expr.arguments()[0], Coargument) if isinstance(target_mesh.topology, VertexOnlyMeshTopology): return object.__new__(SameMeshInterpolator) + elif has_mixed_spaces and needs_adjoint: + return object.__new__(MixedInterpolator) else: return object.__new__(CrossMeshInterpolator) @@ -301,8 +300,7 @@ def __init__( matfree: bool = True ): if not isinstance(expr, ufl.Interpolate): - fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space() - expr = interpolate(expr, fs) + expr = interpolate(expr, V if isinstance(V, ufl.FunctionSpace) else V.function_space()) dual_arg, operand = expr.argument_slots() self.ufl_interpolate = expr self.expr = operand @@ -414,8 +412,7 @@ def assemble(self, tensor=None, **kwargs): Iu = self._interpolate(**kwargs) return assemble(ufl.Action(*cofunctions, Iu), tensor=tensor) else: - return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint, - **kwargs) + return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint, **kwargs) class DofNotDefinedError(Exception): @@ -989,21 +986,33 @@ def callable(): if access == op2.INC: loops.append(tensor.zero) - if rank == 0 and len(V) > 1: - dual_arg, operand = expr.argument_slots() + dual_arg, operand = expr.argument_slots() + # Any arguments in the operand may be from a MixedFunctoinSpace + # We need to split the target space V and generate separate kernels + if len(V) == 1: + expressions = {(0,): expr} + elif isinstance(dual_arg, Coargument): + # Split in the coargument + expressions = dict(firedrake.formmanipulation.split_form(expr)) + else: + # Split in the cofunction: split_form can only split in the coargument + # Replace the cofunction with a coargument to construct the Jacobian interp = expr._ufl_expr_reconstruct_(operand, V) + # Split the Jacobian into blocks interp_split = dict(firedrake.formmanipulation.split_form(interp)) + # Split the cofunction dual_split = dict(firedrake.formmanipulation.split_form(dual_arg)) - expressions = {i: action(interp_split[i], dual_split[i]) for i in dual_split} - elif len(V) > 1: - expressions = dict(firedrake.formmanipulation.split_form(expr)) - else: - expressions = {(0,): expr} + # Combine the splits by taking their action + expressions = {i: action(interp_split[i], dual_split[i[-1:]]) for i in interp_split} # Interpolate each sub expression into each function space - for (i,), sub_expr in expressions.items(): - sub_tensor = tensor[i] if (rank == 1 and len(V) > 1) else tensor - loops.extend(_interpolator(V[i], sub_tensor, sub_expr, subset, arguments, access, bcs=bcs)) + for indices, sub_expr in expressions.items(): + if isinstance(sub_expr, ufl.ZeroBaseForm): + continue + arguments = sub_expr.arguments() + sub_space = sub_expr.argument_slots()[0].function_space().dual() + sub_tensor = tensor[indices[0]] if rank == 1 else tensor + loops.extend(_interpolator(sub_space, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs)) if bcs and rank == 1: loops.extend(partial(bc.apply, f) for bc in bcs) From fee366abc0ef081c22d1d819fdd73d9c29407f59 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 1 Oct 2025 16:56:56 +0100 Subject: [PATCH 06/20] Implement missing functionality in CrossMeshInterpolator --- firedrake/interpolation.py | 92 +++++-------------- .../regression/test_interpolate_cross_mesh.py | 11 ++- 2 files changed, 28 insertions(+), 75 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index b88215c2e4..2e31d9919b 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -262,8 +262,9 @@ class Interpolator(abc.ABC): """ def __new__(cls, expr, V, **kwargs): + V_target = V if isinstance(V, ufl.FunctionSpace) else V.function_space() if not isinstance(expr, ufl.Interpolate): - expr = interpolate(expr, V if isinstance(V, ufl.FunctionSpace) else V.function_space()) + expr = interpolate(expr, V_target) spaces = [a.function_space() for a in expr.arguments()] has_mixed_spaces = any(len(space) > 1 for space in spaces) @@ -280,10 +281,9 @@ def __new__(cls, expr, V, **kwargs): if target_mesh is source_mesh or submesh_interp_implemented: return object.__new__(SameMeshInterpolator) else: - needs_adjoint = not isinstance(expr.arguments()[0], Coargument) if isinstance(target_mesh.topology, VertexOnlyMeshTopology): return object.__new__(SameMeshInterpolator) - elif has_mixed_spaces and needs_adjoint: + elif has_mixed_spaces or len(V_target) > 1: return object.__new__(MixedInterpolator) else: return object.__new__(CrossMeshInterpolator) @@ -506,8 +506,6 @@ def __init__( self.src_mesh = src_mesh self.dest_mesh = dest_mesh - self.sub_interpolators = [] - # Create a VOM at the nodes of V_dest in src_mesh. We don't include halo # node coordinates because interpolation doesn't usually include halos. # NOTE: it is very important to set redundant=False, otherwise the @@ -515,53 +513,17 @@ def __init__( # QUESTION: Should any of the below have annotation turned off? ufl_scalar_element = V_dest.ufl_element() if isinstance(ufl_scalar_element, finat.ufl.MixedElement): - if all( - ufl_scalar_element.sub_elements[0] == e - for e in ufl_scalar_element.sub_elements - ): - # For a VectorElement or TensorElement the correct - # VectorFunctionSpace equivalent is built from the scalar - # sub-element. - ufl_scalar_element = ufl_scalar_element.sub_elements[0] - if ufl_scalar_element.reference_value_shape != (): - raise NotImplementedError( - "Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()." - ) - else: - # Build and save an interpolator for each sub-element - # separately for MixedFunctionSpaces. NOTE: since we can't have - # expressions for MixedFunctionSpaces we know that the input - # argument ``expr`` must be a Function. V_dest can be a Function - # or a FunctionSpace, and subfunctions works for both. - if self.nargs == 1: - # Arguments don't have a subfunctions property so I have to - # make them myself. NOTE: this will not be correct when we - # start allowing interpolators created from an expression - # with arguments, as opposed to just being the argument. - expr_subfunctions = [ - firedrake.TestFunction(V_src_sub_func) - for V_src_sub_func in self.expr.function_space().subspaces - ] - elif self.nargs > 1: - raise NotImplementedError( - "Can't yet create an interpolator from an expression with multiple arguments." - ) - else: - expr_subfunctions = self.expr.subfunctions - if len(expr_subfunctions) != len(V_dest.subspaces): - raise NotImplementedError( - "Can't interpolate from a non-mixed function space into a mixed function space." - ) - for input_sub_func, target_subspace in zip( - expr_subfunctions, V_dest.subspaces - ): - self.sub_interpolators.append( - interpolate( - input_sub_func, target_subspace, subset=subset, - access=access, allow_missing_dofs=allow_missing_dofs - ) - ) - return + if type(ufl_scalar_element) == finat.ufl.MixedElement: + raise NotImplementedError("Need a MixedInterpolator") + + # For a VectorElement or TensorElement the correct + # VectorFunctionSpace equivalent is built from the scalar + # sub-element. + ufl_scalar_element = ufl_scalar_element.sub_elements[0] + if ufl_scalar_element.reference_value_shape != (): + raise NotImplementedError( + "Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()." + ) from firedrake.assemble import assemble V_dest_vec = firedrake.VectorFunctionSpace(dest_mesh, ufl_scalar_element) @@ -672,21 +634,6 @@ def _interpolate( else: output = firedrake.Function(V_dest) - if len(self.sub_interpolators): - # MixedFunctionSpace case - for sub_interpolate, f_src_sub_func, output_sub_func in zip( - self.sub_interpolators, f_src.subfunctions, output.subfunctions - ): - if f_src is self.expr: - # f_src is already contained in self.point_eval_interpolate, - # so the sub_interpolators are already prepared to interpolate - # without needing to be given a Function - assert not self.nargs - assemble(sub_interpolate, tensor=output_sub_func) - else: - assemble(action(sub_interpolate, f_src_sub_func), tensor=output_sub_func) - return output - if not adjoint: if f_src is self.expr: # f_src is already contained in self.point_eval_interpolate @@ -1748,7 +1695,9 @@ def __init__(self, expr, V, bcs=None, **kwargs): expr = self.ufl_interpolate self.arguments = expr.arguments() rank = len(self.arguments) - if rank < 2: + + needs_action = len([a for a in self.arguments if isinstance(a, Coargument)]) == 0 + if needs_action: dual_arg, operand = expr.argument_slots() # Split the dual argument dual_split = dict(firedrake.formmanipulation.split_form(dual_arg)) @@ -1768,7 +1717,7 @@ def __init__(self, expr, V, bcs=None, **kwargs): sub_bcs = [bc for bc in bcs if bc.function_space() in {Vsource, Vtarget}] else: sub_bcs = None - if rank < 2: + if needs_action: # Take the action of each sub-cofunction against each block form = action(form, dual_split[indices[1:]]) @@ -1805,8 +1754,9 @@ def _interpolate(self, *function, output=None, adjoint=False, **kwargs): if rank == 1: for k, sub_tensor in enumerate(output.subfunctions): - sub_tensor.assign(sum(self[i, j].assemble(**kwargs) for (i, j) in self if i == k)) + sub_tensor.assign(sum(self[i].assemble(**kwargs) for i in self if i[0] == k)) elif rank == 2: for k, sub_tensor in enumerate(output.subfunctions): - sub_tensor.assign(sum(self[i, j]._interpolate(*function, adjoint=adjoint, **kwargs) for (i, j) in self if i == k)) + sub_tensor.assign(sum(self[i]._interpolate(*function, adjoint=adjoint, **kwargs) + for i in self if i[0] == k)) return output diff --git a/tests/firedrake/regression/test_interpolate_cross_mesh.py b/tests/firedrake/regression/test_interpolate_cross_mesh.py index 52c1746d74..5bb19fd1c2 100644 --- a/tests/firedrake/regression/test_interpolate_cross_mesh.py +++ b/tests/firedrake/regression/test_interpolate_cross_mesh.py @@ -299,12 +299,15 @@ def test_interpolate_unitsquare_mixed(): assert not np.allclose(f_src.dat.data_ro[0], cofunc_src.dat.data_ro[0]) assert not np.allclose(f_src.dat.data_ro[1], cofunc_src.dat.data_ro[1]) - # Can't go from non-mixed to mixed + # Interpolate from non-mixed to mixed V_src_2 = VectorFunctionSpace(m_src, "CG", 1) assert V_src_2.value_shape == V_src.value_shape - f_src_2 = Function(V_src_2) - with pytest.raises(NotImplementedError): - assemble(interpolate(f_src_2, V_dest)) + f_src_2 = Function(V_src_2).interpolate(SpatialCoordinate(m_src)) + result_mixed = assemble(interpolate(f_src_2, V_dest)) + + for i in range(len(V_dest)): + expected = assemble(interpolate(f_src_2[i], V_dest[i])) + assert np.allclose(result_mixed.dat.data_ro[i], expected.dat.data_ro) @pytest.mark.parallel([1, 3]) From 3aeb517cb053decfd1749923d62ba75a5401745e Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 1 Oct 2025 17:17:44 +0100 Subject: [PATCH 07/20] Test the MixedInterpolator 0-form across different meshes --- firedrake/interpolation.py | 6 +++--- tests/firedrake/regression/test_interpolate_cross_mesh.py | 6 ++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 2e31d9919b..de830844cd 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -934,7 +934,7 @@ def callable(): loops.append(tensor.zero) dual_arg, operand = expr.argument_slots() - # Any arguments in the operand may be from a MixedFunctoinSpace + # Arguments in the operand are allowed to be from a MixedFunctionSpace # We need to split the target space V and generate separate kernels if len(V) == 1: expressions = {(0,): expr} @@ -1701,7 +1701,7 @@ def __init__(self, expr, V, bcs=None, **kwargs): dual_arg, operand = expr.argument_slots() # Split the dual argument dual_split = dict(firedrake.formmanipulation.split_form(dual_arg)) - # Create the Jacobian be split into blocks + # Create the Jacobian to be split into blocks expr = expr._ufl_expr_reconstruct_(operand, V) Isub = {} @@ -1719,7 +1719,7 @@ def __init__(self, expr, V, bcs=None, **kwargs): sub_bcs = None if needs_action: # Take the action of each sub-cofunction against each block - form = action(form, dual_split[indices[1:]]) + form = action(form, dual_split[indices[-1:]]) Isub[indices] = Interpolator(form, Vtarget, bcs=sub_bcs, **kwargs) diff --git a/tests/firedrake/regression/test_interpolate_cross_mesh.py b/tests/firedrake/regression/test_interpolate_cross_mesh.py index 5bb19fd1c2..f6c22e48e8 100644 --- a/tests/firedrake/regression/test_interpolate_cross_mesh.py +++ b/tests/firedrake/regression/test_interpolate_cross_mesh.py @@ -305,10 +305,16 @@ def test_interpolate_unitsquare_mixed(): f_src_2 = Function(V_src_2).interpolate(SpatialCoordinate(m_src)) result_mixed = assemble(interpolate(f_src_2, V_dest)) + expected_zero_form = 0 for i in range(len(V_dest)): expected = assemble(interpolate(f_src_2[i], V_dest[i])) assert np.allclose(result_mixed.dat.data_ro[i], expected.dat.data_ro) + expected_zero_form += assemble(action(cofunc_dest.subfunctions[i], expected)) + + mixed_zero_form = assemble(interpolate(f_src_2, cofunc_dest)) + assert np.isclose(mixed_zero_form, expected_zero_form) + @pytest.mark.parallel([1, 3]) def test_exact_refinement(): From aa87c36940888ea35847f0d5bc87763c3dbb57a5 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 1 Oct 2025 17:56:49 +0100 Subject: [PATCH 08/20] Fixup --- firedrake/interpolation.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index de830844cd..be2a5d3cb5 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -266,9 +266,9 @@ def __new__(cls, expr, V, **kwargs): if not isinstance(expr, ufl.Interpolate): expr = interpolate(expr, V_target) - spaces = [a.function_space() for a in expr.arguments()] - has_mixed_spaces = any(len(space) > 1 for space in spaces) - if len(spaces) == 2 and has_mixed_spaces: + arguments = expr.arguments() + has_mixed_arguments = any(len(a.function_space()) > 1 for a in arguments) + if len(arguments) == 2 and has_mixed_arguments: return object.__new__(MixedInterpolator) operand, = expr.ufl_operands @@ -283,7 +283,7 @@ def __new__(cls, expr, V, **kwargs): else: if isinstance(target_mesh.topology, VertexOnlyMeshTopology): return object.__new__(SameMeshInterpolator) - elif has_mixed_spaces or len(V_target) > 1: + elif has_mixed_arguments or len(V_target) > 1: return object.__new__(MixedInterpolator) else: return object.__new__(CrossMeshInterpolator) @@ -514,12 +514,12 @@ def __init__( ufl_scalar_element = V_dest.ufl_element() if isinstance(ufl_scalar_element, finat.ufl.MixedElement): if type(ufl_scalar_element) == finat.ufl.MixedElement: - raise NotImplementedError("Need a MixedInterpolator") + raise NotImplementedError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") # For a VectorElement or TensorElement the correct # VectorFunctionSpace equivalent is built from the scalar # sub-element. - ufl_scalar_element = ufl_scalar_element.sub_elements[0] + ufl_scalar_element, = set(ufl_scalar_element.sub_elements) if ufl_scalar_element.reference_value_shape != (): raise NotImplementedError( "Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()." @@ -865,10 +865,10 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): Vrow = arguments[0].function_space() Vcol = arguments[1].function_space() if len(Vrow) > 1 or len(Vcol) > 1: - raise NotImplementedError("Interpolation of mixed expressions with arguments is not supported") + raise NotImplementedError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") if isinstance(target_mesh.topology, VertexOnlyMeshTopology) and target_mesh is not source_mesh and not vom_onto_other_vom: if not isinstance(target_mesh.topology, VertexOnlyMeshTopology): - raise NotImplementedError("Can only interpolate onto a Vertex Only Mesh") + raise NotImplementedError("Can only interpolate onto a VertexOnlyMesh") if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: @@ -936,7 +936,7 @@ def callable(): dual_arg, operand = expr.argument_slots() # Arguments in the operand are allowed to be from a MixedFunctionSpace # We need to split the target space V and generate separate kernels - if len(V) == 1: + if len(arguments) == 2: expressions = {(0,): expr} elif isinstance(dual_arg, Coargument): # Split in the coargument @@ -1696,6 +1696,7 @@ def __init__(self, expr, V, bcs=None, **kwargs): self.arguments = expr.arguments() rank = len(self.arguments) + # We need a Coargument in order to split the Interpolate needs_action = len([a for a in self.arguments if isinstance(a, Coargument)]) == 0 if needs_action: dual_arg, operand = expr.argument_slots() @@ -1705,6 +1706,7 @@ def __init__(self, expr, V, bcs=None, **kwargs): expr = expr._ufl_expr_reconstruct_(operand, V) Isub = {} + # Split in the arguments of the Interpolate for indices, form in firedrake.formmanipulation.split_form(expr): if isinstance(form, ufl.ZeroBaseForm): # Ensure block sparsity From fa418e0efeee762a6fe70f2c05cf46c051a1bfe9 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 1 Oct 2025 18:29:54 +0100 Subject: [PATCH 09/20] Update firedrake/interpolation.py --- firedrake/interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index be2a5d3cb5..654cdbd04f 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -513,7 +513,7 @@ def __init__( # QUESTION: Should any of the below have annotation turned off? ufl_scalar_element = V_dest.ufl_element() if isinstance(ufl_scalar_element, finat.ufl.MixedElement): - if type(ufl_scalar_element) == finat.ufl.MixedElement: + if type(ufl_scalar_element) is finat.ufl.MixedElement: raise NotImplementedError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") # For a VectorElement or TensorElement the correct From 774872e42056073fa49dfb887ff61c01a8b5ac5f Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Oct 2025 11:16:22 +0100 Subject: [PATCH 10/20] Apply suggestions from code review --- firedrake/interpolation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 654cdbd04f..2f75ae2213 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -937,6 +937,7 @@ def callable(): # Arguments in the operand are allowed to be from a MixedFunctionSpace # We need to split the target space V and generate separate kernels if len(arguments) == 2: + # Matrix case assumes that the spaces are not mixed expressions = {(0,): expr} elif isinstance(dual_arg, Coargument): # Split in the coargument @@ -954,11 +955,13 @@ def callable(): # Interpolate each sub expression into each function space for indices, sub_expr in expressions.items(): + sub_tensor = tensor[indices[0]] if rank == 1 else tensor if isinstance(sub_expr, ufl.ZeroBaseForm): + if access == op2.WRITE: + sub_tensor.zero() continue arguments = sub_expr.arguments() sub_space = sub_expr.argument_slots()[0].function_space().dual() - sub_tensor = tensor[indices[0]] if rank == 1 else tensor loops.extend(_interpolator(sub_space, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs)) if bcs and rank == 1: From fa679b3c07bc5a41ab64c04e919ad7df33dc2c8a Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Oct 2025 11:20:18 +0100 Subject: [PATCH 11/20] Apply suggestions from code review --- firedrake/interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 2f75ae2213..c6a3db9ef1 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -958,7 +958,7 @@ def callable(): sub_tensor = tensor[indices[0]] if rank == 1 else tensor if isinstance(sub_expr, ufl.ZeroBaseForm): if access == op2.WRITE: - sub_tensor.zero() + loops.append(sub_tensor.zero) continue arguments = sub_expr.arguments() sub_space = sub_expr.argument_slots()[0].function_space().dual() From f72726f4eb80a4ff986d078d419acccb64413c9d Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Oct 2025 11:58:49 +0100 Subject: [PATCH 12/20] tidy --- firedrake/formmanipulation.py | 7 +++++-- firedrake/interpolation.py | 22 +++++++++++++--------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index e1f9ddb885..4806a62ddd 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -183,9 +183,12 @@ def matrix(self, o): bcs = () return AssembledMatrix(tuple(args), bcs, submat) + def zero_base_form(self, o): + return ZeroBaseForm(tuple(map(self, o.arguments()))) + def interpolate(self, o, operand): if isinstance(operand, Zero): - return ZeroBaseForm(o.arguments()) + return self(ZeroBaseForm(o.arguments())) dual_arg, _ = o.argument_slots() V = dual_arg.function_space() @@ -211,7 +214,7 @@ def interpolate(self, o, operand): components = [operand[i] for i in cindices] operand = as_tensor(numpy.reshape(components, W.value_shape)) if isinstance(operand, Zero): - return ZeroBaseForm(o.arguments()) + return self(ZeroBaseForm(o.arguments())) return o._ufl_expr_reconstruct_(operand, dual_arg) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index c6a3db9ef1..9e65d5dbdd 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -933,7 +933,6 @@ def callable(): if access == op2.INC: loops.append(tensor.zero) - dual_arg, operand = expr.argument_slots() # Arguments in the operand are allowed to be from a MixedFunctionSpace # We need to split the target space V and generate separate kernels if len(arguments) == 2: @@ -956,13 +955,7 @@ def callable(): # Interpolate each sub expression into each function space for indices, sub_expr in expressions.items(): sub_tensor = tensor[indices[0]] if rank == 1 else tensor - if isinstance(sub_expr, ufl.ZeroBaseForm): - if access == op2.WRITE: - loops.append(sub_tensor.zero) - continue - arguments = sub_expr.arguments() - sub_space = sub_expr.argument_slots()[0].function_space().dual() - loops.extend(_interpolator(sub_space, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs)) + loops.extend(_interpolator(sub_tensor, sub_expr, subset, access, bcs=bcs)) if bcs and rank == 1: loops.extend(partial(bc.apply, f) for bc in bcs) @@ -976,10 +969,21 @@ def callable(loops, f): @utils.known_pyop2_safe -def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): +def _interpolator(tensor, expr, subset, access, bcs=None): + if isinstance(expr, ufl.ZeroBaseForm): + if access is op2.INC: + return () + elif access is op2.WRITE: + return (tensor.zero,) + V = expr.arguments()[-1].function_space().dual() + expr = interpolate(ufl.zero(V.value_shape), V) + if not isinstance(expr, ufl.Interpolate): raise ValueError("Expecting to interpolate a ufl.Interpolate") + + arguments = expr.arguments() dual_arg, operand = expr.argument_slots() + V = dual_arg.function_space().dual() try: to_element = create_element(V.ufl_element()) From 4ebbb3f92ba417af11b93fc4734b95f4e6209ecd Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Oct 2025 12:19:36 +0100 Subject: [PATCH 13/20] Zero subset --- firedrake/interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 9e65d5dbdd..3ac332937a 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -974,7 +974,7 @@ def _interpolator(tensor, expr, subset, access, bcs=None): if access is op2.INC: return () elif access is op2.WRITE: - return (tensor.zero,) + return (partial(tensor.zero, subset=subset),) V = expr.arguments()[-1].function_space().dual() expr = interpolate(ufl.zero(V.value_shape), V) From bad96e54d600eeb097692a542a6edc9e2741c87c Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Oct 2025 16:48:03 +0100 Subject: [PATCH 14/20] Apply suggestions from code review --- firedrake/formmanipulation.py | 5 +++-- firedrake/interpolation.py | 18 ++++++++++-------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 4806a62ddd..e23b8b801c 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -192,7 +192,8 @@ def interpolate(self, o, operand): dual_arg, _ = o.argument_slots() V = dual_arg.function_space() - if len(V) == 1 or len(dual_arg.arguments()) == 1: + if len(dual_arg.arguments()) == 1 or len(dual_arg.arguments()[-1].function_space()) == 1: + # The dual argument has been contracted or does not need to be split return o._ufl_expr_reconstruct_(operand, dual_arg) # Split the target (dual) argument @@ -200,7 +201,7 @@ def interpolate(self, o, operand): dual_arg = self(dual_arg) indices = self.blocks[dual_arg.number()] else: - raise NotImplementedError() + raise NotImplementedError(f"I do not know how to split an Interpolate with a {type(dual_arg).__name__}.") # Unflatten the expression into the target shapes cur = 0 diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 3ac332937a..9810901c10 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -379,7 +379,7 @@ def _interpolate(self, *args, **kwargs): """ pass - def assemble(self, tensor=None, **kwargs): + def assemble(self, tensor=None, default_missing_val=None): """Assemble the operator (or its action).""" from firedrake.assemble import assemble needs_adjoint = self.ufl_interpolate_renumbered != self.ufl_interpolate @@ -409,10 +409,11 @@ def assemble(self, tensor=None, **kwargs): cofunctions = (dual_arg,) if needs_adjoint and len(arguments) == 0: - Iu = self._interpolate(**kwargs) + Iu = self._interpolate(default_missing_val=default_missing_val) return assemble(ufl.Action(*cofunctions, Iu), tensor=tensor) else: - return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint, **kwargs) + return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint, + default_missing_val=default_missing_val) class DofNotDefinedError(Exception): @@ -514,7 +515,7 @@ def __init__( ufl_scalar_element = V_dest.ufl_element() if isinstance(ufl_scalar_element, finat.ufl.MixedElement): if type(ufl_scalar_element) is finat.ufl.MixedElement: - raise NotImplementedError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") + raise TypeError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") # For a VectorElement or TensorElement the correct # VectorFunctionSpace equivalent is built from the scalar @@ -865,7 +866,7 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): Vrow = arguments[0].function_space() Vcol = arguments[1].function_space() if len(Vrow) > 1 or len(Vcol) > 1: - raise NotImplementedError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") + raise TypeError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") if isinstance(target_mesh.topology, VertexOnlyMeshTopology) and target_mesh is not source_mesh and not vom_onto_other_vom: if not isinstance(target_mesh.topology, VertexOnlyMeshTopology): raise NotImplementedError("Can only interpolate onto a VertexOnlyMesh") @@ -983,7 +984,7 @@ def _interpolator(tensor, expr, subset, access, bcs=None): arguments = expr.arguments() dual_arg, operand = expr.argument_slots() - V = dual_arg.function_space().dual() + V = dual_arg.arguments()[0].function_space() try: to_element = create_element(V.ufl_element()) @@ -1733,7 +1734,7 @@ def __init__(self, expr, V, bcs=None, **kwargs): Isub[indices] = Interpolator(form, Vtarget, bcs=sub_bcs, **kwargs) self._sub_interpolators = Isub - self.callable = self._get_callable + self.callable = self._assemble_matnest def __getitem__(self, item): return self._sub_interpolators[item] @@ -1741,10 +1742,11 @@ def __getitem__(self, item): def __iter__(self): return iter(self._sub_interpolators) - def _get_callable(self): + def _assemble_matnest(self): """Assemble the operator.""" shape = tuple(len(a.function_space()) for a in self.arguments) blocks = numpy.full(shape, PETSc.Mat(), dtype=object) + # Assemble the sparse block matrix for i in self: blocks[i] = self[i].callable().handle petscmat = PETSc.Mat().createNest(blocks) From 7286229e2c220f1036edc8d09e6b9a1573320d4c Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Oct 2025 16:48:42 +0100 Subject: [PATCH 15/20] Update firedrake/interpolation.py --- firedrake/interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 9810901c10..b68f186918 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -413,7 +413,7 @@ def assemble(self, tensor=None, default_missing_val=None): return assemble(ufl.Action(*cofunctions, Iu), tensor=tensor) else: return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint, - default_missing_val=default_missing_val) + default_missing_val=default_missing_val) class DofNotDefinedError(Exception): From 345e8e92d8e7e5aa2eedf4c84d0bcffaa423ea7b Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Oct 2025 16:55:13 +0100 Subject: [PATCH 16/20] lint --- firedrake/formmanipulation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index e23b8b801c..267b26c138 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -193,7 +193,7 @@ def interpolate(self, o, operand): dual_arg, _ = o.argument_slots() V = dual_arg.function_space() if len(dual_arg.arguments()) == 1 or len(dual_arg.arguments()[-1].function_space()) == 1: - # The dual argument has been contracted or does not need to be split + # The dual argument has been contracted or does not need to be split return o._ufl_expr_reconstruct_(operand, dual_arg) # Split the target (dual) argument @@ -203,7 +203,7 @@ def interpolate(self, o, operand): else: raise NotImplementedError(f"I do not know how to split an Interpolate with a {type(dual_arg).__name__}.") - # Unflatten the expression into the target shapes + # Unflatten the expression into the target shape cur = 0 cindices = [] for i, Vi in enumerate(V): From 782fcb72f4b594157dc9c65ad719c41057d3ad76 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Oct 2025 17:01:15 +0100 Subject: [PATCH 17/20] Fix up --- firedrake/formmanipulation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 267b26c138..c7f162c39d 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -191,7 +191,6 @@ def interpolate(self, o, operand): return self(ZeroBaseForm(o.arguments())) dual_arg, _ = o.argument_slots() - V = dual_arg.function_space() if len(dual_arg.arguments()) == 1 or len(dual_arg.arguments()[-1].function_space()) == 1: # The dual argument has been contracted or does not need to be split return o._ufl_expr_reconstruct_(operand, dual_arg) From 555a1eadcce1bda3920c2ebed6dcb660fdd281cf Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Oct 2025 17:11:06 +0100 Subject: [PATCH 18/20] cleanup --- firedrake/formmanipulation.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index c7f162c39d..eb830493e1 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -195,28 +195,29 @@ def interpolate(self, o, operand): # The dual argument has been contracted or does not need to be split return o._ufl_expr_reconstruct_(operand, dual_arg) - # Split the target (dual) argument - if isinstance(dual_arg, Coargument): - dual_arg = self(dual_arg) - indices = self.blocks[dual_arg.number()] - else: + if not isinstance(dual_arg, Coargument): raise NotImplementedError(f"I do not know how to split an Interpolate with a {type(dual_arg).__name__}.") + indices = self.blocks[dual_arg.number()] + V = dual_arg.function_space() + + # Split the target (dual) argument + sub_dual_arg = self(dual_arg) + W = sub_dual_arg.function_space() + # Unflatten the expression into the target shape cur = 0 - cindices = [] + components = [] for i, Vi in enumerate(V): if i in indices: - cindices.extend(range(cur, cur+Vi.value_size)) + components.extend(operand[i] for i in range(cur, cur+Vi.value_size)) cur += Vi.value_size - W = dual_arg.function_space() - components = [operand[i] for i in cindices] operand = as_tensor(numpy.reshape(components, W.value_shape)) if isinstance(operand, Zero): return self(ZeroBaseForm(o.arguments())) - return o._ufl_expr_reconstruct_(operand, dual_arg) + return o._ufl_expr_reconstruct_(operand, sub_dual_arg) SplitForm = collections.namedtuple("SplitForm", ["indices", "form"]) From 50fe34ccfdb381eb11aeda544a61fceb678aab92 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Oct 2025 18:30:42 +0100 Subject: [PATCH 19/20] Comments --- firedrake/interpolation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index b68f186918..4552cd6ee0 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -972,10 +972,13 @@ def callable(loops, f): @utils.known_pyop2_safe def _interpolator(tensor, expr, subset, access, bcs=None): if isinstance(expr, ufl.ZeroBaseForm): + # Zero simplification, avoid code-generation if access is op2.INC: return () elif access is op2.WRITE: return (partial(tensor.zero, subset=subset),) + # Unclear how to avoid codegen for MIN and MAX + # Reconstruct the expression as an Interpolate V = expr.arguments()[-1].function_space().dual() expr = interpolate(ufl.zero(V.value_shape), V) From 085aee927973c47db363c1c03a2d2cb0335fb82e Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Oct 2025 22:18:57 +0100 Subject: [PATCH 20/20] DROP BEFORE MERGE --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8b03f6bcc7..21036a2afc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "mpi4py>3; python_version >= '3.13'", "mpi4py; python_version < '3.13'", # TODO RELEASE: use releases - "fenics-ufl @ git+https://github.com/FEniCS/ufl.git", + "fenics-ufl @ git+https://github.com/FEniCS/ufl.git@main", "firedrake-fiat @ git+https://github.com/firedrakeproject/fiat.git", "h5py>3.12.1", "libsupermesh",