diff --git a/.gitignore b/.gitignore index b4109385..7557ed0b 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,5 @@ release/ /docs/source/finat.rst /docs/source/finat.ufl.rst /docs/source/gem.rst + +firedrake_fiat.egg-info/ diff --git a/FIAT/reference_element.py b/FIAT/reference_element.py index 87794b50..3e41ee3f 100644 --- a/FIAT/reference_element.py +++ b/FIAT/reference_element.py @@ -615,8 +615,10 @@ def get_dimension(self): def compute_barycentric_coordinates(self, points, entity=None, rescale=False): """Returns the barycentric coordinates of a list of points on the complex.""" - if len(points) == 0: + + if isinstance(points, numpy.ndarray) and len(points) == 0: return points + if entity is None: entity = (self.get_spatial_dimension(), 0) entity_dim, entity_id = entity @@ -640,8 +642,11 @@ def compute_barycentric_coordinates(self, points, entity=None, rescale=False): h = 1 / numpy.linalg.norm(A, axis=1) b *= h A *= h[:, None] - out = numpy.dot(points, A.T) - return numpy.add(out, b, out=out) + # out = numpy.dot(points, A.T) + out = points @ A.T + + # return numpy.add(out, b) + return out + b def compute_bubble(self, points, entity=None): """Returns the lowest-order bubble on an entity evaluated at the given @@ -1406,6 +1411,41 @@ def extrinsic_orientation_permutation_map(self): def is_macrocell(self): return any(c.is_macrocell() for c in self.cells) + def compute_axis_barycentric_coordinates(self, points, entity=None, rescale=False): + """Compute barycentric coordinates on each axis (factor) of a tensor-product cell. + + Parameters + ---------- + points: numpy.ndarray or GEM.Node + The reference coordinates of the points. + + Returns + ------- + numpy.ndarray + A flattened array of shape ``(total_bary_coords, )`` and dtype object if points are GEM nodes, + otherwise dtype numeric. The i-th entry contains the barycentric coordinates + on the i-th factor cell. If factor i is a simplex of dimension d, this will + have shape ``(npoints, d+1)``. If factor i is a hypercube of dimension d, + this will have shape ``(npoints, 2*d)``. + """ + + if isinstance(points, numpy.ndarray) and len(points) == 0: + return points + + axis_dims = [c.get_spatial_dimension() for c in self.cells] + point_slices = TensorProductCell._split_slices(axis_dims) + + result = numpy.empty(len(self.cells), dtype=object) + for k, (factor, s) in enumerate(zip(self.cells, point_slices)): + result[k] = factor.compute_barycentric_coordinates(points[..., s], entity, rescale) + + # Flatten the array + # We cannot construct the flat array directly since we may not know upfront the total number + # of barycentric coordinates (e.g., in a simplex it is d+1, in a hypercube it is 2*d) + flat_result = numpy.array([bary[j] for bary in result for j in range(bary.shape[0])]) + + return flat_result + class Hypercube(Cell): """Abstract class for a reference hypercube""" @@ -1423,6 +1463,8 @@ def __init__(self, dimension, product): self.product = product self.unflattening_map = compute_unflattening_map(pt) + self.facet_perm = compute_facet_permutation(self.unflattening_map, self.product) + def get_dimension(self): """Returns the subelement dimension of the cell. Same as the spatial dimension.""" @@ -1521,6 +1563,27 @@ def __ge__(self, other): def __le__(self, other): return self.product <= other + def compute_barycentric_coordinates(self, points, entity=None, rescale=False): + """Returns the barycentric coordinates of a list of points on the hypercube. + + Parameters + ---------- + points: numpy.ndarray or GEM.Node + The reference coordinates of the points. + + Returns + ------- + List of numpy.ndarray or GEM.ComponentTensor + Returns a list of barycentric coordinates in local facet order such that for any point + lying on local facet `lf` of the cell, the barycentric coordinate at index `lf` vanishes. + """ + if isinstance(points, numpy.ndarray) and len(points) == 0: + return points + + tp_bary_coords = self.product.compute_axis_barycentric_coordinates(points, entity, rescale) + + return tp_bary_coords[self.facet_perm] + class UFCHypercube(Hypercube): """Reference UFC Hypercube @@ -1839,6 +1902,60 @@ def compute_unflattening_map(topology_dict): return unflattening_map +def compute_facet_permutation(unflattening_map, product): + """ + Returns a permutation mapping each facet of a `~.Hypercube` to the index of the + barycentric coordinate that vanishes on it. + + The order of barycentric coordinates returned by `compute_axis_barycentric_coordinates` + is determined by axis structure, not by facet numbering. Reordering them by this permutation + yields the invariant: the i-th barycentric coordinate vanishes on the i-th facet. + """ + # First compute axis offsets into the flattened barycentric coordinate array. + axis_offsets = [] + offset = 0 + for axis_cell in product.cells: + axis_offsets.append(offset) + offset += axis_cell.get_dimension() + 1 + + # Initialise the integer permutation array + sd = len(product.cells) + num_facets = 2 * sd + perm = numpy.zeros(num_facets, dtype=int) + + for f in range(num_facets): + # Recover the tensor-product representation of the facet as given by the unflattening map + dim_tuple, tp_entity = unflattening_map[(sd - 1, f)] + + # Determine the axis that's orthogonal to the facet + # E.g., in a quad: + # if dim_tuple = (0,1) -> facet has dimension 0 on the first component -> fixed at x = 0 or x = 1 + # if dim_tuple = (1,0) -> facet has dimension 0 on the second component -> fixed at y = 0 or y = 1 + axis = next( + i for i, d in enumerate(dim_tuple) + if d == product.cells[i].get_dimension() - 1 + ) + + # Determine the index of the endpoint that produces the facet + # which gives the local facet number in the axis space + entity_shape = tuple( + len(c.get_topology()[d]) + for c, d in zip(product.cells, dim_tuple) + ) + tuple_ei = numpy.unravel_index(tp_entity, entity_shape) + local_facet = tuple_ei[axis] + + # For a simplex (UFCInterval, UFCTriangle), the barycentric coordinate that vanishes on local facet i + # corresponds to the ID of the vertex that doesn't belong to that facet + all_vertices = set(product.cells[axis].get_topology()[0].keys()) + facet_vertices = set(product.cells[axis].get_topology()[0][local_facet]) + bary_index = next(iter(all_vertices - facet_vertices)) + + perm[f] = axis_offsets[axis] + bary_index + + return perm + + def max_complex(complexes): max_cell = max(complexes) if all(max_cell >= b for b in complexes): diff --git a/gem/gem.py b/gem/gem.py index 5c9231dc..54fb060c 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """GEM is the intermediate language of TSFC for describing tensor-valued mathematical expressions and tensor operations. It is similar to Einstein's notation. @@ -20,6 +22,8 @@ from operator import attrgetter from numbers import Integral, Number +from types import EllipsisType + import numpy from numpy import asarray @@ -32,7 +36,7 @@ 'Variable', 'Sum', 'Product', 'Division', 'FloorDiv', 'Remainder', 'Power', 'MathFunction', 'MinValue', 'MaxValue', 'Comparison', 'LogicalNot', 'LogicalAnd', 'LogicalOr', 'Conditional', - 'Index', 'VariableIndex', 'Indexed', 'ComponentTensor', + 'Index', 'VariableIndex', 'ListIndex', 'Indexed', 'ComponentTensor', 'IndexSum', 'ListTensor', 'Concatenate', 'Delta', 'OrientationVariableIndex', 'index_sum', 'partial_indexed', 'reshape', 'view', 'indices', 'as_gem', 'FlexiblyIndexed', @@ -81,12 +85,53 @@ def is_equal(self, other): self.children = other.children return result - def __getitem__(self, indices): - try: - indices = tuple(indices) - except TypeError: - indices = (indices, ) - return Indexed(self, indices) + def __getitem__( + self, + key: IndexT | tuple[IndexT, ...], + ) -> ComponentTensor | Indexed: + """A generalised interface for indexing GEM tensors""" + if not isinstance(key, tuple): + key = (key,) + + # Expand ellipsis -> fill in remaining dimensions with slice(None) + if any(k is Ellipsis for k in key): + if key.count(Ellipsis) > 1: + raise NotImplementedError("Multiple ellipses are not supported.") + ellipsis_pos = key.index(Ellipsis) + remaining_dims = len(self.shape) - (len(key) - 1) + if remaining_dims < 0: + raise IndexError("Too many indices provided.") + key = ( + key[:ellipsis_pos] + + (slice(None), ) * remaining_dims + + key[ellipsis_pos + 1:] + ) + + has_slice = any(isinstance(k, slice) for k in key) + has_array = any(isinstance(k, (numpy.ndarray, list)) for k in key) + + if has_slice and has_array: + raise NotImplementedError("Mixed slice and array indexing is not supported.") + + # Slice indexing -> delegate to view() + if has_slice: + # view expects one slice for each axis/dim of the tensor + if len(key) != len(self.shape): + raise IndexError("Expects the number of slices to match the gem.Node tensor rank") + return view(self, *key) + + # Support a list or array of integer indices + # Old approach: build a ListTensor out of Indexed nodes, one for each element of the permutation + if has_array: + pos = next(i for i, k in enumerate(key) if isinstance(k, (numpy.ndarray, list))) + arr = numpy.asarray(key[pos]) + list_index = ListIndex(arr) + new_key = key[:pos] + (list_index,) + key[pos+1:] + indexed = Indexed(self, new_key) + return ComponentTensor(indexed, (list_index.free_index,)) # convert free index back to shape + + # Point indexing + return Indexed(self, key) def __neg__(self): return componentwise(Product, minus, self) @@ -117,6 +162,7 @@ def __matmul__(self, other): raise ValueError("Both objects must have shape for matmul") elif self.shape[-1] != other.shape[0]: raise ValueError(f"Mismatching shapes {self.shape} and {other.shape} in matmul") + *i, k = indices(len(self.shape)) _, *j = indices(len(other.shape)) expr = Product(Indexed(self, (*i, k)), Indexed(other, (k, *j))) @@ -678,6 +724,41 @@ def __reduce__(self): return type(self), (self.expression,) +class ListIndex(IndexBase): + """A lookup index in the form of an index array""" + + __slots__ = ('index_array', 'free_index',) + + def __init__(self, index_array): + assert isinstance(index_array, numpy.ndarray) + assert numpy.issubdtype(index_array.dtype, numpy.integer) + self.index_array = index_array + self.free_index = Index(extent=len(self.index_array)) + + def __eq__(self, other): + if type(self) is not type(other): + return False + return numpy.array_equal(self.index_array, other.index_array) + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash((type(self), self.index_array.tobytes())) + + def __str__(self): + return str(self.index_array) + + def __repr__(self): + return "%r(%s)" % (type(self), self.index_array) + + def __reduce__(self): + return type(self), (self.index_array, ) + + +IndexT = int | Index | VariableIndex | ListIndex | slice | EllipsisType | list | numpy.ndarray + + class Indexed(Scalar): __slots__ = ('children', 'multiindex', 'indirect_children') __back__ = ('multiindex',) @@ -693,6 +774,9 @@ def __new__(cls, aggregate, multiindex): assert isinstance(index, IndexBase) if isinstance(index, Index): index.set_extent(extent) + elif isinstance(index, ListIndex): + if numpy.any(index.index_array < 0) or numpy.any(index.index_array >= extent): + raise IndexError("Invalid index in ListIndex") elif isinstance(index, int) and not (0 <= index < extent): raise IndexError("Invalid literal index") @@ -740,6 +824,8 @@ def __new__(cls, aggregate, multiindex): new_indices.append(i) elif isinstance(i, VariableIndex): new_indices.extend(i.expression.free_indices) + elif isinstance(i, ListIndex): + new_indices.append(i.free_index) self.free_indices = unique(aggregate.free_indices + tuple(new_indices)) return self @@ -752,6 +838,8 @@ def index_ordering(self): free_indices.append(i) elif isinstance(i, VariableIndex): free_indices.extend(i.expression.free_indices) + elif isinstance(i, ListIndex): + free_indices.append(i.free_index) return tuple(free_indices) diff --git a/gem/interpreter.py b/gem/interpreter.py index 13eeb44a..d3f71f75 100644 --- a/gem/interpreter.py +++ b/gem/interpreter.py @@ -263,27 +263,89 @@ def _evaluate_conditional(e, self): def _evaluate_indexed(e, self): """Indexing maps shape to free indices""" val = self(e.children[0]) - fids = tuple(i for i in e.multiindex if isinstance(i, gem.Index)) + fids = tuple( + i if isinstance(i, gem.Index) else i.free_index + for i in e.multiindex + if isinstance(i, (gem.Index, gem.ListIndex)) + ) + all_fids = val.fids + fids idx = [] # First pick up all the existing free indices - for _ in val.fids: - idx.append(slice(None)) + for fid in val.fids: + # idx.append(slice(None)) + shape = tuple(fid.extent if f is fid else 1 for f in all_fids) + idx.append(numpy.arange(fid.extent).reshape(shape)) + # Now grab the shape axes for i in e.multiindex: if isinstance(i, gem.Index): # Free index, want entire extent - idx.append(slice(None)) + # idx.append(slice(None)) + shape = tuple(i.extent if f is i else 1 for f in all_fids) + idx.append(numpy.arange(i.extent).reshape(shape)) elif isinstance(i, gem.VariableIndex): # Variable index, evaluate inner expression result, = self(i.expression) assert not result.tshape idx.append(result[()]) + elif isinstance(i, gem.ListIndex): + # ListIndex, use the index array + shape = tuple(i.free_index.extent if f is i.free_index else 1 for f in all_fids) + idx.append(i.index_array.reshape(shape)) else: # Fixed index, just pick that value idx.append(i) assert len(idx) == len(val.tshape) - return Result(val[idx], val.fids + fids) + return Result(val[idx], all_fids) + + +@_evaluate.register(gem.FlexiblyIndexed) +def _evaluate_flexiblyindexed(e, self): + val = self(e.children[0]) + + all_fids = val.fids + e.free_indices + + idx = [] + + # Collect existing free indices + for fid in val.fids: + shape = tuple(fid.extent if f is fid else 1 for f in all_fids) + fidx_array = numpy.arange(fid.extent).reshape(shape) + idx.append(fidx_array) + + # Compute a flat index for each dim from dim2idxs + for dim, (offset, idxs) in zip(e.children[0].shape, e.dim2idxs): + if isinstance(offset, gem.Node): + offset_result = self(offset) + assert not offset_result.tshape + offset_val = offset_result[()] + else: + offset_val = offset + + dim_flat_idx_components = [] + for i, s in idxs: + # Index i may be one of: Index, VariableIndex, int + if isinstance(i, gem.Index): + # Free index contributes an array index given by the extent + shape = tuple(i.extent if f is i else 1 for f in all_fids) + i_vals = numpy.arange(i.extent).reshape(shape) * s + dim_flat_idx_components.append(i_vals) + elif isinstance(i, gem.VariableIndex): + # Variable index gives a single integer index + # obtained by evaluating its inner expression + result, = self(i.expression) + assert not result.tshape + dim_flat_idx_components.append(result[()]*s) + else: + # Fixed index is just an integer so use that + dim_flat_idx_components.append(i*s) + + # From dim_flat_idx_components compute the flat index into dim + dim_idx_array = offset_val + sum(flat_index_component for flat_index_component in dim_flat_idx_components) + idx.append(dim_idx_array) + + return Result(val[idx], all_fids) @_evaluate.register(gem.ComponentTensor) diff --git a/gem/optimise.py b/gem/optimise.py index caf254e0..d57f3ce1 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -13,7 +13,7 @@ reuse_if_untouched_arg, traversal) from gem.gem import (Node, Failure, Identity, Constant, Literal, Zero, Product, Sum, Comparison, Conditional, Division, - Index, VariableIndex, Indexed, FlexiblyIndexed, + Index, VariableIndex, ListIndex, Indexed, FlexiblyIndexed, IndexSum, ComponentTensor, ListTensor, Delta, partial_indexed, one) @@ -145,6 +145,13 @@ def replace_indices_indexed(node, self, subst): child = Literal(sub, dtype=child.dtype) if isinstance(child, Constant) else ListTensor(sub) multiindex = tuple(i for i in multiindex if not isinstance(i, Integral)) + elif isinstance(child, ListTensor) and len(multiindex) == 1 and isinstance(multiindex[0], ListIndex): + # ListIndex into a ListTensor: apply the permutation at compile time + # by reordering the ListTensor elements, eliminating the free index. + list_index = multiindex[0] + child = ListTensor(child.array[list_index.index_array]) + multiindex = (list_index.free_index,) + if multiindex == node.multiindex and child == node.children[0]: return node else: diff --git a/test/FIAT/unit/test_reference_element.py b/test/FIAT/unit/test_reference_element.py index d166ba6e..e61b6a55 100644 --- a/test/FIAT/unit/test_reference_element.py +++ b/test/FIAT/unit/test_reference_element.py @@ -455,6 +455,80 @@ def test_flatten_maintains_ufc_status(cell): assert ufc_status == is_ufc(flat_cell) +@pytest.mark.parametrize(('cell', 'point'), + [(interval_x_interval, [0.25, 0.6]), + (triangle_x_interval, [0.25, 0.25, 0.5]), + (quadrilateral_x_interval, [0.25, 0.25, 0.5])]) +def test_tp_axis_bary_coords(cell, point, epsilon=1e-12): + point = np.asarray(point) + axis_bary_coords = cell.compute_axis_barycentric_coordinates(point) + + assert type(axis_bary_coords) is np.ndarray + + offset = 0 + bary_offset = 0 + for factor in cell.cells: + sd = factor.get_spatial_dimension() + coords_k = point[offset: offset + sd] + offset += sd + expected = factor.compute_barycentric_coordinates(coords_k) + n_bary = len(expected) + assert np.allclose(axis_bary_coords[bary_offset: bary_offset + n_bary], expected, atol=epsilon) + bary_offset += n_bary + + +@pytest.mark.parametrize(('cell', 'point'), + [(quadrilateral, [0.0, 0.3]), + (quadrilateral, [1.0, 0.3]), + (quadrilateral, [0.3, 0.0]), + (quadrilateral, [0.3, 1.0]), + (hexahedron, [0.0, 0.3, 0.4]), + (hexahedron, [1.0, 0.3, 0.4]), + (hexahedron, [0.3, 0.0, 0.4]), + (hexahedron, [0.3, 1.0, 0.4]), + (hexahedron, [0.3, 0.4, 0.0]), + (hexahedron, [0.3, 0.4, 1.0]),]) +def test_hypercube_bary_coords_are_in_facet_order(cell, point, epsilon=1e-12): + point = np.asarray(point) + + facet_dim = cell.get_spatial_dimension() - 1 + point_entity_ids = cell.point_entity_ids([point]) + facet_hits = [fid for fid, pts in point_entity_ids[facet_dim].items() if len(pts) > 0] + assert len(facet_hits) == 1 + + facet_id = facet_hits[0] + bary_coords = cell.compute_barycentric_coordinates(point) + assert np.isclose(bary_coords[facet_id], 0.0, atol=epsilon) + + mask = np.ones(len(bary_coords), dtype=bool) + mask[facet_id] = False + assert np.all(bary_coords[mask] > epsilon) + + +@pytest.mark.parametrize(('cell', 'point'), + [(interval, [0.5]), + (triangle, [0.25, 0.25]), + (tetrahedron, [0.25, 0.25, 0.25]), + (quadrilateral, [0.25, 0.5]), + (hexahedron, [0.25, 0.5, 0.25]),]) +def test_bary_coords_gem(cell, point): + import gem + from gem.interpreter import evaluate + + point = np.asarray(point) + sd = cell.get_spatial_dimension() + + coords = gem.Variable('X', (sd,)) + bindings = {coords: point} + + bary_gem = cell.compute_barycentric_coordinates(coords) + results, = evaluate((gem.as_gem(bary_gem),), bindings=bindings) + results = results.arr + + bary_numpy = cell.compute_barycentric_coordinates(point) + assert np.allclose(results, bary_numpy) + + if __name__ == '__main__': import os pytest.main(os.path.abspath(__file__))