diff --git a/docs/source/interpolation.rst b/docs/source/interpolation.rst index 116b34cbef..50441902ff 100644 --- a/docs/source/interpolation.rst +++ b/docs/source/interpolation.rst @@ -261,17 +261,11 @@ which indeed contracts into a number. Interpolation across meshes --------------------------- -The interpolation API supports interpolation between meshes where the target -function space has finite elements (as given in the list of -:ref:`supported elements `) - -* **Lagrange/CG** (also known as Continuous Galerkin or P elements), -* **Q** (i.e. Lagrange/CG on lines, quadrilaterals and hexahedra), -* **Discontinuous Lagrange/DG** (also known as Discontinuous Galerkin or DP elements) and -* **DQ** (i.e. Discontinuous Lagrange/DG on lines, quadrilaterals and hexahedra). - -Vector, tensor, and mixed function spaces can also be interpolated into from -other meshes as long as they are constructed from these spaces. +The interpolation API supports interpolation across meshes where the target +function space has any finite element which supports interpolation, as specified in the list of +:ref:`supported elements `. Vector, tensor, and mixed function +spaces can also be interpolated into from other meshes as long as they are +constructed from these spaces. .. note:: diff --git a/firedrake/__init__.py b/firedrake/__init__.py index 1238563ba4..de15331684 100644 --- a/firedrake/__init__.py +++ b/firedrake/__init__.py @@ -69,7 +69,7 @@ def init_petsc(): from firedrake.deflation import DeflatedSNES, Deflation # noqa: F401 from firedrake.exceptions import ( # noqa: F401 FiredrakeException, ConvergenceError, MismatchingDomainError, - VertexOnlyMeshMissingPointsError, DofNotDefinedError + VertexOnlyMeshMissingPointsError, DofNotDefinedError, DofTypeError, ) from firedrake.function import ( # noqa: F401 Function, PointNotInDomainError, diff --git a/firedrake/exceptions.py b/firedrake/exceptions.py index 8816375138..1a74c6a35e 100644 --- a/firedrake/exceptions.py +++ b/firedrake/exceptions.py @@ -17,6 +17,12 @@ class DofNotDefinedError(FiredrakeException): """ +class DofTypeError(FiredrakeException): + """Raised when an operation is attempted on a degree of freedom (DoF) + type which is not supported. + """ + + class VertexOnlyMeshMissingPointsError(FiredrakeException): """Exception raised when 1 or more points are not found by a :func:`~.VertexOnlyMesh` in its parent mesh. diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 6f0954d8ea..a5304d7b6c 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -19,7 +19,7 @@ from pyop2 import op2 from pyop2.caching import memory_and_disk_cache -from finat.ufl import TensorElement, VectorElement, MixedElement +from finat.ufl import TensorElement, VectorElement, MixedElement, FiniteElementBase from finat.element_factory import create_element from tsfc.driver import compile_expression_dual_evaluation @@ -28,7 +28,7 @@ from firedrake.utils import IntType, ScalarType, cached_property, known_pyop2_safe, tuplify from firedrake.pointeval_utils import runtime_quadrature_element from firedrake.tsfc_interface import extract_numbered_coefficients, _cachedir -from firedrake.ufl_expr import Argument, Coargument, action +from firedrake.ufl_expr import Argument, Coargument, TrialFunction, TestFunction, action from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshTopology, MeshGeometry, MeshTopology, VertexOnlyMesh from firedrake.petsc import PETSc from firedrake.halo import _get_mtype @@ -41,7 +41,8 @@ from firedrake.function import Function from firedrake.cofunction import Cofunction from firedrake.exceptions import ( - DofNotDefinedError, VertexOnlyMeshMissingPointsError, NonUniqueMeshSequenceError + DofNotDefinedError, VertexOnlyMeshMissingPointsError, NonUniqueMeshSequenceError, + DofTypeError, ) from mpi4py import MPI @@ -422,17 +423,6 @@ def __init__(self, expr: Interpolate): else: self.access = op2.WRITE - # TODO check V.finat_element.is_lagrange() once https://github.com/firedrakeproject/fiat/pull/200 is released - target_element = self.target_space.ufl_element() - if not ((isinstance(target_element, MixedElement) - and all(sub.mapping() == "identity" for sub in target_element.sub_elements)) - or target_element.mapping() == "identity"): - # Identity mapping between reference cell and physical coordinates - # implies point evaluation nodes. - raise NotImplementedError( - "Can only cross-mesh interpolate into spaces with point evaluation nodes." - ) - if self.allow_missing_dofs: self.missing_points_behaviour = MissingPointsBehaviour.IGNORE else: @@ -441,26 +431,61 @@ def __init__(self, expr: Interpolate): if self.source_mesh.unique().geometric_dimension != self.target_mesh.unique().geometric_dimension: raise ValueError("Geometric dimensions of source and destination meshes must match.") + # Interpolate into intermediate quadrature space for non-point-evaluation elements + if into_quadrature_space := not self.target_space.finat_element.has_pointwise_dual_basis: + self.original_target_space = self.target_space + r"""The original target space for interpolation, as specified by the user.""" + self.target_space = self.target_space.quadrature_space() + r"""The target space for the cross-mesh interpolation. Must have point-evaluation dofs. + If ``self.original_target_space`` does not have point-evaluation dofs, then this is + an intermediate quadrature space.""" + + self.into_quadrature_space = into_quadrature_space + + @cached_property + def _target_space_element(self) -> FiniteElementBase: + """The element of `self.target_space`. If `self.target_space` is tensor/vector valued, + the base scalar element. + + Returns + ------- + FiniteElementBase + The base element of `self.target_space`. + """ dest_element = self.target_space.ufl_element() if isinstance(dest_element, MixedElement): if isinstance(dest_element, VectorElement | TensorElement): # In this case all sub elements are equal - base_element = dest_element.sub_elements[0] - if base_element.reference_value_shape != (): - raise NotImplementedError( - "Can't yet cross-mesh interpolate onto function spaces made from VectorElements " - "or TensorElements made from sub elements with value shape other than ()." - ) - self.dest_element = base_element + return dest_element.sub_elements[0] else: raise NotImplementedError("Interpolation with MixedFunctionSpace requires MixedInterpolator.") else: # scalar fiat/finat element - self.dest_element = dest_element + return dest_element + + @cached_property + def _target_space_type(self) -> Callable[..., WithGeometry]: + """Returns a callable which returns a function space matching the type of `self.target_space`. + + Returns + ------- + Callable + A callable which returns a :class:`.WithGeometry` matching the type of `self.target_space`. + """ + # Get the correct type of function space + shape = self.target_space.value_shape + if len(shape) == 0: + return FunctionSpace + elif len(shape) == 1: + return partial(VectorFunctionSpace, dim=shape[0]) + else: + symmetry = self.target_space.ufl_element().symmetry() + return partial(TensorFunctionSpace, shape=shape, symmetry=symmetry) - def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]: - """Return the symbolic ``Interpolate`` expressions for point evaluation and - re-ordering into the input-ordering VertexOnlyMesh. + @cached_property + def _symbolic_expressions(self) -> tuple[Interpolate, Interpolate]: + """The symbolic ``Interpolate`` expressions for point evaluation of `self.target_space`s + dofs in the source mesh, and the corresponding input-ordering interpolation. Returns ------- @@ -470,14 +495,22 @@ def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]: Raises ------ - DofNotDefinedError - If any DoFs in the target mesh cannot be defined in the source mesh. + DoFNotDefinedError + If any of the target spaces dofs cannot be defined in the source mesh. + DoFTypeError + If the target space does not have point-evaluation dofs. """ from firedrake.assemble import assemble + if not self.target_space.finat_element.has_pointwise_dual_basis: + raise DofTypeError(f"FunctionSpace {self.target_space} must have point-evaluation dofs.") + # Immerse coordinates of target space point evaluation dofs in src_mesh - target_space_vec = VectorFunctionSpace(self.target_mesh.unique(), self.dest_element) - f_dest_node_coords = assemble(interpolate(self.target_mesh.unique().coordinates, target_space_vec)) - dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, self.target_mesh.unique().geometric_dimension) + # If `self.into_quadrature_space` is true, then the point evaluation dofs + # are the quadrature points of the original target space. + target_mesh = self.target_space.mesh().unique() + target_space_vec = VectorFunctionSpace(target_mesh, self._target_space_element) + f_dest_node_coords = assemble(interpolate(target_mesh.coordinates, target_space_vec)) + dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, target_mesh.geometric_dimension) try: vom = VertexOnlyMesh( self.source_mesh.unique(), @@ -486,44 +519,53 @@ def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]: missing_points_behaviour=self.missing_points_behaviour, ) except VertexOnlyMeshMissingPointsError: - raise DofNotDefinedError(f"The given target function space on domain {self.target_mesh} " + raise DofNotDefinedError(f"The given target function space on domain {target_mesh} " "contains degrees of freedom which cannot cannot be defined in the " - f"source function space on domain {self.source_mesh}. " + f"source function space on domain {self.source_mesh.unique()}. " "This may be because the target mesh covers a larger domain than the " "source mesh. To disable this error, set allow_missing_dofs=True.") - # Get the correct type of function space - shape = self.target_space.ufl_function_space().value_shape - if len(shape) == 0: - fs_type = FunctionSpace - elif len(shape) == 1: - fs_type = partial(VectorFunctionSpace, dim=shape[0]) - else: - symmetry = self.target_space.ufl_element().symmetry() - fs_type = partial(TensorFunctionSpace, shape=shape, symmetry=symmetry) - - # Get expression for point evaluation at the dest_node_coords - P0DG_vom = fs_type(vom, "DG", 0) + # Expression for point evaluation at the dest_node_coords + P0DG_vom = self._target_space_type(vom, "DG", 0) point_eval = interpolate(self.operand, P0DG_vom) - # Interpolate into the input-ordering VOM - P0DG_vom_input_ordering = fs_type(vom.input_ordering, "DG", 0) - + # Expression for interpolating into the input-ordering VOM + P0DG_vom_input_ordering = self._target_space_type(vom.input_ordering, "DG", 0) arg = Argument(P0DG_vom, 0 if self.ufl_interpolate.is_adjoint else 1) point_eval_input_ordering = interpolate(arg, P0DG_vom_input_ordering) + return point_eval, point_eval_input_ordering + @cached_property + def _interpolate_from_quadrature(self) -> Interpolate: + """Returns symbolic expression for interpolation from the intermediate quadrature + space into the user-provided target space. Only relevant if `self.into_quadrature_space` is True. + + Returns + ------- + Interpolate + A symbolic interpolate expression. + """ + if self.rank == 2: + if self.ufl_interpolate.is_adjoint: + return interpolate(TestFunction(self.target_space), self.original_target_space) + else: + return interpolate(TrialFunction(self.target_space), self.original_target_space) + elif self.ufl_interpolate.is_adjoint: + return interpolate(TestFunction(self.target_space), self.dual_arg) + def _get_callable(self, tensor=None, bcs=None, mat_type=None, sub_mat_type=None): from firedrake.assemble import assemble if bcs: raise NotImplementedError("bcs not implemented for cross-mesh interpolation.") mat_type = mat_type or "aij" - # self.ufl_interpolate.function_space() is None in the 0-form case - V_dest = self.ufl_interpolate.function_space() or self.target_space - f = tensor or Function(V_dest) + if self.into_quadrature_space: + f = Function(self.target_space.dual() if self.ufl_interpolate.is_adjoint else self.target_space) + else: + f = tensor or Function(self.ufl_interpolate.function_space() or self.target_space) - point_eval, point_eval_input_ordering = self._get_symbolic_expressions() + point_eval, point_eval_input_ordering = self._symbolic_expressions P0DG_vom_input_ordering = point_eval_input_ordering.argument_slots()[0].function_space().dual() if self.rank == 2: @@ -532,47 +574,65 @@ def _get_callable(self, tensor=None, bcs=None, mat_type=None, sub_mat_type=None) # `self.point_eval_interpolate` and the permutation # given by `self.to_input_ordering_interpolate`. if self.ufl_interpolate.is_adjoint: - symbolic = action(point_eval, point_eval_input_ordering) + interp_expr = action(point_eval, point_eval_input_ordering) else: - symbolic = action(point_eval_input_ordering, point_eval) + interp_expr = action(point_eval_input_ordering, point_eval) def callable() -> PETSc.Mat: - return assemble(symbolic, mat_type=mat_type).petscmat + res = assemble(interp_expr, mat_type=mat_type).petscmat + if self.into_quadrature_space: + source_space = self.ufl_interpolate.function_space() + if self.ufl_interpolate.is_adjoint: + I = AssembledMatrix((Argument(source_space, 0), Argument(self.target_space.dual(), 1)), None, res) + return assemble(action(I, self._interpolate_from_quadrature)).petscmat + else: + I = AssembledMatrix((Argument(self.target_space.dual(), 0), Argument(source_space, 1)), None, res) + return assemble(action(self._interpolate_from_quadrature, I)).petscmat + else: + return res + elif self.ufl_interpolate.is_adjoint: assert self.rank == 1 - # f_src is a cofunction on V_dest.dual - cofunc = self.dual_arg - assert isinstance(cofunc, Cofunction) - - # Our first adjoint operation is to assign the dat values to a - # P0DG cofunction on our input ordering VOM. - f_input_ordering = Cofunction(P0DG_vom_input_ordering.dual()) - f_input_ordering.dat.data_wo[:] = cofunc.dat.data_ro[:] - - # The rest of the adjoint interpolation is the composition - # of the adjoint interpolators in the reverse direction. - # We don't worry about skipping over missing points here - # because we're going from the input ordering VOM to the original VOM - # and all points from the input ordering VOM are in the original. + def callable() -> Cofunction: + if self.into_quadrature_space: + cofunc = assemble(self._interpolate_from_quadrature) + f_target = Cofunction(point_eval.function_space()) + else: + cofunc = self.dual_arg + f_target = f + + assert isinstance(cofunc, Cofunction) + + # Our first adjoint operation is to assign the dat values to a + # P0DG cofunction on our input ordering VOM. + f_input_ordering = Cofunction(P0DG_vom_input_ordering.dual()) + f_input_ordering.dat.data_wo[:] = cofunc.dat.data_ro[:] + + # The rest of the adjoint interpolation is the composition + # of the adjoint interpolators in the reverse direction. + # We don't worry about skipping over missing points here + # because we're going from the input ordering VOM to the original VOM + # and all points from the input ordering VOM are in the original. f_src_at_src_node_coords = assemble(action(point_eval_input_ordering, f_input_ordering)) - assemble(action(point_eval, f_src_at_src_node_coords), tensor=f) - return f + assemble(action(point_eval, f_src_at_src_node_coords), tensor=f_target) + return f_target else: assert self.rank in {0, 1} - # We create the input-ordering Function before interpolating so we can - # set default missing values if required. - f_point_eval_input_ordering = Function(P0DG_vom_input_ordering) - if self.default_missing_val is not None: - f_point_eval_input_ordering.assign(self.default_missing_val) - elif self.allow_missing_dofs: - # If we allow missing points there may be points in the target - # mesh that are not in the source mesh. If we don't specify a - # default missing value we set these to NaN so we can identify - # them later. - f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan def callable() -> Function | Number: + # We create the input-ordering Function before interpolating so we can + # set default missing values if required. + f_point_eval_input_ordering = Function(P0DG_vom_input_ordering) + if self.default_missing_val is not None: + f_point_eval_input_ordering.assign(self.default_missing_val) + elif self.allow_missing_dofs: + # If we allow missing points there may be points in the target + # mesh that are not in the source mesh. If we don't specify a + # default missing value we set these to NaN so we can identify + # them later. + f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan + assemble(action(point_eval_input_ordering, point_eval), tensor=f_point_eval_input_ordering) # We assign these values to the output function if self.allow_missing_dofs and self.default_missing_val is None: @@ -581,12 +641,18 @@ def callable() -> Function | Number: else: f.dat.data_wo[:] = f_point_eval_input_ordering.dat.data_ro[:] + if self.into_quadrature_space: + f_target = Function(self.original_target_space) + assemble(interpolate(f, self.original_target_space), tensor=f_target) + else: + f_target = f + if self.rank == 0: # We take the action of the dual_arg on the interpolated function assert isinstance(self.dual_arg, Cofunction) - return assemble(action(self.dual_arg, f)) + return assemble(action(self.dual_arg, f_target)) else: - return f + return f_target return callable @property diff --git a/tests/firedrake/regression/test_cross_mesh_non_lagrange.py b/tests/firedrake/regression/test_cross_mesh_non_lagrange.py new file mode 100644 index 0000000000..ac0ea40ecb --- /dev/null +++ b/tests/firedrake/regression/test_cross_mesh_non_lagrange.py @@ -0,0 +1,174 @@ +from firedrake import * +import pytest +import numpy as np +from functools import partial + + +def mat_equals(a, b) -> bool: + """Check that two Matrices are equal.""" + a = a.petscmat.copy() + a.axpy(-1.0, b.petscmat) + return a.norm(norm_type=PETSc.NormType.NORM_FROBENIUS) < 1e-14 + + +def fs_shape(V): + shape = V.ufl_function_space().value_shape + if len(shape) == 0: + return FunctionSpace + elif len(shape) == 1: + return partial(VectorFunctionSpace, dim=shape[0]) + elif len(shape) == 2: + return partial(TensorFunctionSpace, shape=shape) + else: + raise ValueError("Invalid function space shape") + + +@pytest.fixture(params=[("RT", 1), ("RT", 2), ("BDM", 1), ("BDM", 2), ("BDFM", 2), + ("HHJ", 0), ("HHJ", 2), ("N1curl", 1), ("N1curl", 2), + ("N2curl", 1), ("N2curl", 2), ("GLS", 1), ("GLS", 2), + ("GLS2", 2), ("Regge", 0), ("Regge", 2)], + ids=lambda x: f"{x[0]}_{x[1]}") +def V(request): + element, degree = request.param + mesh = UnitSquareMesh(16, 16) + return FunctionSpace(mesh, element, degree) + + +@pytest.mark.parallel([1, 3]) +@pytest.mark.parametrize("rank", [1, 2]) +def test_cross_mesh(V, rank): + mesh1 = UnitSquareMesh(5, 5) + mesh2 = V.mesh() + x, y = SpatialCoordinate(mesh1) + x1, y1 = SpatialCoordinate(mesh2) + + shape = V.ufl_function_space().value_shape + if len(shape) == 0: + fs_type = FunctionSpace + expr1 = x * x + y * y + expr2 = x1 * x1 + y1 * y1 + elif len(shape) == 1: + fs_type = partial(VectorFunctionSpace, dim=shape[0]) + expr1 = as_vector([x, y]) + expr2 = as_vector([x1, y1]) + elif len(shape) == 2: + fs_type = partial(TensorFunctionSpace, shape=shape) + expr1 = as_tensor([[x, x*y], [x*y, y]]) + expr2 = as_tensor([[x1, x1*y1], [x1*y1, y1]]) + else: + raise ValueError("Unsupported target space shape") + + V_source = fs_type(mesh1, "CG", 2) + f_source = Function(V_source).interpolate(expr1) + f_direct = Function(V).interpolate(expr2) + + Q = V.quadrature_space() + + if rank == 2: + # Assemble the operator + I1 = interpolate(TrialFunction(V_source), Q) # V_source x Q_target^* -> R + I2 = interpolate(TrialFunction(Q), V) # Q_target x V^* -> R + I_manual = assemble(action(I2, I1)) # V_source x V^* -> R + assert I_manual.arguments() == (TestFunction(V.dual()), TrialFunction(V_source)) + # Direct assembly + I_direct = assemble(interpolate(TrialFunction(V_source), V)) # V_source + assert I_direct.arguments() == (TestFunction(V.dual()), TrialFunction(V_source)) + assert mat_equals(I_manual, I_direct) + + f_interpolated_manual = assemble(action(I_manual, f_source)) + assert np.allclose(f_interpolated_manual.dat.data_ro, f_direct.dat.data_ro) + f_interpolated_direct = assemble(action(I_direct, f_source)) + assert np.allclose(f_interpolated_direct.dat.data_ro, f_direct.dat.data_ro) + elif rank == 1: + # Interp V_source -> Q + I1 = interpolate(f_source, Q) # SameMesh + f_quadrature = assemble(I1) + # Interp Q -> V + I2 = interpolate(f_quadrature, V) # CrossMesh + f_interpolated_manual = assemble(I2) + assert f_interpolated_manual.function_space() == V + assert np.allclose(f_interpolated_manual.dat.data_ro, f_direct.dat.data_ro) + + f_interpolated_direct = assemble(interpolate(f_source, V)) + assert f_interpolated_direct.function_space() == V + assert np.allclose(f_interpolated_direct.dat.data_ro, f_direct.dat.data_ro) + + +@pytest.mark.parallel([1, 3]) +@pytest.mark.parametrize("rank", [0, 1, 2]) +def test_cross_mesh_adjoint(V, rank): + # Can already do Lagrange -> RT adjoint + # V^* -> Q^* -> V_target^* + name = V.ufl_element()._short_name + deg = V.ufl_element().degree() + if name in ["N1curl", "GLS", "RT"] and deg == 1: + exact = False + elif name in ["Regge", "HHJ"] and deg == 0: + exact = False + else: + exact = True + + mesh1 = UnitSquareMesh(2, 2) + x1 = SpatialCoordinate(mesh1) + V_target = fs_shape(V)(mesh1, "CG", 1) + + mesh2 = V.mesh() + x2 = SpatialCoordinate(mesh2) + + if len(V.value_shape) > 1: + expr = outer(x2, x2) + target_expr = outer(x1, x1) + if V.ufl_element().mapping() == "covariant contravariant Piola": + expr = dev(expr) + target_expr = dev(target_expr) + else: + expr = x2 + target_expr = x1 + + oneform_V = inner(expr, TestFunction(V)) * dx # V^* + cofunc_Vtarget_direct = assemble(inner(target_expr, TestFunction(V_target)) * dx) + + if exact: + def close(x, y): + if rank == 0: + return np.isclose(x, y) + else: + return np.allclose(x, y) + else: + def close(x, y): + return np.linalg.norm(x - y) < 0.003 + + Q = V.quadrature_space() + + if rank == 2: + # Assemble the operator + I1 = interpolate(TestFunction(Q), V) # V^* x Q -> R + I2 = interpolate(TestFunction(V_target), Q) # Q^* x V_target -> R + I_manual = assemble(action(I2, I1)) # V^* x V_target -> R + assert I_manual.arguments() == (TestFunction(V_target), TrialFunction(V.dual())) + # Direct assembly + I_direct = assemble(interpolate(TestFunction(V_target), V)) # V^* x V_target -> R + assert I_direct.arguments() == (TestFunction(V_target), TrialFunction(V.dual())) + assert mat_equals(I_manual, I_direct) + + cofunc_Vtarget_manual = assemble(action(I_manual, oneform_V)) + assert close(cofunc_Vtarget_manual.dat.data_ro, cofunc_Vtarget_direct.dat.data_ro) + + cofunc_Vtarget = assemble(action(I_direct, oneform_V)) + assert close(cofunc_Vtarget.dat.data_ro, cofunc_Vtarget_direct.dat.data_ro) + elif rank == 1: + # Interp V^* -> Q^* + I1_adj = interpolate(TestFunction(Q), oneform_V) # SameMesh + cofunc_Q = assemble(I1_adj) + + # Interp Q^* -> V_target^* + I2_adj = interpolate(TestFunction(V_target), cofunc_Q) # CrossMesh + cofunc_Vtarget_manual = assemble(I2_adj) + assert close(cofunc_Vtarget_manual.dat.data_ro, cofunc_Vtarget_direct.dat.data_ro) + + cofunc_Vtarget = assemble(interpolate(TestFunction(V_target), oneform_V)) # V^* -> V_target^* + assert close(cofunc_Vtarget.dat.data_ro, cofunc_Vtarget_direct.dat.data_ro) + elif rank == 0: + res = assemble(interpolate(target_expr, oneform_V)) + actual = assemble(inner(expr, expr) * dx) + assert close(res, actual) diff --git a/tests/firedrake/regression/test_interpolate_cross_mesh.py b/tests/firedrake/regression/test_interpolate_cross_mesh.py index 2011dc9466..a04e3edbc0 100644 --- a/tests/firedrake/regression/test_interpolate_cross_mesh.py +++ b/tests/firedrake/regression/test_interpolate_cross_mesh.py @@ -6,7 +6,6 @@ import numpy as np import pytest from ufl import product -import subprocess def allgather(comm, coords): @@ -17,9 +16,9 @@ def allgather(comm, coords): return coords -def unitsquaresetup(): +def unitsquaresetup(dest_quad=True): m_src = UnitSquareMesh(2, 3) - m_dest = UnitSquareMesh(3, 5, quadrilateral=True) + m_dest = UnitSquareMesh(3, 5, quadrilateral=dest_quad) coords = np.array( [[0.56, 0.6], [0.1, 0.9], [0.9, 0.1], [0.9, 0.9], [0.726, 0.6584]] ) # fairly arbitrary @@ -49,14 +48,7 @@ def make_high_order(m_low_order, degree): "unitsquare_vfs", "unitsquare_tfs", "unitsquare_N1curl_source", - pytest.param( - "unitsquare_SminusDiv_destination", - marks=pytest.mark.xfail( - # CalledProcessError is so the parallel tests correctly xfail - raises=(subprocess.CalledProcessError, NotImplementedError), - reason="Can only interpolate into spaces with point evaluation nodes", - ), - ), + "unitsquare_RT_N1curl_destination", "unitsquare_Regge_source", # This test fails in complex mode pytest.param("spheresphere", marks=pytest.mark.skipcomplex), @@ -188,14 +180,14 @@ def parameters(request): V_src = FunctionSpace(m_src, "N1curl", 2) # Not point evaluation nodes V_dest = VectorFunctionSpace(m_dest, "CG", 4) V_dest_2 = VectorFunctionSpace(m_dest, "DQ", 2) - elif request.param == "unitsquare_SminusDiv_destination": - m_src, m_dest, coords = unitsquaresetup() + elif request.param == "unitsquare_RT_N1curl_destination": + m_src, m_dest, coords = unitsquaresetup(dest_quad=False) expr_src = 2 * SpatialCoordinate(m_src) expr_dest = 2 * SpatialCoordinate(m_dest) expected = 2 * coords V_src = VectorFunctionSpace(m_src, "CG", 2) - V_dest = FunctionSpace(m_dest, "SminusDiv", 2) # Not point evaluation nodes - V_dest_2 = FunctionSpace(m_dest, "SminusCurl", 2) # Not point evaluation nodes + V_dest = FunctionSpace(m_dest, "RT", 2) # Not point evaluation nodes + V_dest_2 = FunctionSpace(m_dest, "N1curl", 2) # Not point evaluation nodes elif request.param == "unitsquare_Regge_source": m_src, m_dest, coords = unitsquaresetup() expr_src = outer(SpatialCoordinate(m_src), SpatialCoordinate(m_src)) @@ -418,32 +410,6 @@ def test_interpolate_unitsquare_tfs_shape(shape, symmetry): assemble(interpolate(f_src, V_dest)) -def test_interpolate_cross_mesh_not_point_eval(): - m_src = UnitSquareMesh(2, 3) - m_dest = UnitSquareMesh(3, 5, quadrilateral=True) - coords = np.array( - [[0.56, 0.6], [0.1, 0.9], [0.9, 0.1], [0.9, 0.9], [0.726, 0.6584]] - ) # fairly arbitrary - # add the coordinates of the mesh vertices to test boundaries - vertices_src = allgather(m_src.comm, m_src.coordinates.dat.data_ro) - coords = np.concatenate((coords, vertices_src)) - vertices_dest = allgather(m_dest.comm, m_dest.coordinates.dat.data_ro) - coords = np.concatenate((coords, vertices_dest)) - dest_eval = PointEvaluator(m_dest, coords) - expr_src = 2 * SpatialCoordinate(m_src) - expr_dest = 2 * SpatialCoordinate(m_dest) - expected = 2 * coords - V_src = FunctionSpace(m_src, "RT", 2) - V_dest = FunctionSpace(m_dest, "RTCE", 2) - atol = 1e-8 # default - # This might not make much mathematical sense, but it should test if we get - # the not implemented error for non-point evaluation nodes! - with pytest.raises(NotImplementedError): - interpolate_function( - m_src, m_dest, V_src, V_dest, dest_eval, expected, expr_src, expr_dest, atol - ) - - def interpolate_function( m_src, m_dest, V_src, V_dest, dest_eval, expected, expr_src, expr_dest, atol ): diff --git a/tests/firedrake/regression/test_interpolation_manual.py b/tests/firedrake/regression/test_interpolation_manual.py index d6dfcf09d6..cedaee54be 100644 --- a/tests/firedrake/regression/test_interpolation_manual.py +++ b/tests/firedrake/regression/test_interpolation_manual.py @@ -102,6 +102,7 @@ def mydata(points): def test_line_integral(): # [test_line_integral 1] + from firedrake.mesh import Mesh, plex_from_cell_list # Start with a simple field exactly represented in the function space over # the unit square domain. m = UnitSquareMesh(2, 2) @@ -113,8 +114,8 @@ def test_line_integral(): # Note that it only has 1 cell cells = np.asarray([[0, 1]]) vertex_coords = np.asarray([[0.0, 0.0], [1.0, 1.0]]) - plex = mesh.plex_from_cell_list(1, cells, vertex_coords, comm=m.comm) - line = mesh.Mesh(plex, dim=2) + plex = plex_from_cell_list(1, cells, vertex_coords, comm=m.comm) + line = Mesh(plex, dim=2) # [test_line_integral 2] x, y = SpatialCoordinate(line) V_line = FunctionSpace(line, "CG", 2)