diff --git a/src/scifem/__init__.py b/src/scifem/__init__.py index c0f50da..c825ede 100644 --- a/src/scifem/__init__.py +++ b/src/scifem/__init__.py @@ -22,6 +22,7 @@ ) from .eval import evaluate_function, find_cell_extrema, compute_extrema from .interpolation import interpolation_matrix, prepare_interpolation_data +from .geometry import closest_point_projection meta = metadata("scifem") __version__ = meta["Version"] @@ -38,6 +39,7 @@ "__email__", "__program_name__", "PointSource", + "closest_point_projection", "assemble_scalar", "create_space_of_simple_functions", "compute_interface_data", diff --git a/src/scifem/geometry.py b/src/scifem/geometry.py new file mode 100644 index 0000000..7b452f3 --- /dev/null +++ b/src/scifem/geometry.py @@ -0,0 +1,221 @@ +import dolfinx +import numpy as np +import numpy.typing as npt +import warnings + + +def project_onto_simplex( + v: npt.NDArray[np.float64 | np.float32], +) -> npt.NDArray[np.float64 | np.float32]: + """ + Exact projection of vector v onto the simplex {x >= 0, sum(x) <= 1}. + + See Algorithm 1: Laurent Condat. Fast Projection onto the Simplex and the l1 Ball. + Mathematical Programming, Series A, 2016, 158 (1), pp.575-585. + ⟨DOI: 10.1007/s10107-015-0946-6⟩. + + + Args: + v: The vector to project onto simplex + + Return: + The projection of v onto the simplex. + + """ + # 1. First try the unconstrained positive quadrant projection + v = v.reshape(-1) + u = np.maximum(v, 0.0) + if np.sum(u) <= 1.0: + return u + + # 2. Otherwise, project exactly onto the slanted face sum(x) = 1 + tdim = len(v) + if tdim == 2: + sort_v = np.array([max(u[0], u[1]), min(u[0], u[1])]) + cssv = np.array([sort_v[0], sort_v[0] + sort_v[1]]) + elif tdim == 3: + sort_v = np.sort(v)[::-1] + cssv = np.array([sort_v[0], sort_v[0] + sort_v[1], sort_v[0] + sort_v[1] + sort_v[2]]) + else: + raise RuntimeError("Projection onto simplex is only implemented for 2D and 3D vectors.") + # cssv = np.cumsum(sort_v) + + # Find the primal-dual root + # sum x_i = a + # Find K: = max (k in [1,.. N] such that sum_{r=1}^k sort_v[r] - a)/k < u_k + # Multiply by k and take the last true entry to get the index rho + K = np.nonzero(sort_v * np.arange(1, tdim + 1) > (cssv - 1.0))[0][-1] + tau = (cssv[K] - 1.0) / (K + 1.0) + + return np.maximum(v - tau, 0.0) + + +def closest_point_projection( + mesh: dolfinx.mesh.Mesh, + cells: npt.NDArray[np.int32], + target_points: npt.NDArray[np.float64 | np.float32], + tol_x: float | None = None, + tol_dist: float = 1e-10, + tol_grad: float = 1e-10, + max_iter: int = 2000, + max_ls_iter: int = 250, +) -> tuple[npt.NDArray[np.float64 | np.float32], npt.NDArray[np.float64 | np.float32]]: + """ + Projects a 3D point onto a cell in a potentially higher order mesh. + + Uses the Goldstein-Levitin-Polyak Gradient projection method, where + potential simplex constraints are handled by an exact projection using a + primal-dual root finding method. See: + - Held, M., Wolfe, P., Crowder, H.: Validation of subgradient optimization (1974) + - Laurent Condat. Fast Projection onto the Simplex and the l1 Ball. (2016) + - Dimitri P. Bertsekas, "On the Goldstein-Levitin-Polyak gradient projection method," (1976) + + Args: + mesh: {py:class}`dolfinx.mesh.Mesh`, the mesh containing the cell. + cells: {py:class}`numpy.ndarray`, the local indices of the cells to project onto. + target_point: (3,) numpy array, the 3D point to project. + tol_x: Tolerance for changes between iterates in the reference coordinates. + If None, uses the square root of machine precision. + tol_dist: Tolerance used to determine if the projected point is close enough to + the target point to stop optimization. + max_iter: int, the maximum number of iterations for the projected gradient method. + max_ls_iter: int, the maximum number of line search iterations. + + Returns: + A tuple of arrays containing the closest points (in physical space) + and reference coordinates for each cell to each target point. + """ + dtype = mesh.geometry.x.dtype + eps = np.finfo(dtype).eps + tol_x = np.sqrt(eps) if tol_x is None else tol_x + roundoff_tol = 100 * eps + + # Extract scalar element of mesh + element = mesh.ufl_domain().ufl_coordinate_element().sub_elements[0] + tdim = mesh.topology.dim + # Get the coordinates of the nodes for the specified cell + node_coords = mesh.geometry.x[mesh.geometry.dofmap[cells]][:, :, : mesh.geometry.dim] + target_points = target_points.reshape(-1, 3) + # cmap = mesh.geometry.cmap + + # Constraints and Bounds + cell_type = mesh.topology.cell_type + + # Set initial guess and tolerance for solver + initial_guess = np.full(mesh.topology.dim, 1 / (mesh.topology.dim + 1), dtype=dtype) + closest_points = np.zeros((target_points.shape[0], 3), dtype=dtype) + reference_points = np.zeros((target_points.shape[0], mesh.topology.dim), dtype=dtype) + is_simplex = cell_type in [ + dolfinx.mesh.CellType.triangle, + dolfinx.mesh.CellType.tetrahedron, + ] + + if is_simplex: + + def project(x): + return project_onto_simplex(x) + else: + + def project(x): + return np.clip(x, 0.0, 1.0) + + for i, (coord, target_point) in enumerate(zip(node_coords, target_points)): + coord = coord.reshape(-1, mesh.geometry.dim) + x_k = initial_guess.copy() + + for k in range(max_iter): + x_old = x_k.copy() + + # Evaluate basis functions and first order derivatives at current reference point + tab = element.tabulate(1, x_k.reshape(1, tdim)) + + # Push forward to physical space + surface_point = np.dot(tab[0, 0, :], coord) + + # Compute current objective function + diff = surface_point - target_point + current_dist_sq = 0.5 * np.linalg.norm(diff) ** 2 + + # Compute the gradient in reference coordinates using the Jacobian (tangent vectors) + tangents = np.dot(tab[1 : tdim + 1, 0, :], coord) + g = np.dot(tangents, diff) + + # Check for convergence in gradient norm, scaled by the Jacobian to account + # for stretching of the reference space + jac_norm = np.linalg.norm(tangents) + scaled_tol_grad = tol_grad * max(jac_norm, 1.0) + if np.linalg.norm(g) < scaled_tol_grad: + break + + # 3. Goldstein-Polyak-Levitin Projected Line Search + # Bertsekas (1976) Eq. (14) - Armijo Rule along the Projection Arc + sigma = 0.1 # Sufficient decrease parameter (0 < sigma < 0.5) + beta = 0.5 # Reduction factor (0 < beta < 1) + alpha = 1.0 # Initial step size + + x_new_prev = np.full(tdim, -1, dtype=dtype) + target_reached = False + for li in range(max_ls_iter): + # Apply the exact analytical simplex projection + x_new = project(x_k - alpha * g) + + if np.linalg.norm(x_new - x_new_prev) < eps: + # The projection is pinned to a boundary. + # Changing alpha further will not change the physical point! + break + x_new_prev = x_new.copy() + + # The actual physical step we took after hitting the geometric walls + actual_step = x_new - x_k + + # Evaluate distance at the projected point + tab_new = element.tabulate(0, x_new.reshape(1, tdim)) + S_new = np.dot(tab_new[0, 0, :], coord) + new_dist_sq = 0.5 * np.linalg.norm(S_new - target_point) ** 2 + if new_dist_sq < 0.5 * tol_dist**2: + # We are close enough to the target point, no need for further line search + target_reached = True + break + + # Bertsekas Eq. (14) condition: + # f(x_new) <= f(x_k) + sigma * grad_f(x_k)^T * (x_new - x_k) + # Note: g is grad_f(x_k) + if new_dist_sq <= current_dist_sq + sigma * np.dot(g, actual_step) + roundoff_tol: + # Condition satisfied + break + + # Reduction step (Backtracking) + alpha *= beta + + if li == max_ls_iter - 1: + warnings.warn( + f"Line search failed to converge after {max_ls_iter} iterations " + + f"for cell {cells[i]} and {target_point=}." + ) + x_k[:] = x_new + + if target_reached: + break + + # 4. Check for convergence + if np.linalg.norm(x_k - x_old) < tol_x: + break + if new_dist_sq < 0.5 * tol_dist**2: + print("Projected point is within tolerance of target point, stopping optimization.") + break + + assert np.allclose(project(x_k), x_k), "Projection failed to satisfy constraints" + + # Final coordinate extraction + tab_final = element.tabulate(0, x_k.reshape(1, tdim)) + closest_points[i] = np.dot(tab_final[0, 0, :], coord) + reference_points[i] = x_k + + if k == max_iter - 1: + raise RuntimeError( + f"Projected gradient method failed to converge after {max_iter} iterations ", + f"for cell {cells[i]} and {target_point=} and final iterate {x_k=} ", + f"and final point {closest_points[i]} with final distance ", + f"{np.linalg.norm(closest_points[i] - target_point)}.", + ) + return closest_points, reference_points diff --git a/tests/test_geometry.py b/tests/test_geometry.py new file mode 100644 index 0000000..e805181 --- /dev/null +++ b/tests/test_geometry.py @@ -0,0 +1,197 @@ +from mpi4py import MPI +import dolfinx +import numpy as np +import numpy.typing as npt +from scipy.optimize import minimize +from scifem import closest_point_projection +from scifem.geometry import project_onto_simplex +import ufl +import basix.ufl +import pytest + + +def scipy_project_point_to_element( + mesh: dolfinx.mesh.Mesh, + cells: npt.NDArray[np.int64], + target_points: npt.NDArray[np.float64 | np.float32], + method=None, + tol: float | None = None, +): + """ + Projects a 3D point onto a cell in a potentially higher order mesh. + + Args: + mesh: {py:class}`dolfinx.mesh.Mesh`, the mesh containing the cell. + cells: {py:class}`numpy.ndarray`, the indices of the cells to project onto. + target_point: (3,) numpy array, the 3D point to project. + method: str, the optimization method to use. + tol: float, the tolerance for the optimizer. + + Returns: + dict: A dictionary containing the reference coordinates, closest 3D point, and distance. + """ + + # Extract scalar element of mesh + element = mesh.ufl_domain().ufl_coordinate_element().sub_elements[0] + + # Get the coordinates of the nodes for the specified cell + node_coords = mesh.geometry.x[mesh.geometry.dofmap[cells]][:, :, : mesh.geometry.dim] + target_points = target_points.reshape(-1, 3) + # cmap = mesh.geometry.cmap + + # Constraints and Bounds + cell_type = mesh.topology.cell_type + if ( + cell_type == dolfinx.mesh.CellType.triangle + or cell_type == dolfinx.mesh.CellType.tetrahedron + ): + method = method or "SLSQP" + constraint = {"type": "ineq", "fun": lambda x: 1.0 - np.sum(x)} + else: + method = method or "L-BFGS-B" + constraint = {} + bounds = [(0.0, 1.0) for _ in range(mesh.topology.dim)] + + # Set initial guess and tolerance for solver + initial_guess = np.full(mesh.topology.dim, 1 / (mesh.topology.dim + 1), dtype=np.float64) + tol = np.sqrt(np.finfo(mesh.geometry.x.dtype).eps) if tol is None else tol + closest_points = np.zeros((target_points.shape[0], 3), dtype=mesh.geometry.x.dtype) + for i, (coord, target_point) in enumerate(zip(node_coords, target_points)): + coord = coord.reshape(-1, mesh.geometry.dim) + + def S(x_ref): + N_vals = element.tabulate(0, x_ref.reshape(1, mesh.topology.dim))[0, 0, :] + return np.dot(N_vals, coord) + + def dSdx_ref(x_ref): + """Evaluate jacobian (tangent vectors) at the given reference coordinates.""" + dN = element.tabulate(1, x_ref.reshape(1, mesh.topology.dim))[ + 1 : mesh.topology.dim + 1, 0, : + ] + return np.dot(dN, coord) + + def objective(x_ref): + surface_point = S(x_ref) + diff = surface_point - target_point + return 0.5 * np.linalg.norm(diff) ** 2 + + def objective_grad(x_ref): + diff = S(x_ref) - target_point + tangents = dSdx_ref(x_ref) + return np.dot(tangents, diff) + + res = minimize( + objective, + initial_guess, + method=method, + jac=objective_grad, + bounds=bounds, + constraints=constraint, + tol=tol, + options={"disp": False, "ftol": tol, "maxiter": 250}, + ) + closest_points[i] = S(res.x) + assert res.success, f"Optimization failed for {cells[i]} and {target_point=}: {res.message}" + return closest_points, res.x + + +@pytest.mark.parametrize("order", [1, 2]) +def test_2D_manifold(order): + comm = MPI.COMM_SELF + + # Curved quadratic triangle in 3D (6 nodes) + curved_nodes = np.array( + [ + [0.0, 0.0, 0.0], # Node 0: Vertex + [1.0, 0.0, 0.0], # Node 1: Vertex + [0.0, 1.0, 0.0], # Node 2: Vertex + [0.6, 0.6, 0.0], # Node 4: Edge 1-2 + [0.1, 0.5, 0.2], # Node 5: Edge 2-0 (curved upward in Z) + [0.5, 0.1, 0.2], # Node 3: Edge 0-1 (curved upward in Z) + ] + ) + cells = np.array([[0, 1, 2, 3, 4, 5]], dtype=np.int64) # Single curved triangle element + c_el = ufl.Mesh(basix.ufl.element("Lagrange", "triangle", order, shape=(3,))) + if order == 1: + curved_nodes = curved_nodes[:3] # Use only vertices for linear case + cells = cells[:, :3] + mesh = dolfinx.mesh.create_mesh(comm, cells=cells, x=curved_nodes, e=c_el) + + tol = 1e-7 + tol_dist = 1e-7 + theta = np.linspace(0, 4 * np.pi, 250) + rand = np.random.RandomState(42) + R = rand.rand(len(theta)) + z = rand.rand(len(theta)) * 0.5 # Add some random z variation + points = np.vstack([R * np.cos(theta), R * np.sin(theta), z]).T + + for point_to_project in points: + (result_scipy, ref_scipy) = scipy_project_point_to_element( + mesh, np.array([0], dtype=np.int32), point_to_project, tol=tol + ) + + result, ref_coords = closest_point_projection( + mesh, np.array([0], dtype=np.int32), point_to_project, tol_x=tol, tol_dist=tol_dist + ) + # Check that we are within the bounds of the simplex + ref_proj = project_onto_simplex(ref_coords[0]) + np.testing.assert_allclose(ref_proj, ref_coords[0]) + + dist_scipy = 0.5 * np.sum(result_scipy - point_to_project) ** 2 + dist_ours = 0.5 * np.sum(result - point_to_project) ** 2 + if not np.isclose(dist_ours, dist_scipy, atol=tol_dist, rtol=1e-2): + assert np.linalg.norm(ref_coords - ref_scipy) < 1e-2 + else: + assert np.isclose(dist_ours, dist_scipy, atol=tol, rtol=1e-2) + + +@pytest.mark.parametrize("order", [1, 2]) +def test_3D_curved_cell(order): + comm = MPI.COMM_SELF + + curved_nodes_tet = np.array( + [ + [0.0, 0.0, 0.0], # 0: Vertex + [1.0, 0.0, 0.0], # 1: Vertex + [0.0, 1.0, 0.0], # 2: Vertex + [0.0, 0.0, 1.0], # 3: Vertex + [0.0, 0.5, 0.5], # 4: Edge 2-3 + [0.5, 0.0, 0.5], # 5: Edge 1-3 + [0.5, 0.5, 0.0], # 6: Edge 1-2 + [0.0, 0.0, 0.5], # 9: Edge 0-3 + [0.0, 0.5, 0.2], # 8: Edge 0-2 (Curved) + [0.5, 0.0, 0.2], # 7: Edge 0-1 (Curved) + ], + dtype=np.float64, + ) + cells_tet = np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], dtype=np.int64) + + domain_tet = ufl.Mesh(basix.ufl.element("Lagrange", "tetrahedron", order, shape=(3,))) + if order == 1: + curved_nodes_tet = curved_nodes_tet[:4] # Use only vertices for linear case + cells_tet = cells_tet[:, :4] + mesh = dolfinx.mesh.create_mesh(comm, cells=cells_tet, x=curved_nodes_tet, e=domain_tet) + + rand = np.random.RandomState(32) + points = rand.rand(100, 3) - 0.5 * rand.rand(100, 3) + tol = 1e-7 + tol_dist = 1e-7 + + for point_to_project in points: + result_scipy, ref_scipy = scipy_project_point_to_element( + mesh, np.array([0], dtype=np.int32), point_to_project, tol=tol + ) + + result, ref_coords = closest_point_projection( + mesh, np.array([0], dtype=np.int32), point_to_project, tol_x=tol, tol_dist=tol_dist + ) + # Check that we are within the bounds of the simplex + ref_proj = project_onto_simplex(ref_coords[0]) + np.testing.assert_allclose(ref_proj, ref_coords[0]) + + dist_scipy = 0.5 * np.sum(result_scipy - point_to_project) ** 2 + dist_ours = 0.5 * np.sum(result - point_to_project) ** 2 + if not np.isclose(dist_ours, dist_scipy, atol=tol_dist, rtol=1e-2): + assert np.linalg.norm(ref_coords - ref_scipy) < 1e-2 + else: + assert np.isclose(dist_ours, dist_scipy, atol=tol, rtol=1e-2)