Skip to content

Closest point projection #198

@jorgensd

Description

@jorgensd

Would be neat to have. Example use scipy. Will rewrite to hand-written newton.

import numpy as np
from mpi4py import MPI
import ufl
import basix.ufl
import dolfinx
from scipy.optimize import minimize
import numpy.typing as npt


def project_point_to_element(
    mesh: dolfinx.mesh.Mesh,
    cells: npt.NDArray[np.int64],
    target_points: npt.NDArray[np.float64 | np.float32],
    method=None,
    tol: float | None = None,
):
    """
    Projects a 3D point onto a cell in a potentially higher order mesh.

    Args:
        mesh: {py:class}`dolfinx.mesh.Mesh`, the mesh containing the cell.
        cells: {py:class}`numpy.ndarray`, the indices of the cells to project onto.
        target_point: (3,) numpy array, the 3D point to project.
        method: str, the optimization method to use.
        tol: float, the tolerance for the optimizer.

    Returns:
        dict: A dictionary containing the reference coordinates, closest 3D point, and distance.
    """

    # Extract scalar element of mesh
    element = mesh.ufl_domain().ufl_coordinate_element().sub_elements[0]

    # Get the coordinates of the nodes for the specified cell
    node_coords = mesh.geometry.x[mesh.geometry.dofmap[cells]][
        :, :, : mesh.geometry.dim
    ]
    target_points = target_points.reshape(-1, 3)
    # cmap = mesh.geometry.cmap

    # Constraints and Bounds
    cell_type = mesh.topology.cell_type
    if cell_type == dolfinx.mesh.CellType.triangle:
        method = method or "SLSQP"
        constraint = {"type": "ineq", "fun": lambda x: 1.0 - np.sum(x)}
    else:
        method = method or "L-BFGS-B"
        constraint = {}
    bounds = [(0.0, 1.0) for _ in range(mesh.topology.dim)]

    # Set initial guess and tolerance for solver
    initial_guess = np.full(
        mesh.topology.dim, 1 / (mesh.topology.dim + 1), dtype=np.float64
    )
    tol = np.sqrt(np.finfo(mesh.geometry.x.dtype).eps) if tol is None else tol
    closest_points = np.zeros((target_points.shape[0], 3), dtype=mesh.geometry.x.dtype)
    for i, (coord, target_point) in enumerate(zip(node_coords, target_points)):
        coord = coord.reshape(-1, mesh.geometry.dim)

        def S(x_ref):
            N_vals = element.tabulate(0, x_ref.reshape(1, mesh.topology.dim))[0, 0, :]
            return np.dot(N_vals, coord)
            # return cmap.push_forward(x_ref.reshape(1, mesh.topology.dim), coord)[0]

        def dSdx_ref(x_ref):
            """Evaluate jacobian (tangent vectors) at the given reference coordinates."""
            dN = element.tabulate(1, x_ref.reshape(1, mesh.topology.dim))[
                1 : mesh.topology.dim + 1, 0, :
            ]
            return np.dot(dN, coord)

        def objective(x_ref):
            surface_point = S(x_ref)
            diff = surface_point - target_point
            return 0.5 * np.sum(diff**2)

        def objective_grad(x_ref):
            diff = S(x_ref) - target_point
            tangents = dSdx_ref(x_ref)
            return np.dot(tangents, diff)

        res = minimize(
            objective,
            initial_guess,
            method=method,
            jac=objective_grad,
            bounds=bounds,
            constraints=constraint,
            tol=tol,
        )
        closest_points[i] = S(res.x)
        assert res.success, (
            f"Optimization failed for {cells[i]} and {target_point=}: {res.message}"
        )
    return closest_points


# --- Example Usage ---
if __name__ == "__main__":
    # Dummy curved quadratic triangle in 3D (6 nodes)
    curved_nodes = np.array(
        [
            [0.0, 0.0, 0.0],  # Node 0: Vertex
            [1.0, 0.0, 0.0],  # Node 1: Vertex
            [0.0, 1.0, 0.0],  # Node 2: Vertex
            [0.6, 0.6, 0.0],  # Node 4: Edge 1-2
            [0.1, 0.5, 0.2],  # Node 5: Edge 2-0 (curved upward in Z)
            [0.5, 0.1, 0.2],  # Node 3: Edge 0-1 (curved upward in Z)
        ]
    )
    cells = np.array(
        [[0, 1, 2, 3, 4, 5]], dtype=np.int64
    )  # Single curved triangle element
    c_el = ufl.Mesh(basix.ufl.element("Lagrange", "triangle", 2, shape=(3,)))
    mesh = dolfinx.mesh.create_mesh(MPI.COMM_WORLD, cells=cells, x=curved_nodes, e=c_el)

    point_to_project = np.array([0.3, 0.3, -0.2])  # A point floating above the element
    import time

    start = time.perf_counter()
    result = project_point_to_element(
        mesh, np.array([0], dtype=np.int32), point_to_project
    )
    end = time.perf_counter()
    print("Closest Point on Element:", result, f"Time taken: {end - start:.6e} seconds")

    with dolfinx.io.VTXWriter(mesh.comm, "closest_point.bp", mesh) as writer:
        writer.write(0.0)

    curved_nodes_tet = np.array(
        [
            [0.0, 0.0, 0.0],  # 0: Vertex
            [1.0, 0.0, 0.0],  # 1: Vertex
            [0.0, 1.0, 0.0],  # 2: Vertex
            [0.0, 0.0, 1.0],  # 3: Vertex
            [0.0, 0.5, 0.5],  # 4: Edge 2-3
            [0.5, 0.0, 0.5],  # 5: Edge 1-3
            [0.5, 0.5, 0.0],  # 6: Edge 1-2
            [0.0, 0.0, 0.5],  # 9: Edge 0-3
            [0.0, 0.5, 0.2],  # 8: Edge 0-2 (Curved)
            [0.5, 0.0, 0.2],  # 7: Edge 0-1 (Curved)
        ],
        dtype=np.float64,
    )
    cells_tet = np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], dtype=np.int64)

    domain_tet = ufl.Mesh(basix.ufl.element("Lagrange", "tetrahedron", 2, shape=(3,)))
    mesh_tet = dolfinx.mesh.create_mesh(
        MPI.COMM_WORLD, cells=cells_tet, x=curved_nodes_tet, e=domain_tet
    )
    start = time.perf_counter()
    result = project_point_to_element(
        mesh_tet, np.array([0], dtype=np.int32), point_to_project
    )
    end = time.perf_counter()
    print("Closest Point on Element:", result, f"Time taken: {end - start:.6e} seconds")

    with dolfinx.io.VTXWriter(
        mesh_tet.comm, "closest_point_tet.bp", mesh_tet
    ) as writer:
        writer.write(0.0)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions