Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 0 additions & 28 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import finat.ufl
from firedrake import (extrusion_utils as eutils, matrix, parameters, solving,
tsfc_interface, utils)
from firedrake.formmanipulation import split_form
from firedrake.adjoint_utils import annotate_assemble
from firedrake.ufl_expr import extract_unique_domain
from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit
Expand Down Expand Up @@ -570,36 +569,9 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
rank = len(expr.arguments())
if rank > 2:
raise ValueError("Cannot assemble an Interpolate with more than two arguments")
# If argument numbers have been swapped => Adjoint.
arg_operand = ufl.algorithms.extract_arguments(operand)
is_adjoint = (arg_operand and arg_operand[0].number() == 0)

# Get the target space
V = v.function_space().dual()

# Dual interpolation from mixed source
if is_adjoint and len(V) > 1:
cur = 0
sub_operands = []
components = numpy.reshape(operand, (-1,))
for Vi in V:
sub_operands.append(ufl.as_tensor(components[cur:cur+Vi.value_size].reshape(Vi.value_shape)))
cur += Vi.value_size

# Component-split of the primal operands interpolated into the dual argument-split
split_interp = sum(reconstruct_interp(sub_operands[i], v=vi) for (i,), vi in split_form(v))
return assemble(split_interp, tensor=tensor)

# Dual interpolation into mixed target
if is_adjoint and len(arg_operand[0].function_space()) > 1 and rank == 1:
V = arg_operand[0].function_space()
tensor = tensor or firedrake.Cofunction(V.dual())

# Argument-split of the Interpolate gets assembled into the corresponding sub-tensor
for (i,), sub_interp in split_form(expr):
assemble(sub_interp, tensor=tensor.subfunctions[i])
return tensor

# Get the interpolator
interp_data = expr.interp_data.copy()
default_missing_val = interp_data.pop('default_missing_val', None)
Expand Down
50 changes: 49 additions & 1 deletion firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy
import collections

from ufl import as_vector, split
from ufl import as_tensor, as_vector, split
from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.algorithms import expand_derivatives
Expand All @@ -14,6 +14,7 @@
from firedrake.petsc import PETSc
from firedrake.functionspace import MixedFunctionSpace
from firedrake.cofunction import Cofunction
from firedrake.ufl_expr import Coargument
from firedrake.matrix import AssembledMatrix


Expand Down Expand Up @@ -133,6 +134,17 @@ def argument(self, o):
args.extend(Zero() for j in numpy.ndindex(V[i].value_shape))
return self._arg_cache.setdefault(o, as_vector(args))

def coargument(self, o):
V = o.function_space()

if len(V) == 1:
# Not on a mixed space, just return ourselves.
return o

indices = self.blocks[o.number()]
W = subspace(V, indices)
return Coargument(W, number=o.number(), part=o.part())

def cofunction(self, o):
V = o.function_space()

Expand Down Expand Up @@ -171,6 +183,42 @@ def matrix(self, o):
bcs = ()
return AssembledMatrix(tuple(args), bcs, submat)

def zero_base_form(self, o):
return ZeroBaseForm(tuple(map(self, o.arguments())))

def interpolate(self, o, operand):
if isinstance(operand, Zero):
return self(ZeroBaseForm(o.arguments()))

dual_arg, _ = o.argument_slots()
if len(dual_arg.arguments()) == 1 or len(dual_arg.arguments()[-1].function_space()) == 1:
# The dual argument has been contracted or does not need to be split
return o._ufl_expr_reconstruct_(operand, dual_arg)

if not isinstance(dual_arg, Coargument):
raise NotImplementedError(f"I do not know how to split an Interpolate with a {type(dual_arg).__name__}.")

indices = self.blocks[dual_arg.number()]
V = dual_arg.function_space()

# Split the target (dual) argument
sub_dual_arg = self(dual_arg)
W = sub_dual_arg.function_space()

# Unflatten the expression into the target shape
cur = 0
components = []
for i, Vi in enumerate(V):
if i in indices:
components.extend(operand[i] for i in range(cur, cur+Vi.value_size))
cur += Vi.value_size

operand = as_tensor(numpy.reshape(components, W.value_shape))
if isinstance(operand, Zero):
return self(ZeroBaseForm(o.arguments()))

return o._ufl_expr_reconstruct_(operand, sub_dual_arg)


SplitForm = collections.namedtuple("SplitForm", ["indices", "form"])

Expand Down
Loading