From b98cefe3b61624661bd7718474a80e82635542d7 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Tue, 7 Apr 2026 22:44:08 +0200 Subject: [PATCH 01/14] Initial implementation for RBF interpolation for adaptivity --- micro_manager/adaptivity/adaptivity.py | 97 ++ micro_manager/adaptivity/global_adaptivity.py | 5 +- micro_manager/adaptivity/local_adaptivity.py | 5 +- micro_manager/adaptivity/model_adaptivity.py | 3 +- micro_manager/config.py | 16 + micro_manager/interpolation.py | 977 ++++++++++++++++++ micro_manager/micro_manager.py | 24 +- 7 files changed, 1115 insertions(+), 12 deletions(-) diff --git a/micro_manager/adaptivity/adaptivity.py b/micro_manager/adaptivity/adaptivity.py index edfb8897..98927414 100644 --- a/micro_manager/adaptivity/adaptivity.py +++ b/micro_manager/adaptivity/adaptivity.py @@ -58,6 +58,43 @@ def __init__( self._max_similarity_dist = 0.0 + self._interpolation = None + self._mappings = [] + self._mapping_configs = [] + mappings = configurator.get_adaptivity_mapping_configs() + for mapping in mappings: + src_fields = mapping['src_fields'] + dst_fields = mapping['dst_fields'] + n_neighbors = mapping['n_neighbors'] + + self._mappings.append((src_fields, dst_fields)) + config = {} + if 'use_pu' in mapping['rbf_config']: + config['use_pu'] = mapping['rbf_config']['use_pu'] + if 'pu_overlap' in mapping['rbf_config']: + config['pu_overlap'] = mapping['rbf_config']['pu_overlap'] + config['pu_cluster_size'] = n_neighbors + if 'basis' in mapping['rbf_config']: + if 'type' in mapping['rbf_config']['basis']: + config['basis'] = mapping['rbf_config']['basis']['type'] + if config['basis'] == 'gauss' and 'eps' in mapping['rbf_config']['basis']: + config['gauss_eps'] = mapping['rbf_config']['basis']['eps'] + + dom_config = {} + dom_config['n_neighbors'] = n_neighbors + if 'max_filling' in mapping['domain_config']: + dom_config['max_filling'] = mapping['domain_config']['max_filling'] + if 'coarsening_factor' in mapping['domain_config']: + dom_config['coarsening_factor'] = mapping['domain_config']['coarsening_factor'] + if 'projection' in mapping['domain_config']: + if 'type' in mapping['domain_config']['projection']: + dom_config['projection_type'] = mapping['domain_config']['projection']['type'] + if 'target_dims' in mapping['domain_config']['projection']: + dom_config['projection_std_dims'] = mapping['domain_config']['projection']['target_dims'] + + config['domain_config'] = dom_config + self._mapping_configs.append(config) + # is_sim_active: 1D array having state (active or inactive) of each micro simulation # Start adaptivity calculation with all sims active # This array is modified in place via the function update_active_sims and update_inactive_sims @@ -199,6 +236,66 @@ def _check_for_deactivation(self, active_id: int, active_ids: list) -> bool: return True return False + def _interpolate_output(self, micro_input, micro_sims_output): + # now all outputs are init to representative + + # We need RBF interp here + # will treat function f1 ... fN described in the config + # fi: X -> Y, X and Y must be subsets of the coupled fields + + # mapping list[tuple[list[str], list[str]]] = list[mapping=tuple[src args, dst args]] + # will aggregate args + # every output may only be used once as interpolation target + targets = [] + for _, target_args in self._mappings: + targets.extend(target_args) + assert len(targets) == len(set(targets)) + + # precompute arg sizes + active_lids = self.get_active_sim_local_ids() + inactive_lids = self.get_inactive_sim_local_ids() + arg_sizes = {} + for name, value in micro_input[active_lids[0]].items(): + arg_sizes[name] = 1 if type(value) != np.ndarray and type(value) != list else len(value) + for name, value in micro_sims_output[active_lids[0]].items(): + arg_sizes[name] = 1 if type(value) != np.ndarray and type(value) != list else len(value) + + # compute interpolation + n_points = len(active_lids) + n_points_inactive = len(inactive_lids) + for m_idx, fun in enumerate(self._mappings): + src_args, dst_args = fun + src_size = np.array([arg_sizes[name] for name in src_args]).sum() + dst_size = np.array([arg_sizes[name] for name in dst_args]).sum() + input_data = np.zeros((n_points, src_size)) + output_data = np.zeros((n_points, dst_size)) + for idx, lid in enumerate(active_lids): + offset = 0 + for src_arg in src_args: + input_data[idx, offset:offset + arg_sizes[src_arg]] = micro_input[lid][src_arg] + offset += arg_sizes[src_arg] + offset = 0 + for dst_arg in dst_args: + output_data[idx, offset:offset + arg_sizes[dst_arg]] = micro_sims_output[lid][dst_arg] + offset += arg_sizes[dst_arg] + input_data_inactive = np.zeros((n_points_inactive, src_size)) + for idx, lid in enumerate(inactive_lids): + offset = 0 + for src_arg in src_args: + input_data_inactive[idx, offset:offset + arg_sizes[src_arg]] = micro_input[lid][src_arg] + offset += arg_sizes[src_arg] + + # use interpolant + self._interpolation.configure(self._mappings[m_idx]) + self._interpolation.set_local_data(input_data, input_data_inactive, output_data) + output_data_inactive = self._interpolation.interpolate() + + for idx, lid in enumerate(inactive_lids): + offset = 0 + for dst_arg in dst_args: + micro_sims_output[lid][dst_arg] = output_data_inactive[idx, offset:offset + arg_sizes[dst_arg]] + offset += arg_sizes[dst_arg] + def _get_similarity_measure( self, similarity_measure: str ) -> Callable[[np.ndarray], np.ndarray]: diff --git a/micro_manager/adaptivity/global_adaptivity.py b/micro_manager/adaptivity/global_adaptivity.py index 0657443d..bc85a53d 100644 --- a/micro_manager/adaptivity/global_adaptivity.py +++ b/micro_manager/adaptivity/global_adaptivity.py @@ -15,6 +15,7 @@ from micro_manager.tools.logging_wrapper import Logger from micro_manager.micro_simulation import MicroSimulationClass from micro_manager.model_manager import ModelManager +from micro_manager.interpolation import RBF_PU from micro_manager.tools.p2p import p2p_comm, get_ranks_of_sims @@ -68,6 +69,7 @@ def __init__( self._global_ids = global_ids self._comm = comm + self._interpolation = RBF_PU(configurator, base_logger, comm, self._rank, self._comm.Get_size()) rank_of_sim = get_ranks_of_sims(global_ids, rank, comm, global_number_of_sims) self._is_sim_on_this_rank = [False] * global_number_of_sims # DECLARATION @@ -240,7 +242,7 @@ def get_inactive_sim_global_ids(self) -> np.ndarray: return np.array(inactive_sim_ids) - def get_full_field_micro_output(self, micro_output: list) -> list: + def get_full_field_micro_output(self, micro_input: list, micro_output: list) -> list: """ Get the full field micro output from active simulations to inactive simulations. @@ -260,6 +262,7 @@ def get_full_field_micro_output(self, micro_output: list) -> list: micro_sims_output = deepcopy(micro_output) self._communicate_micro_output(micro_sims_output) + self._interpolate_output(micro_input, micro_sims_output) self._precice_participant.stop_last_profiling_section() diff --git a/micro_manager/adaptivity/local_adaptivity.py b/micro_manager/adaptivity/local_adaptivity.py index bbd9d614..d1cf8425 100644 --- a/micro_manager/adaptivity/local_adaptivity.py +++ b/micro_manager/adaptivity/local_adaptivity.py @@ -12,6 +12,7 @@ from micro_manager.micro_simulation import MicroSimulationClass from micro_manager.tools.logging_wrapper import Logger from micro_manager.model_manager import ModelManager +from micro_manager.interpolation import RBF_PU class LocalAdaptivityCalculator(AdaptivityCalculator): @@ -49,6 +50,7 @@ def __init__( configurator, num_sims, micro_problem_cls, model_manager, base_logger, rank ) self._comm = comm + self._interpolation = RBF_PU(configurator, base_logger, MPI.COMM_SELF, MPI.COMM_SELF.Get_rank(), MPI.COMM_SELF.Get_size()) # similarity_dists: 2D array having similarity distances between each micro simulation pair # This matrix is modified in place via the function update_similarity_dists @@ -146,7 +148,7 @@ def get_inactive_sim_global_ids(self) -> np.ndarray: inactive_sim_ids = self.get_inactive_sim_local_ids() return inactive_sim_ids - def get_full_field_micro_output(self, micro_output: list) -> list: + def get_full_field_micro_output(self, micro_input: list, micro_output: list) -> list: """ Get the full field micro output from active simulations to inactive simulations. @@ -168,6 +170,7 @@ def get_full_field_micro_output(self, micro_output: list) -> list: micro_sims_output[inactive_id] = deepcopy( micro_sims_output[self._sim_is_associated_to[inactive_id]] ) + self._interpolate_output(micro_input, micro_sims_output) return micro_sims_output diff --git a/micro_manager/adaptivity/model_adaptivity.py b/micro_manager/adaptivity/model_adaptivity.py index 304a67dc..bb50dca3 100644 --- a/micro_manager/adaptivity/model_adaptivity.py +++ b/micro_manager/adaptivity/model_adaptivity.py @@ -453,6 +453,7 @@ def _create_active_mask(self, active_sim_ids: list, size: int) -> np.ndarray: active_sims = np.ones(size) else: mask = np.zeros(size) - mask[active_sim_ids] = 1 + if len(active_sim_ids) > 0: + mask[active_sim_ids] = 1 active_sims = mask return active_sims.astype(bool) diff --git a/micro_manager/config.py b/micro_manager/config.py index 7d6b634b..577633a1 100644 --- a/micro_manager/config.py +++ b/micro_manager/config.py @@ -50,6 +50,7 @@ def __init__(self, config_file_name): self._adaptivity = False self._adaptivity_type = "" self._data_for_adaptivity = dict() + self._adaptivity_mappings = [] self._adaptivity_n = 1 self._adaptivity_history_param = 0.5 self._adaptivity_coarsening_constant = 0.5 @@ -328,6 +329,9 @@ def read_json_micro_manager(self): self._logger.log_info_rank_zero("Adaptivity type: " + self._adaptivity_type) + if self._adaptivity_type == "global": + self._adaptivity_mappings = self._data["simulation_params"]["adaptivity_settings"]["mappings"] + if self._data["simulation_params"]["adaptivity_settings"].get( "lazy_initialization" ): @@ -566,6 +570,7 @@ def read_json_micro_manager(self): ) if self._m_adap: + self._write_data_names.append("model_resolution") self._m_adap_micro_file_names = [ name.replace("/", ".").replace("\\", ".").replace(".py", "") for name in self._data["simulation_params"][ @@ -849,6 +854,17 @@ def get_adaptivity_type(self): """ return self._adaptivity_type + def get_adaptivity_mapping_configs(self): + """ + Get the mapping configurations for the adaptivity interpolation scheme. + + Returns + ------- + adaptivity_mapping_configs : list + List of adaptivity mapping configurations. + """ + return self._adaptivity_mappings + def get_data_for_adaptivity(self): """ Get names of data to be used for similarity distance calculation in adaptivity diff --git a/micro_manager/interpolation.py b/micro_manager/interpolation.py index ee93a894..3b1269c3 100644 --- a/micro_manager/interpolation.py +++ b/micro_manager/interpolation.py @@ -1,6 +1,20 @@ +from abc import ABC, abstractmethod +from copy import deepcopy +from enum import Enum +from functools import partial +from typing import Optional +import sys + +from mpi4py import MPI import numpy as np from sklearn.neighbors import NearestNeighbors +from micro_manager.tools.p2p import create_tag + +# handle compat issue between np version 1 and 2 +if int(np.version.version.split(".")[0]) > 1: + np.alltrue = np.all + class Interpolation: def __init__(self, logger): @@ -81,3 +95,966 @@ def interpolate(self, neighbors: np.ndarray, point: np.ndarray, values): summed_weights += 1 / norm return interpol_val / summed_weights + +class NDtree: + class Mode(Enum): + DISCRETIZE = 0 + INDEX = 1 + + class Node: + def __init__( + self, + mode: "NDtree.Mode", + low: np.ndarray, + high: np.ndarray, + max_depth: int, + max_filling: int, + is_bound: np.ndarray, + ): + """ + Constructs an NDtree node. + + Parameters + ---------- + low : np.ndarray + Lower bound of the node. + high : np.ndarray + Upper bound of the node. + max_depth : int + Remaining maximum depth of the node. + rtol : float + Maximum Error of points to node center + is_bound : np.ndarray + Boolean indicating whether the node is on the boundary. + """ + self._mode : NDtree.Mode = mode + self.low = low + self.high = high + self.max_depth = max_depth + self.max_filling = max_filling + self.is_bound = is_bound + self.children: Optional[list[NDtree.Node]] = None + self.data = [] + self.data_reserve_count = 0 + + @property + def dim(self) -> int: + return self.low.shape[0] + + @property + def num_max_split(self) -> int: + return 2**self.dim + + @property + def filling(self) -> int: + return len(self.data) + + def clear(self): + self.data.clear() + self.data_reserve_count = 0 + + if self.children is None: + return + for node in self.children: + node.clear() + + def propagate_up_reserve_counts(self): + if self.children is None: + return self.data_reserve_count + + for node in self.children: + self.data_reserve_count += node.propagate_up_reserve_counts() + + return self.data_reserve_count + + def find_min_depth_for_n_neighbors(self, n: int, depth: int, p): + if self.data_reserve_count < n: return None + if self.children is None: return None + if not self.is_within(p): return None + + tmp = [node.find_min_depth_for_n_neighbors(n, depth+1, p) for node in self.children] + depths = [] + for d in tmp: + if d is None: continue + depths.append(d) + if len(depths) == 0: return depth + + min_depth = min(depths) + return min_depth + + def get_filled_coords(self, bin_low, bin_high): + assert self._mode == NDtree.Mode.DISCRETIZE + + if self.children is None: + if self.data_reserve_count == 0: + return [] + assert np.allclose(bin_high - bin_low, 1) + return [bin_low] * self.data_reserve_count + + buffer = [] + for i in range(self.num_max_split): + mask = self._idx2mask(i) + inv_mask = np.ones_like(mask) - mask + delta_bin_half = ((bin_high - bin_low) / 2).astype(bin_low.dtype) + bin_low_i = bin_low + mask * delta_bin_half + bin_high_i = bin_high - inv_mask * delta_bin_half + buffer.extend(self.children[i].get_filled_coords(bin_low_i.astype(bin_low.dtype), bin_high_i.astype(bin_low.dtype))) + return buffer + + def split(self): + if self.children is not None: + return + if self.max_depth == 0: + return + + self.children = [None] * self.num_max_split + delta = (self.high - self.low) / 2 + for i in range(self.num_max_split): + new_low = self._idx2coord(delta, self.low, i) + self.children[i] = NDtree.Node( + self._mode, + new_low, + new_low + delta, + self.max_depth - 1, + self.max_filling, + self._idx2mask(i) * self.is_bound, + ) + + for p in self.data: + self._insert_find_child_node(p) + self.data.clear() + + def insert(self, p): + if self._mode == NDtree.Mode.INDEX: + # first insert to sub nodes if available + if self.children is not None: + self._insert_find_child_node(p) + return + # no sub nodes + # insert locally if possible + if self.filling < self.max_filling: + self.data.append(p) + return + # max filling reached, split and insert + self.split() + # insert local if split unsuccessful + if self.children is None: + self.data.append(p) + else: + self._insert_find_child_node(p) + return + + if self._mode == NDtree.Mode.DISCRETIZE: + # split as far as max depth allows + self.split() + # insert here if max depth reached + if self.children is None: + self.data.append(p) + else: + self._insert_find_child_node(p) + + def get_coord_of(self, point, bin_low, bin_high): + assert self._mode == NDtree.Mode.DISCRETIZE + + if self.children is None: + return bin_low + + for i in range(self.num_max_split): + if self.children[i].is_within(point): + mask = self._idx2mask(i) + inv_mask = np.ones_like(mask) - mask + delta_bin_half = (bin_high - bin_low) / 2 + bin_low_i = bin_low + mask * delta_bin_half + bin_high_i = bin_high - inv_mask * delta_bin_half + return self.children[i].get_coord_of(point, bin_low_i, bin_high_i) + + raise RuntimeError("Failed to locate cell of point") + + def is_within(self, point): + return ( + np.alltrue(point >= self.low) and + np.alltrue( + np.logical_or( + np.logical_and(self.is_bound, np.isclose(point, self.high, 1e-10)), + point < self.high + ) + ) + ) + + def get_height(self): + if self.children is None: + return 0 + + heights = [node.get_height() for node in self.children] + return max(heights) + 1 + + def serialize(self): + if self.children is None: + return [2, len(self.data)] + + result = [1] + for node in self.children: + c_result = node.serialize() + result[0] += c_result[0] + result.extend(c_result) + return result + + def deserialize(self, serialized): + if self.children is not None or len(self.data) > 0: + raise RuntimeError("Deserialize called on non empty tree.") + + if serialized[0] == 2: + self.data_reserve_count = serialized[1] + return + + self.split() + offset = 1 + for i in range(self.num_max_split): + self.children[i].deserialize(serialized[offset:offset+serialized[offset]]) + offset += serialized[offset] + + def merge(self, other): + is_split = self.children is not None + is_split_other = other.children is not None + + if not is_split and not is_split_other: + self.data_reserve_count += other.data_reserve_count + return + + if not is_split and is_split_other: + self.split() + for i in range(self.num_max_split): + self.children[i].merge(other.children[i]) + + if is_split and not is_split_other: + assert other.data_reserve_count == 0 + + if is_split and is_split_other: + for i in range(self.num_max_split): + self.children[i].merge(other.children[i]) + + def _insert_find_child_node(self, p): + for i in range(self.num_max_split): + if not self.children[i].is_within(p): + continue + self.children[i].insert(p) + return + + def _idx2mask(self, idx): + return ((idx & np.array([1 << i for i in range(self.dim)], dtype=np.int32)) != 0).astype(np.int32) + + def _idx2coord(self, delta, low, idx): + mask = self._idx2mask(idx).astype(dtype=delta.dtype) + return (low + delta * mask).astype(mask.dtype) + + def __init__(self, mode, low, high, max_depth, max_filling): + self.root = NDtree.Node(mode, low, high, max_depth, max_filling, np.ones(low.shape[0], dtype=np.int32)) + + def get_filled_coords(self, height=None): + if height is None: + height = self.root.get_height() + dtype = np.int32 + if height > 32: + dtype = np.int64 + return self.root.get_filled_coords( + np.zeros(self.root.dim, dtype=dtype), + np.power(2 * np.ones(self.root.dim, dtype=dtype), height), + ) + + def get_coords_of(self, points, height=None): + if height is None: + height = self.root.get_height() + dtype = np.int32 + if height > 32: + dtype = np.int64 + coords = np.zeros((len(points), self.root.dim), dtype=dtype) + c_min = np.zeros(self.root.dim, dtype=dtype), + c_max = np.power(2 * np.ones(self.root.dim, dtype=dtype), height) + for i, point in enumerate(points): + coords[i, :] = self.root.get_coord_of(point, c_min, c_max) + return coords + + def find_min_depth_for_n_neighbors(self, n, points): + if points.shape[0] == 0: + return 0 + depths = np.ones(len(points)) * self.get_height() + for idx in range(points.shape[0]): + d = self.root.find_min_depth_for_n_neighbors(n, 0, points[idx, :]) + if d is not None: + depths[idx] = d + return np.min(depths) + + def propagate_up_reserve_counts(self): + return self.root.propagate_up_reserve_counts() + + def split(self): + return self.root.split() + + def insert(self, p): + return self.root.insert(p) + + def serialize(self): + return self.root.serialize() + + def deserialize(self, serialized): + return self.root.deserialize(serialized) + + def merge(self, other): + return self.root.merge(other.root) + + def get_height(self): + return self.root.get_height() + + def clear(self): + return self.root.clear() + +class HilbertDirect: + def __init__(self, dim, bits): + self.dim = dim + self.bits = bits + self.dtype = None + + if bits <= 32: + self.dtype = np.int32 + else: + self.dtype = np.int64 + + def index2coord(self, idx): + X = np.zeros(self.dim, dtype=self.dtype) + if self.bits == 0: + return X + + # flat index to array + pos = (self.bits * self.dim) - 1 + for b in range(self.bits): + for n in range(self.dim): + bit = (idx >> pos) & 1 + X[n] = X[n] | (bit << (self.bits - b - 1)) + pos = pos - 1 + + # gray decode + N = 2 << (self.bits-1) + tmp = X[self.dim-1] >> 1 + i = self.dim-1 + while i > 0: + X[i] = X[i] ^ X[i-1] + i = i - 1 + X[0] = X[0] ^ tmp + + # undo excess work + Q = 2 + while Q != N: + P = Q - 1 + i = self.dim - 1 + while i >= 0: + if X[i] & Q: + X[0] = X[0] ^ P + else: + tmp = (X[0] ^ X[i]) & P + X[0] = X[0] ^ tmp + X[i] = X[i] ^ tmp + i = i - 1 + Q = Q << 1 + + return X + + def coord2index(self, coord): + if self.bits == 0: + return 0 + X = deepcopy(coord) + M = 1 << (self.bits-1) + + # inverse undo + Q = M + while Q > 1: + P = Q-1 + for i in range(self.dim): + if X[i] & Q: + X[0] = X[0] ^ P + else: + tmp = (X[0] ^ X[i]) & P + X[0] = X[0] ^ tmp + X[i] = X[i] ^ tmp + Q = Q >> 1 + + # gray encode + for i in range(1, self.dim): + X[i] = X[i] ^ X[i-1] + tmp = 0 + Q = M + while Q > 1: + if X[self.dim-1] & Q: + tmp = tmp ^ Q-1 + Q = Q >> 1 + for i in range(self.dim): + X[i] = X[i] ^ tmp + + # conv arrays to flat index + result = 0 + pos = (self.bits * self.dim) - 1 + for b in range(self.bits): + for n in range(self.dim): + result = result | (((X[n] >> (self.bits - b - 1)) & 1) << pos) + pos = pos - 1 + + return result + +class Projector(ABC): + @abstractmethod + def __call__(self, data): + pass + + @abstractmethod + def initialize(self, data): + pass + +class STDProjector(Projector): + def __init__(self, target_dims: int, comm: MPI.Comm): + self.num_target_dims = target_dims + self.target_dims = np.zeros(target_dims, dtype=np.int32) + self.comm = comm + + def initialize(self, data): + assert data.ndim > 1 + std = np.zeros(data.shape[-1]) + if data.shape[0] > 0: std = np.std(data, axis=0) + stds = np.array(self.comm.allgather(std)) + stds = np.mean(stds, axis=0) + self.target_dims[:] = np.sort(np.argsort(stds)[::-1][0:self.num_target_dims]).astype(np.int32) + + def __call__(self, data): + d = data + if data.ndim == 1: + d = d[np.newaxis, :] + return d[:, self.target_dims] + +class IdentityProjector(Projector): + def __call__(self, data): + return data + + def initialize(self, data): + pass + +class InterleavedDomain: + """ + Handles n-dimensional data in an overlapping domain. + Will de- and re-compose the distributed data to allow for domain local operations. + """ + def __init__(self, config, comm: MPI.Comm): + self._config = config + self._comm = comm + self._size = comm.Get_size() + self._rank = comm.Get_rank() + + # decomp data + self._x_local = None # n_points x point_dim + self._x_query_local = None # m_points x point_dim + self._f_local = None # n_points x fun_dim + self._proj_x_local = None # n_points x proj_dim + self._proj_x_query_local = None # m_points x proj_dim + self._bound_low = None + self._bound_high = None + self._normalization = None + self._shift = None + self._max_depth = None + self._max_filling = 8 + self._coarsening_factor = 2 + self._n_neighbors = 50 + self._tree: Optional[NDtree] = None + self._query_tree: Optional[NDtree] = None + self._projector: Projector = IdentityProjector() + + self._query_rank_mapping = None + + def configure(self, domain_config): + self._max_filling = 8 if 'max_filling' not in domain_config else domain_config['max_filling'] + self._coarsening_factor = 2 if 'coarsening_factor' not in domain_config else domain_config['coarsening_factor'] + self._n_neighbors = 50 if 'n_neighbors' not in domain_config else domain_config['n_neighbors'] + if 'projection_type' not in domain_config: + self._projector = IdentityProjector() + return + + match domain_config['projection_type']: + case "std": + target_dims = 1 if 'projection_std_dims' not in domain_config else domain_config['projection_std_dims'] + self._projector = STDProjector(target_dims, self._comm) + case "identity": + self._projector = IdentityProjector() + + def set_local_data(self, x, x_, f): + self._x_local = x + self._x_query_local = x_ + self._f_local = f + + def decompose(self): + # if not parallel, no work to be done + if self._size == 1: + return self._x_local, self._x_query_local, self._f_local + + self._generate_trees() + return self._create_partitions() + + def get_depth_filling(self): + return self._max_depth, self._max_filling + + def reassemble(self, x_query, f_query): + # if not parallel, no work to be done + if self._size == 1: + return f_query + + # transfer data back to original rank + send_map = {r:[] for r in range(self._size)} + for i in range(x_query.shape[0]): + dst_rank = self._query_rank_mapping[tuple(x_query[i, :].tolist())] + data = [] + data.extend(x_query[i, :].tolist()) + data.extend(f_query[i, :].tolist()) + send_map[dst_rank].append(data) + local_data = self._communicate(x_query.shape[-1] + f_query.shape[-1], send_map) + local_data = np.array(local_data).reshape(-1, x_query.shape[-1] + f_query.shape[-1]) + + # sort data to match order of initial query input + idx_map = {} + for i in range(self._x_query_local.shape[0]): + idx_map[tuple(self._x_query_local[i, :].tolist())] = i + + result = np.zeros((self._x_query_local.shape[0], f_query.shape[-1])) + for d_idx in range(local_data.shape[0]): + idx = idx_map[tuple(local_data[d_idx, 0:x_query.shape[-1]].tolist())] + result[idx, :] = local_data[d_idx, x_query.shape[-1]:] + return result + + def _normalize_x(self): + x_loc_min = np.ones(self._x_local.shape[-1]) * np.inf + if self._x_local.shape[0] > 0: x_loc_min = np.min(self._x_local, axis=0) + xq_loc_min = np.ones(self._x_query_local.shape[-1]) * np.inf + if self._x_query_local.shape[0] > 0: xq_loc_min = np.min(self._x_query_local, axis=0) + local_min = np.minimum(x_loc_min, xq_loc_min) + glob_min = np.min(np.array(self._comm.allgather(local_min)), axis=0) + self._bound_low = glob_min + + x_loc_max = -np.ones(self._x_local.shape[-1]) * np.inf + if self._x_local.shape[0] > 0: x_loc_max = np.max(self._x_local, axis=0) + xq_loc_max = -np.ones(self._x_query_local.shape[-1]) * np.inf + if self._x_query_local.shape[0] > 0: xq_loc_max = np.max(self._x_query_local, axis=0) + local_max = np.maximum(x_loc_max, xq_loc_max) + glob_max = np.max(np.array(self._comm.allgather(local_max)), axis=0) + self._bound_high = glob_max + + delta = glob_max - glob_min + shift = glob_min + delta / 2.0 + + self._normalization = delta / 2.0 + self._shift = shift + self._x_local = (self._x_local - shift) / self._normalization[None, :] + self._x_query_local = (self._x_query_local - shift) / self._normalization[None, :] + def eval_cond(): + return not all([np.alltrue(self._x_local > -1.0), np.alltrue(self._x_local < 1.0), np.alltrue(self._x_query_local > -1.0), np.alltrue(self._x_query_local < 1.0)]) + glob_cond = self._comm.allgather(eval_cond()) + while any(glob_cond): + # undo prev norm + self._x_local = self._x_local * self._normalization[None, :] + self._x_query_local = self._x_query_local * self._normalization[None, :] + # retry with larger norm + self._normalization += 1e-10 + self._x_local = self._x_local / self._normalization[None, :] + self._x_query_local = self._x_query_local / self._normalization[None, :] + glob_cond = self._comm.allgather(eval_cond()) + + self._projector.initialize(self._x_local) + self._proj_x_local = self._projector(self._x_local) + self._proj_x_query_local = self._projector(self._x_query_local) + + def _generate_trees(self): + self._normalize_x() + + proj_dim = self._proj_x_local.shape[1] + low, high = -np.ones(proj_dim), np.ones(proj_dim) + # determine max required depth + # populate and query height + depth_tree = NDtree(NDtree.Mode.INDEX, low, high, 32, self._max_filling) + for n in range(self._proj_x_local.shape[0]): + depth_tree.insert(self._proj_x_local[n, :]) + for m in range(self._proj_x_query_local.shape[0]): + depth_tree.insert(self._proj_x_query_local[m, :]) + max_depth = np.maximum(self._comm.allreduce(depth_tree.get_height(), op=MPI.MAX) // self._coarsening_factor, self._coarsening_factor) + del depth_tree + self._max_depth = max_depth + + # populate discretization trees + tree = NDtree(NDtree.Mode.DISCRETIZE, low, high, max_depth, self._max_filling) + for n in range(self._proj_x_local.shape[0]): + tree.insert(self._proj_x_local[n, :]) + query_tree = NDtree(NDtree.Mode.DISCRETIZE, low, high, max_depth, self._max_filling) + for m in range(self._proj_x_query_local.shape[0]): + query_tree.insert(self._proj_x_query_local[m, :]) + + # merge into a global tree structure + def bcast_tree(t) -> NDtree: + serial = t.serialize() + serial_global = self._comm.allgather(serial) + res = NDtree(NDtree.Mode.DISCRETIZE, low, high, max_depth, self._max_filling) + for s in serial_global: + other_tree = NDtree(NDtree.Mode.DISCRETIZE, low, high, max_depth, self._max_filling) + other_tree.deserialize(s) + res.merge(other_tree) + return res + + self._tree = bcast_tree(tree) + self._query_tree = bcast_tree(query_tree) + + def _create_partitions(self): + self._tree.propagate_up_reserve_counts() + r_m_depth = self._tree.find_min_depth_for_n_neighbors(self._n_neighbors, self._proj_x_query_local) + r_m_depth = self._comm.allreduce(r_m_depth, op=MPI.MAX) + r_m_cells = np.power(2, r_m_depth) + grid_resolution = self._tree.get_height() + hMap = HilbertDirect(self._proj_x_local.shape[-1], grid_resolution) + + # index query points + query_coords = self._query_tree.get_filled_coords(grid_resolution) + query_mapping = {tuple(c.tolist()):hMap.coord2index(c) for c in query_coords} + query_mapping_inv = {} + for coord, idx in query_mapping.items(): + query_mapping_inv[idx] = coord + sorted_1d_query_indices = sorted(query_mapping.values()) + + # partition based on query points + target_point_per_rank = len(sorted_1d_query_indices) // self._size + partitions = {r:[-1, -1] for r in range(self._size)} + last_val = sorted_1d_query_indices[0] + start_idx = 0 + part_begin = 0 + part_idx = 0 + for i in range(1, len(sorted_1d_query_indices)): + if sorted_1d_query_indices[i] != last_val: + last_val = sorted_1d_query_indices[i] + start_idx = i + + if i - part_begin + 1 < target_point_per_rank: + continue + + # handle last partition + if part_idx == self._size - 1: + partitions[part_idx][0] = part_begin + partitions[part_idx][1] = len(sorted_1d_query_indices) - 1 + part_idx = part_idx + 1 + break + + # partition has minimum size, find nearest end of current cell + end_idx = i + for j in range(i+1, len(sorted_1d_query_indices)): + if sorted_1d_query_indices[j] != last_val: + end_idx = j - 1 + break + # start is closer + if i - start_idx < end_idx - i: + partitions[part_idx][0] = part_begin + partitions[part_idx][1] = start_idx - 1 + part_begin = start_idx + # end is closer + else: + partitions[part_idx][0] = part_begin + partitions[part_idx][1] = end_idx + part_begin = end_idx + 1 + + part_idx = part_idx + 1 + # last part was not used + if part_idx in partitions and partitions[part_idx][0] == 0 and partitions[part_idx][1] == 0: + partitions[part_idx][0] = part_begin + partitions[part_idx][1] = len(sorted_1d_query_indices) - 1 + + # assign surrounding src domain to rank local query points + src_domains = {r:[None, None] for r in range(self._size)} + for rank, p_range in partitions.items(): + if -1 == p_range[0] == p_range[1]: + continue + + # gather coords, find bounding box + local_indices = sorted_1d_query_indices[p_range[0]:p_range[1]+1] + local_coords = np.array([query_mapping_inv[idx] for idx in local_indices]) + bbox_low = np.min(local_coords, axis=0) + bbox_high = np.max(local_coords, axis=0) + # np.where for valid bbox ranges + low_mask = (bbox_low >= 2 * r_m_cells).astype(np.int32) + bbox_low = low_mask * (bbox_low - 2 * r_m_cells) + (1 - low_mask) * np.zeros_like(bbox_low) + + max_coord = np.power(2, grid_resolution) - 1 + high_mask = (bbox_high < (max_coord - 2 * r_m_cells)).astype(np.int32) + bbox_high = high_mask * (bbox_high + 2 * r_m_cells) + (1 - high_mask) * (np.ones_like(bbox_high) * max_coord) + src_domains[rank][0] = bbox_low + src_domains[rank][1] = bbox_high + + # figure out which query points need to be sent where + owned_query_coords = self._query_tree.get_coords_of(self._proj_x_query_local, grid_resolution) + owned_query_indices = [query_mapping[tuple(coord.tolist())] for coord in owned_query_coords] + send_map = {r:[] for r in range(self._size)} + for i in range(len(owned_query_indices)): + # find owning partition + found = False + for rank, rank_range in partitions.items(): + if -1 == rank_range[0] == rank_range[1]: + continue + + if ( + sorted_1d_query_indices[rank_range[0]] > owned_query_indices[i] or + sorted_1d_query_indices[rank_range[1]] < owned_query_indices[i] + ): + continue + + send_map[rank].append(self._x_query_local[i, :].tolist()) + found = True + break + + # part not found + if not found: + raise RuntimeError("Corresponding rank not found for query point") + + # transfer query points + x_query_part, inv_map = self._communicate(self._x_query_local.shape[-1], send_map, return_inverse=True) + x_query = np.array(x_query_part).reshape(-1, self._x_query_local.shape[-1]) + # invert query send map for later (to transfer back) + self._query_rank_mapping = {} + for rank, data in inv_map.items(): + data_ = np.array(data).reshape(-1, self._x_query_local.shape[-1]) + for p_idx in range(data_.shape[0]): + self._query_rank_mapping[tuple(data_[p_idx, :].tolist())] = rank + + # figure out which source points need to be sent where + send_map = {r:[] for r in range(self._size)} + source_coords = self._tree.get_filled_coords(grid_resolution) + source_mapping = {tuple(c.tolist()): hMap.coord2index(c) for c in source_coords} + owned_source_coords = self._tree.get_coords_of(self._proj_x_local, grid_resolution) + owned_source_indices = [source_mapping[tuple(coord.tolist())] for coord in owned_source_coords] + for i in range(len(owned_source_indices)): + for rank, rank_domain in src_domains.items(): + if rank_domain[0] is None or rank_domain[1] is None: + continue + + if ( + np.alltrue(rank_domain[0] <= source_coords[i]) and + np.alltrue(source_coords[i] <= rank_domain[1]) + ): + data = [] + data.extend(self._x_local[i, :].tolist()) + data.extend(self._f_local[i, :].tolist()) + send_map[rank].append(data) + + # transfer source points + xf_part = self._communicate(self._x_local.shape[-1] + self._f_local.shape[-1], send_map) + xf_part = np.array(xf_part).reshape(-1, self._x_local.shape[-1] + self._f_local.shape[-1]) + x = xf_part[:, 0:self._x_local.shape[-1]] + f = xf_part[:, self._x_local.shape[-1]:] + + return x, x_query, f + + def _communicate(self, entry_size, send_map, return_inverse=False): + send_counts = [len(send_map[r]) for r in range(self._size)] + send_counts[self._rank] = 0 # ignore local count + glob_send_counts = self._comm.allgather(send_counts) + + send_reqs = [] + for recv_rank, data in send_map.items(): + if recv_rank == self._rank: + continue + if len(data) == 0: + continue + for d_idx, entry in enumerate(data): + req = self._comm.isend(entry, dest=recv_rank, tag=create_tag(d_idx, self._rank, recv_rank)) + send_reqs.append(req) + + recv_reqs = [] + for send_rank in range(self._size): + if send_rank == self._rank: + continue + if glob_send_counts[send_rank][self._rank] == 0: + continue + for d_idx in range(glob_send_counts[send_rank][self._rank]): + req = self._comm.irecv(None, source=send_rank, tag=create_tag(d_idx, send_rank, self._rank)) + recv_reqs.append(tuple([send_rank, req])) + + MPI.Request.Waitall(send_reqs) + + result = [] + result.extend(send_map[self._rank]) + inv_map = {r:[] for r in range(self._size)} + inv_map[self._rank].extend(send_map[self._rank]) + + for source_rank, req in recv_reqs: + data = req.wait() + result.append(data) + if return_inverse: + inv_map[source_rank].append(data) + + if return_inverse: + return result, inv_map + else: + return result + +class RBF_PU: + """ + Interpolates f(x) for f: R^n -> R^m using partition of unity RBF interpolant. + + The approach here does not require a support radius as data is normalized. + """ + + def __init__(self, config, logger, comm: MPI.Comm, rank, size): + self._config = config + self._logger = logger + self._comm = comm + self._rank = rank + self._size = size + + self._domain = InterleavedDomain(config, comm) + self._use_pu = False + self._pu_overlap = 0.1 + self._pu_cluster_size = 50 + + # RBF data + self._phi = RBF_PU.basis_c6 + self._x = None + self._x_query = None + self._f = None + + def configure(self, interp_config): + self._domain.configure({} if 'domain_config' not in interp_config else interp_config['domain_config']) + self._use_pu = False if 'use_pu' not in interp_config else interp_config['use_pu'] + if self._use_pu: + self._pu_overlap = 0.1 if 'pu_overlap' not in interp_config else interp_config['pu_overlap'] + self._pu_cluster_size = 50 if 'pu_cluster_size' not in interp_config else interp_config['pu_cluster_size'] + if 'basis' not in interp_config: + return + match interp_config['basis']: + case 'c0': + self._phi = RBF_PU.basis_c0 + case 'c2': + self._phi = RBF_PU.basis_c2 + case 'c4': + self._phi = RBF_PU.basis_c4 + case 'c6': + self._phi = RBF_PU.basis_c6 + case 'gauss': + eps = 1.0 if 'gauss_eps' not in interp_config else interp_config['gauss_eps'] + self._phi = partial(RBF_PU.basis_gauss, eps=eps) + + def set_local_data(self, x, x_, f): + self._domain.set_local_data(x, x_, f) + + def interpolate(self): + self._x, self._x_query, self._f = self._domain.decompose() + + interp = self.compute_interpolant(self._x, self._f) + xq, fq = self.evaluate_interpolant(interp, self._x_query) + + fq_local = self._domain.reassemble(xq, fq) + + return fq_local + + # ================================ + # RBF + # ================================ + @property + def compute_interpolant(self): + if self._use_pu: + return self.compute_rbf_pu_interpolant + else: + return self.compute_rbf_interpolant + + @property + def evaluate_interpolant(self): + if self._use_pu: + return self.evaluate_rbf_pu_interpolant + else: + return self.evaluate_rbf_interpolant + + def _compute_cluster_centers(self, x): + assert self._use_pu + local_min, local_max = np.min(x, axis=0), np.max(x, axis=0) + d4 = (local_max - local_min) / 4 + + center = local_min + 2.0 * d4 + centers = np.zeros((2 * x.shape[-1] + 1, x.shape[-1])) + centers[-1, :] = center + for d in range(x.shape[-1]): + mask = np.zeros_like(d4) + mask[d] = 1 + + centers[2*d + 0, :] = center - mask * d4 + centers[2*d + 1, :] = center + mask * d4 + + return centers, local_min, local_max + + def compute_rbf_pu_interpolant(self, x, f): + # compute r_m + c_centers, local_min, local_max = self._compute_cluster_centers(x) + index_tree = NDtree(NDtree.Mode.INDEX, local_min, local_max, *self._domain.get_depth_filling()) + # TODO later + # determine clusters + # ignore empty clusters + # compute local RBF interpolant for remaining clusters + pass + + def compute_rbf_interpolant(self, x, f): + n_points = x.shape[0] + src_size = x.shape[-1] + dst_size = f.shape[-1] + + r = np.linalg.norm(x[None, :, :] - x[:, None, :], ord=2, axis=-1) + # compute lin and const term + b = np.zeros((src_size + 1, dst_size)) + p = np.zeros((n_points, src_size + 1)) + p[:, 0] = 1 + p[:, 1:] = x + for k in range(dst_size): + b[:, k] = np.linalg.lstsq(p, f[:, k], rcond=None)[0] + + a = self._phi(r) + # compute basis weights + w = np.zeros((dst_size, n_points)) + for k in range(dst_size): + w[k, :] = np.linalg.solve(a, f[:, k] - np.matmul(p, b[:, k])) + + return w, b, x, f + + def evaluate_rbf_pu_interpolant(self, interp, xq): + # eval xq for all cluster interpolants + # compute weights + # sum contributions + pass + + def evaluate_rbf_interpolant(self, interp, xq): + w, b, x, f = interp + + r = np.linalg.norm(xq[None, :, :] - x[:, None, :], ord=2, axis=-1) + contrib_basis = np.matmul(w[:, :], self._phi(r)) # f_k x eval_p + contrib_const = b[0, :] # f_k + # b: p_size+1 x f_k + # xq: eval_p x p_size + contrib_lin = np.matmul(xq[:, :], b[1:, :]).T # f_k x eval_p + + fq = (contrib_basis + contrib_const[:, None] + contrib_lin).T + return xq, fq + + # ================================ + # BASIS FUNCTIONS + # ================================ + @staticmethod + def basis_c0(r): + return np.maximum(0.0, np.power(1.0 - r, 2)) + + @staticmethod + def basis_c2(r): + return np.maximum(0.0, np.power(1.0 - r, 4)) * (4.0 * r + 1) + + @staticmethod + def basis_c4(r): + return np.maximum(0.0, np.power(1.0 - r, 6)) * (35.0 * np.power(r, 2) + 18.0 * r + 3.0) / 3.0 + + @staticmethod + def basis_c6(r): + return np.maximum(0.0, np.power(1.0 - r, 8)) * ( + 32.0 * np.power(r, 3) + 25.0 * np.power(r, 2) + 8.0 * r + 1.0) + + @staticmethod + def basis_gauss(r, eps): + return np.exp(-np.power(eps * r, 2.0)) diff --git a/micro_manager/micro_manager.py b/micro_manager/micro_manager.py index 64239a35..85dc4c6b 100644 --- a/micro_manager/micro_manager.py +++ b/micro_manager/micro_manager.py @@ -212,8 +212,11 @@ def solve(self) -> None: # Write a checkpoint if a simulation is just activated. # This checkpoint will be asynchronous to the checkpoints written at the start of the time window. if self._is_model_adaptivity_on: + active_sim_lids = ( + self._adaptivity_controller.get_active_sim_local_ids() + ) self._model_adaptivity_controller.update_states( - self._micro_sims, active_sim_gids + self._micro_sims, active_sim_lids ) for i in range(self._local_number_of_sims): if sim_states_cp[i] is None and self._micro_sims[i]: @@ -244,13 +247,13 @@ def solve(self) -> None: # Write a checkpoint if self._participant.requires_writing_checkpoint() or performed_lb: if self._is_model_adaptivity_on: - active_sim_gids = None + active_sim_lids = None if self._is_adaptivity_on: - active_sim_gids = ( + active_sim_lids = ( self._adaptivity_controller.get_active_sim_local_ids() ) self._model_adaptivity_controller.update_states( - self._micro_sims, active_sim_gids + self._micro_sims, active_sim_lids ) for i in range(self._local_number_of_sims): sim_states_cp[i] = ( @@ -308,13 +311,13 @@ def solve(self) -> None: self.state_setter(self._micro_sims[i], sim_states_cp[i]) if self._is_model_adaptivity_on: - active_sim_gids = None + active_sim_lids = None if self._is_adaptivity_on: - active_sim_gids = ( + active_sim_lids = ( self._adaptivity_controller.get_active_sim_local_ids() ) self._model_adaptivity_controller.write_back_states( - self._micro_sims, active_sim_gids + self._micro_sims, active_sim_lids ) first_iteration = False @@ -821,7 +824,7 @@ def initialize(self) -> None: initial_micro_data_list = ( self._adaptivity_controller.get_full_field_micro_output( - initial_micro_data_list + initial_data, initial_micro_data_list ) ) @@ -1092,7 +1095,7 @@ def _solve_micro_simulations_with_adaptivity( ) micro_sims_output = self._adaptivity_controller.get_full_field_micro_output( - micro_sims_output + micro_sims_input, micro_sims_output ) inactive_sim_lids = self._adaptivity_controller.get_inactive_sim_local_ids() @@ -1149,6 +1152,9 @@ def _solve_micro_simulations_with_model_adaptivity( ) self._model_adaptivity_controller.finalise_solve() + + for lid, sim in enumerate(self._micro_sims): + output[lid]["model_resolution"] = self._model_adaptivity_controller.get_sim_class_resolution(sim) return output def _get_solve_variant(self) -> Callable[[list, float], list]: From 883692aae9d6b32fcfbe03811dde06b5c022ed18 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Tue, 7 Apr 2026 22:45:50 +0200 Subject: [PATCH 02/14] fix format --- micro_manager/adaptivity/adaptivity.py | 87 +++-- micro_manager/adaptivity/global_adaptivity.py | 8 +- micro_manager/adaptivity/local_adaptivity.py | 12 +- micro_manager/config.py | 4 +- micro_manager/interpolation.py | 317 ++++++++++++------ micro_manager/micro_manager.py | 4 +- 6 files changed, 297 insertions(+), 135 deletions(-) diff --git a/micro_manager/adaptivity/adaptivity.py b/micro_manager/adaptivity/adaptivity.py index 98927414..ead339be 100644 --- a/micro_manager/adaptivity/adaptivity.py +++ b/micro_manager/adaptivity/adaptivity.py @@ -63,36 +63,45 @@ def __init__( self._mapping_configs = [] mappings = configurator.get_adaptivity_mapping_configs() for mapping in mappings: - src_fields = mapping['src_fields'] - dst_fields = mapping['dst_fields'] - n_neighbors = mapping['n_neighbors'] + src_fields = mapping["src_fields"] + dst_fields = mapping["dst_fields"] + n_neighbors = mapping["n_neighbors"] self._mappings.append((src_fields, dst_fields)) config = {} - if 'use_pu' in mapping['rbf_config']: - config['use_pu'] = mapping['rbf_config']['use_pu'] - if 'pu_overlap' in mapping['rbf_config']: - config['pu_overlap'] = mapping['rbf_config']['pu_overlap'] - config['pu_cluster_size'] = n_neighbors - if 'basis' in mapping['rbf_config']: - if 'type' in mapping['rbf_config']['basis']: - config['basis'] = mapping['rbf_config']['basis']['type'] - if config['basis'] == 'gauss' and 'eps' in mapping['rbf_config']['basis']: - config['gauss_eps'] = mapping['rbf_config']['basis']['eps'] + if "use_pu" in mapping["rbf_config"]: + config["use_pu"] = mapping["rbf_config"]["use_pu"] + if "pu_overlap" in mapping["rbf_config"]: + config["pu_overlap"] = mapping["rbf_config"]["pu_overlap"] + config["pu_cluster_size"] = n_neighbors + if "basis" in mapping["rbf_config"]: + if "type" in mapping["rbf_config"]["basis"]: + config["basis"] = mapping["rbf_config"]["basis"]["type"] + if ( + config["basis"] == "gauss" + and "eps" in mapping["rbf_config"]["basis"] + ): + config["gauss_eps"] = mapping["rbf_config"]["basis"]["eps"] dom_config = {} - dom_config['n_neighbors'] = n_neighbors - if 'max_filling' in mapping['domain_config']: - dom_config['max_filling'] = mapping['domain_config']['max_filling'] - if 'coarsening_factor' in mapping['domain_config']: - dom_config['coarsening_factor'] = mapping['domain_config']['coarsening_factor'] - if 'projection' in mapping['domain_config']: - if 'type' in mapping['domain_config']['projection']: - dom_config['projection_type'] = mapping['domain_config']['projection']['type'] - if 'target_dims' in mapping['domain_config']['projection']: - dom_config['projection_std_dims'] = mapping['domain_config']['projection']['target_dims'] - - config['domain_config'] = dom_config + dom_config["n_neighbors"] = n_neighbors + if "max_filling" in mapping["domain_config"]: + dom_config["max_filling"] = mapping["domain_config"]["max_filling"] + if "coarsening_factor" in mapping["domain_config"]: + dom_config["coarsening_factor"] = mapping["domain_config"][ + "coarsening_factor" + ] + if "projection" in mapping["domain_config"]: + if "type" in mapping["domain_config"]["projection"]: + dom_config["projection_type"] = mapping["domain_config"][ + "projection" + ]["type"] + if "target_dims" in mapping["domain_config"]["projection"]: + dom_config["projection_std_dims"] = mapping["domain_config"][ + "projection" + ]["target_dims"] + + config["domain_config"] = dom_config self._mapping_configs.append(config) # is_sim_active: 1D array having state (active or inactive) of each micro simulation @@ -256,9 +265,13 @@ def _interpolate_output(self, micro_input, micro_sims_output): inactive_lids = self.get_inactive_sim_local_ids() arg_sizes = {} for name, value in micro_input[active_lids[0]].items(): - arg_sizes[name] = 1 if type(value) != np.ndarray and type(value) != list else len(value) + arg_sizes[name] = ( + 1 if type(value) != np.ndarray and type(value) != list else len(value) + ) for name, value in micro_sims_output[active_lids[0]].items(): - arg_sizes[name] = 1 if type(value) != np.ndarray and type(value) != list else len(value) + arg_sizes[name] = ( + 1 if type(value) != np.ndarray and type(value) != list else len(value) + ) # compute interpolation n_points = len(active_lids) @@ -272,28 +285,38 @@ def _interpolate_output(self, micro_input, micro_sims_output): for idx, lid in enumerate(active_lids): offset = 0 for src_arg in src_args: - input_data[idx, offset:offset + arg_sizes[src_arg]] = micro_input[lid][src_arg] + input_data[idx, offset : offset + arg_sizes[src_arg]] = micro_input[ + lid + ][src_arg] offset += arg_sizes[src_arg] offset = 0 for dst_arg in dst_args: - output_data[idx, offset:offset + arg_sizes[dst_arg]] = micro_sims_output[lid][dst_arg] + output_data[ + idx, offset : offset + arg_sizes[dst_arg] + ] = micro_sims_output[lid][dst_arg] offset += arg_sizes[dst_arg] input_data_inactive = np.zeros((n_points_inactive, src_size)) for idx, lid in enumerate(inactive_lids): offset = 0 for src_arg in src_args: - input_data_inactive[idx, offset:offset + arg_sizes[src_arg]] = micro_input[lid][src_arg] + input_data_inactive[ + idx, offset : offset + arg_sizes[src_arg] + ] = micro_input[lid][src_arg] offset += arg_sizes[src_arg] # use interpolant self._interpolation.configure(self._mappings[m_idx]) - self._interpolation.set_local_data(input_data, input_data_inactive, output_data) + self._interpolation.set_local_data( + input_data, input_data_inactive, output_data + ) output_data_inactive = self._interpolation.interpolate() for idx, lid in enumerate(inactive_lids): offset = 0 for dst_arg in dst_args: - micro_sims_output[lid][dst_arg] = output_data_inactive[idx, offset:offset + arg_sizes[dst_arg]] + micro_sims_output[lid][dst_arg] = output_data_inactive[ + idx, offset : offset + arg_sizes[dst_arg] + ] offset += arg_sizes[dst_arg] def _get_similarity_measure( diff --git a/micro_manager/adaptivity/global_adaptivity.py b/micro_manager/adaptivity/global_adaptivity.py index bc85a53d..6fc48409 100644 --- a/micro_manager/adaptivity/global_adaptivity.py +++ b/micro_manager/adaptivity/global_adaptivity.py @@ -69,7 +69,9 @@ def __init__( self._global_ids = global_ids self._comm = comm - self._interpolation = RBF_PU(configurator, base_logger, comm, self._rank, self._comm.Get_size()) + self._interpolation = RBF_PU( + configurator, base_logger, comm, self._rank, self._comm.Get_size() + ) rank_of_sim = get_ranks_of_sims(global_ids, rank, comm, global_number_of_sims) self._is_sim_on_this_rank = [False] * global_number_of_sims # DECLARATION @@ -242,7 +244,9 @@ def get_inactive_sim_global_ids(self) -> np.ndarray: return np.array(inactive_sim_ids) - def get_full_field_micro_output(self, micro_input: list, micro_output: list) -> list: + def get_full_field_micro_output( + self, micro_input: list, micro_output: list + ) -> list: """ Get the full field micro output from active simulations to inactive simulations. diff --git a/micro_manager/adaptivity/local_adaptivity.py b/micro_manager/adaptivity/local_adaptivity.py index d1cf8425..b7d2f011 100644 --- a/micro_manager/adaptivity/local_adaptivity.py +++ b/micro_manager/adaptivity/local_adaptivity.py @@ -50,7 +50,13 @@ def __init__( configurator, num_sims, micro_problem_cls, model_manager, base_logger, rank ) self._comm = comm - self._interpolation = RBF_PU(configurator, base_logger, MPI.COMM_SELF, MPI.COMM_SELF.Get_rank(), MPI.COMM_SELF.Get_size()) + self._interpolation = RBF_PU( + configurator, + base_logger, + MPI.COMM_SELF, + MPI.COMM_SELF.Get_rank(), + MPI.COMM_SELF.Get_size(), + ) # similarity_dists: 2D array having similarity distances between each micro simulation pair # This matrix is modified in place via the function update_similarity_dists @@ -148,7 +154,9 @@ def get_inactive_sim_global_ids(self) -> np.ndarray: inactive_sim_ids = self.get_inactive_sim_local_ids() return inactive_sim_ids - def get_full_field_micro_output(self, micro_input: list, micro_output: list) -> list: + def get_full_field_micro_output( + self, micro_input: list, micro_output: list + ) -> list: """ Get the full field micro output from active simulations to inactive simulations. diff --git a/micro_manager/config.py b/micro_manager/config.py index 577633a1..3ed6491e 100644 --- a/micro_manager/config.py +++ b/micro_manager/config.py @@ -330,7 +330,9 @@ def read_json_micro_manager(self): self._logger.log_info_rank_zero("Adaptivity type: " + self._adaptivity_type) if self._adaptivity_type == "global": - self._adaptivity_mappings = self._data["simulation_params"]["adaptivity_settings"]["mappings"] + self._adaptivity_mappings = self._data["simulation_params"][ + "adaptivity_settings" + ]["mappings"] if self._data["simulation_params"]["adaptivity_settings"].get( "lazy_initialization" diff --git a/micro_manager/interpolation.py b/micro_manager/interpolation.py index 3b1269c3..b859f007 100644 --- a/micro_manager/interpolation.py +++ b/micro_manager/interpolation.py @@ -96,6 +96,7 @@ def interpolate(self, neighbors: np.ndarray, point: np.ndarray, values): return interpol_val / summed_weights + class NDtree: class Mode(Enum): DISCRETIZE = 0 @@ -127,7 +128,7 @@ def __init__( is_bound : np.ndarray Boolean indicating whether the node is on the boundary. """ - self._mode : NDtree.Mode = mode + self._mode: NDtree.Mode = mode self.low = low self.high = high self.max_depth = max_depth @@ -168,16 +169,24 @@ def propagate_up_reserve_counts(self): return self.data_reserve_count def find_min_depth_for_n_neighbors(self, n: int, depth: int, p): - if self.data_reserve_count < n: return None - if self.children is None: return None - if not self.is_within(p): return None - - tmp = [node.find_min_depth_for_n_neighbors(n, depth+1, p) for node in self.children] + if self.data_reserve_count < n: + return None + if self.children is None: + return None + if not self.is_within(p): + return None + + tmp = [ + node.find_min_depth_for_n_neighbors(n, depth + 1, p) + for node in self.children + ] depths = [] for d in tmp: - if d is None: continue + if d is None: + continue depths.append(d) - if len(depths) == 0: return depth + if len(depths) == 0: + return depth min_depth = min(depths) return min_depth @@ -198,7 +207,12 @@ def get_filled_coords(self, bin_low, bin_high): delta_bin_half = ((bin_high - bin_low) / 2).astype(bin_low.dtype) bin_low_i = bin_low + mask * delta_bin_half bin_high_i = bin_high - inv_mask * delta_bin_half - buffer.extend(self.children[i].get_filled_coords(bin_low_i.astype(bin_low.dtype), bin_high_i.astype(bin_low.dtype))) + buffer.extend( + self.children[i].get_filled_coords( + bin_low_i.astype(bin_low.dtype), + bin_high_i.astype(bin_low.dtype), + ) + ) return buffer def split(self): @@ -271,13 +285,10 @@ def get_coord_of(self, point, bin_low, bin_high): raise RuntimeError("Failed to locate cell of point") def is_within(self, point): - return ( - np.alltrue(point >= self.low) and - np.alltrue( - np.logical_or( - np.logical_and(self.is_bound, np.isclose(point, self.high, 1e-10)), - point < self.high - ) + return np.alltrue(point >= self.low) and np.alltrue( + np.logical_or( + np.logical_and(self.is_bound, np.isclose(point, self.high, 1e-10)), + point < self.high, ) ) @@ -310,7 +321,9 @@ def deserialize(self, serialized): self.split() offset = 1 for i in range(self.num_max_split): - self.children[i].deserialize(serialized[offset:offset+serialized[offset]]) + self.children[i].deserialize( + serialized[offset : offset + serialized[offset]] + ) offset += serialized[offset] def merge(self, other): @@ -341,14 +354,23 @@ def _insert_find_child_node(self, p): return def _idx2mask(self, idx): - return ((idx & np.array([1 << i for i in range(self.dim)], dtype=np.int32)) != 0).astype(np.int32) + return ( + (idx & np.array([1 << i for i in range(self.dim)], dtype=np.int32)) != 0 + ).astype(np.int32) def _idx2coord(self, delta, low, idx): mask = self._idx2mask(idx).astype(dtype=delta.dtype) return (low + delta * mask).astype(mask.dtype) def __init__(self, mode, low, high, max_depth, max_filling): - self.root = NDtree.Node(mode, low, high, max_depth, max_filling, np.ones(low.shape[0], dtype=np.int32)) + self.root = NDtree.Node( + mode, + low, + high, + max_depth, + max_filling, + np.ones(low.shape[0], dtype=np.int32), + ) def get_filled_coords(self, height=None): if height is None: @@ -368,7 +390,7 @@ def get_coords_of(self, points, height=None): if height > 32: dtype = np.int64 coords = np.zeros((len(points), self.root.dim), dtype=dtype) - c_min = np.zeros(self.root.dim, dtype=dtype), + c_min = (np.zeros(self.root.dim, dtype=dtype),) c_max = np.power(2 * np.ones(self.root.dim, dtype=dtype), height) for i, point in enumerate(points): coords[i, :] = self.root.get_coord_of(point, c_min, c_max) @@ -408,6 +430,7 @@ def get_height(self): def clear(self): return self.root.clear() + class HilbertDirect: def __init__(self, dim, bits): self.dim = dim @@ -433,11 +456,11 @@ def index2coord(self, idx): pos = pos - 1 # gray decode - N = 2 << (self.bits-1) - tmp = X[self.dim-1] >> 1 - i = self.dim-1 + N = 2 << (self.bits - 1) + tmp = X[self.dim - 1] >> 1 + i = self.dim - 1 while i > 0: - X[i] = X[i] ^ X[i-1] + X[i] = X[i] ^ X[i - 1] i = i - 1 X[0] = X[0] ^ tmp @@ -462,12 +485,12 @@ def coord2index(self, coord): if self.bits == 0: return 0 X = deepcopy(coord) - M = 1 << (self.bits-1) + M = 1 << (self.bits - 1) # inverse undo Q = M while Q > 1: - P = Q-1 + P = Q - 1 for i in range(self.dim): if X[i] & Q: X[0] = X[0] ^ P @@ -479,12 +502,12 @@ def coord2index(self, coord): # gray encode for i in range(1, self.dim): - X[i] = X[i] ^ X[i-1] + X[i] = X[i] ^ X[i - 1] tmp = 0 Q = M while Q > 1: - if X[self.dim-1] & Q: - tmp = tmp ^ Q-1 + if X[self.dim - 1] & Q: + tmp = tmp ^ Q - 1 Q = Q >> 1 for i in range(self.dim): X[i] = X[i] ^ tmp @@ -499,6 +522,7 @@ def coord2index(self, coord): return result + class Projector(ABC): @abstractmethod def __call__(self, data): @@ -508,6 +532,7 @@ def __call__(self, data): def initialize(self, data): pass + class STDProjector(Projector): def __init__(self, target_dims: int, comm: MPI.Comm): self.num_target_dims = target_dims @@ -517,10 +542,13 @@ def __init__(self, target_dims: int, comm: MPI.Comm): def initialize(self, data): assert data.ndim > 1 std = np.zeros(data.shape[-1]) - if data.shape[0] > 0: std = np.std(data, axis=0) + if data.shape[0] > 0: + std = np.std(data, axis=0) stds = np.array(self.comm.allgather(std)) stds = np.mean(stds, axis=0) - self.target_dims[:] = np.sort(np.argsort(stds)[::-1][0:self.num_target_dims]).astype(np.int32) + self.target_dims[:] = np.sort( + np.argsort(stds)[::-1][0 : self.num_target_dims] + ).astype(np.int32) def __call__(self, data): d = data @@ -528,6 +556,7 @@ def __call__(self, data): d = d[np.newaxis, :] return d[:, self.target_dims] + class IdentityProjector(Projector): def __call__(self, data): return data @@ -535,11 +564,13 @@ def __call__(self, data): def initialize(self, data): pass + class InterleavedDomain: """ Handles n-dimensional data in an overlapping domain. Will de- and re-compose the distributed data to allow for domain local operations. """ + def __init__(self, config, comm: MPI.Comm): self._config = config self._comm = comm @@ -567,16 +598,28 @@ def __init__(self, config, comm: MPI.Comm): self._query_rank_mapping = None def configure(self, domain_config): - self._max_filling = 8 if 'max_filling' not in domain_config else domain_config['max_filling'] - self._coarsening_factor = 2 if 'coarsening_factor' not in domain_config else domain_config['coarsening_factor'] - self._n_neighbors = 50 if 'n_neighbors' not in domain_config else domain_config['n_neighbors'] - if 'projection_type' not in domain_config: + self._max_filling = ( + 8 if "max_filling" not in domain_config else domain_config["max_filling"] + ) + self._coarsening_factor = ( + 2 + if "coarsening_factor" not in domain_config + else domain_config["coarsening_factor"] + ) + self._n_neighbors = ( + 50 if "n_neighbors" not in domain_config else domain_config["n_neighbors"] + ) + if "projection_type" not in domain_config: self._projector = IdentityProjector() return - match domain_config['projection_type']: + match domain_config["projection_type"]: case "std": - target_dims = 1 if 'projection_std_dims' not in domain_config else domain_config['projection_std_dims'] + target_dims = ( + 1 + if "projection_std_dims" not in domain_config + else domain_config["projection_std_dims"] + ) self._projector = STDProjector(target_dims, self._comm) case "identity": self._projector = IdentityProjector() @@ -603,7 +646,7 @@ def reassemble(self, x_query, f_query): return f_query # transfer data back to original rank - send_map = {r:[] for r in range(self._size)} + send_map = {r: [] for r in range(self._size)} for i in range(x_query.shape[0]): dst_rank = self._query_rank_mapping[tuple(x_query[i, :].tolist())] data = [] @@ -611,7 +654,9 @@ def reassemble(self, x_query, f_query): data.extend(f_query[i, :].tolist()) send_map[dst_rank].append(data) local_data = self._communicate(x_query.shape[-1] + f_query.shape[-1], send_map) - local_data = np.array(local_data).reshape(-1, x_query.shape[-1] + f_query.shape[-1]) + local_data = np.array(local_data).reshape( + -1, x_query.shape[-1] + f_query.shape[-1] + ) # sort data to match order of initial query input idx_map = {} @@ -620,23 +665,27 @@ def reassemble(self, x_query, f_query): result = np.zeros((self._x_query_local.shape[0], f_query.shape[-1])) for d_idx in range(local_data.shape[0]): - idx = idx_map[tuple(local_data[d_idx, 0:x_query.shape[-1]].tolist())] - result[idx, :] = local_data[d_idx, x_query.shape[-1]:] + idx = idx_map[tuple(local_data[d_idx, 0 : x_query.shape[-1]].tolist())] + result[idx, :] = local_data[d_idx, x_query.shape[-1] :] return result def _normalize_x(self): x_loc_min = np.ones(self._x_local.shape[-1]) * np.inf - if self._x_local.shape[0] > 0: x_loc_min = np.min(self._x_local, axis=0) + if self._x_local.shape[0] > 0: + x_loc_min = np.min(self._x_local, axis=0) xq_loc_min = np.ones(self._x_query_local.shape[-1]) * np.inf - if self._x_query_local.shape[0] > 0: xq_loc_min = np.min(self._x_query_local, axis=0) + if self._x_query_local.shape[0] > 0: + xq_loc_min = np.min(self._x_query_local, axis=0) local_min = np.minimum(x_loc_min, xq_loc_min) glob_min = np.min(np.array(self._comm.allgather(local_min)), axis=0) self._bound_low = glob_min x_loc_max = -np.ones(self._x_local.shape[-1]) * np.inf - if self._x_local.shape[0] > 0: x_loc_max = np.max(self._x_local, axis=0) + if self._x_local.shape[0] > 0: + x_loc_max = np.max(self._x_local, axis=0) xq_loc_max = -np.ones(self._x_query_local.shape[-1]) * np.inf - if self._x_query_local.shape[0] > 0: xq_loc_max = np.max(self._x_query_local, axis=0) + if self._x_query_local.shape[0] > 0: + xq_loc_max = np.max(self._x_query_local, axis=0) local_max = np.maximum(x_loc_max, xq_loc_max) glob_max = np.max(np.array(self._comm.allgather(local_max)), axis=0) self._bound_high = glob_max @@ -647,9 +696,20 @@ def _normalize_x(self): self._normalization = delta / 2.0 self._shift = shift self._x_local = (self._x_local - shift) / self._normalization[None, :] - self._x_query_local = (self._x_query_local - shift) / self._normalization[None, :] + self._x_query_local = (self._x_query_local - shift) / self._normalization[ + None, : + ] + def eval_cond(): - return not all([np.alltrue(self._x_local > -1.0), np.alltrue(self._x_local < 1.0), np.alltrue(self._x_query_local > -1.0), np.alltrue(self._x_query_local < 1.0)]) + return not all( + [ + np.alltrue(self._x_local > -1.0), + np.alltrue(self._x_local < 1.0), + np.alltrue(self._x_query_local > -1.0), + np.alltrue(self._x_query_local < 1.0), + ] + ) + glob_cond = self._comm.allgather(eval_cond()) while any(glob_cond): # undo prev norm @@ -677,7 +737,11 @@ def _generate_trees(self): depth_tree.insert(self._proj_x_local[n, :]) for m in range(self._proj_x_query_local.shape[0]): depth_tree.insert(self._proj_x_query_local[m, :]) - max_depth = np.maximum(self._comm.allreduce(depth_tree.get_height(), op=MPI.MAX) // self._coarsening_factor, self._coarsening_factor) + max_depth = np.maximum( + self._comm.allreduce(depth_tree.get_height(), op=MPI.MAX) + // self._coarsening_factor, + self._coarsening_factor, + ) del depth_tree self._max_depth = max_depth @@ -685,7 +749,9 @@ def _generate_trees(self): tree = NDtree(NDtree.Mode.DISCRETIZE, low, high, max_depth, self._max_filling) for n in range(self._proj_x_local.shape[0]): tree.insert(self._proj_x_local[n, :]) - query_tree = NDtree(NDtree.Mode.DISCRETIZE, low, high, max_depth, self._max_filling) + query_tree = NDtree( + NDtree.Mode.DISCRETIZE, low, high, max_depth, self._max_filling + ) for m in range(self._proj_x_query_local.shape[0]): query_tree.insert(self._proj_x_query_local[m, :]) @@ -693,9 +759,13 @@ def _generate_trees(self): def bcast_tree(t) -> NDtree: serial = t.serialize() serial_global = self._comm.allgather(serial) - res = NDtree(NDtree.Mode.DISCRETIZE, low, high, max_depth, self._max_filling) + res = NDtree( + NDtree.Mode.DISCRETIZE, low, high, max_depth, self._max_filling + ) for s in serial_global: - other_tree = NDtree(NDtree.Mode.DISCRETIZE, low, high, max_depth, self._max_filling) + other_tree = NDtree( + NDtree.Mode.DISCRETIZE, low, high, max_depth, self._max_filling + ) other_tree.deserialize(s) res.merge(other_tree) return res @@ -705,7 +775,9 @@ def bcast_tree(t) -> NDtree: def _create_partitions(self): self._tree.propagate_up_reserve_counts() - r_m_depth = self._tree.find_min_depth_for_n_neighbors(self._n_neighbors, self._proj_x_query_local) + r_m_depth = self._tree.find_min_depth_for_n_neighbors( + self._n_neighbors, self._proj_x_query_local + ) r_m_depth = self._comm.allreduce(r_m_depth, op=MPI.MAX) r_m_cells = np.power(2, r_m_depth) grid_resolution = self._tree.get_height() @@ -713,7 +785,7 @@ def _create_partitions(self): # index query points query_coords = self._query_tree.get_filled_coords(grid_resolution) - query_mapping = {tuple(c.tolist()):hMap.coord2index(c) for c in query_coords} + query_mapping = {tuple(c.tolist()): hMap.coord2index(c) for c in query_coords} query_mapping_inv = {} for coord, idx in query_mapping.items(): query_mapping_inv[idx] = coord @@ -721,7 +793,7 @@ def _create_partitions(self): # partition based on query points target_point_per_rank = len(sorted_1d_query_indices) // self._size - partitions = {r:[-1, -1] for r in range(self._size)} + partitions = {r: [-1, -1] for r in range(self._size)} last_val = sorted_1d_query_indices[0] start_idx = 0 part_begin = 0 @@ -743,7 +815,7 @@ def _create_partitions(self): # partition has minimum size, find nearest end of current cell end_idx = i - for j in range(i+1, len(sorted_1d_query_indices)): + for j in range(i + 1, len(sorted_1d_query_indices)): if sorted_1d_query_indices[j] != last_val: end_idx = j - 1 break @@ -760,35 +832,47 @@ def _create_partitions(self): part_idx = part_idx + 1 # last part was not used - if part_idx in partitions and partitions[part_idx][0] == 0 and partitions[part_idx][1] == 0: + if ( + part_idx in partitions + and partitions[part_idx][0] == 0 + and partitions[part_idx][1] == 0 + ): partitions[part_idx][0] = part_begin partitions[part_idx][1] = len(sorted_1d_query_indices) - 1 # assign surrounding src domain to rank local query points - src_domains = {r:[None, None] for r in range(self._size)} + src_domains = {r: [None, None] for r in range(self._size)} for rank, p_range in partitions.items(): if -1 == p_range[0] == p_range[1]: continue # gather coords, find bounding box - local_indices = sorted_1d_query_indices[p_range[0]:p_range[1]+1] + local_indices = sorted_1d_query_indices[p_range[0] : p_range[1] + 1] local_coords = np.array([query_mapping_inv[idx] for idx in local_indices]) bbox_low = np.min(local_coords, axis=0) bbox_high = np.max(local_coords, axis=0) # np.where for valid bbox ranges low_mask = (bbox_low >= 2 * r_m_cells).astype(np.int32) - bbox_low = low_mask * (bbox_low - 2 * r_m_cells) + (1 - low_mask) * np.zeros_like(bbox_low) + bbox_low = low_mask * (bbox_low - 2 * r_m_cells) + ( + 1 - low_mask + ) * np.zeros_like(bbox_low) max_coord = np.power(2, grid_resolution) - 1 high_mask = (bbox_high < (max_coord - 2 * r_m_cells)).astype(np.int32) - bbox_high = high_mask * (bbox_high + 2 * r_m_cells) + (1 - high_mask) * (np.ones_like(bbox_high) * max_coord) + bbox_high = high_mask * (bbox_high + 2 * r_m_cells) + (1 - high_mask) * ( + np.ones_like(bbox_high) * max_coord + ) src_domains[rank][0] = bbox_low src_domains[rank][1] = bbox_high # figure out which query points need to be sent where - owned_query_coords = self._query_tree.get_coords_of(self._proj_x_query_local, grid_resolution) - owned_query_indices = [query_mapping[tuple(coord.tolist())] for coord in owned_query_coords] - send_map = {r:[] for r in range(self._size)} + owned_query_coords = self._query_tree.get_coords_of( + self._proj_x_query_local, grid_resolution + ) + owned_query_indices = [ + query_mapping[tuple(coord.tolist())] for coord in owned_query_coords + ] + send_map = {r: [] for r in range(self._size)} for i in range(len(owned_query_indices)): # find owning partition found = False @@ -797,8 +881,8 @@ def _create_partitions(self): continue if ( - sorted_1d_query_indices[rank_range[0]] > owned_query_indices[i] or - sorted_1d_query_indices[rank_range[1]] < owned_query_indices[i] + sorted_1d_query_indices[rank_range[0]] > owned_query_indices[i] + or sorted_1d_query_indices[rank_range[1]] < owned_query_indices[i] ): continue @@ -811,7 +895,9 @@ def _create_partitions(self): raise RuntimeError("Corresponding rank not found for query point") # transfer query points - x_query_part, inv_map = self._communicate(self._x_query_local.shape[-1], send_map, return_inverse=True) + x_query_part, inv_map = self._communicate( + self._x_query_local.shape[-1], send_map, return_inverse=True + ) x_query = np.array(x_query_part).reshape(-1, self._x_query_local.shape[-1]) # invert query send map for later (to transfer back) self._query_rank_mapping = {} @@ -821,19 +907,22 @@ def _create_partitions(self): self._query_rank_mapping[tuple(data_[p_idx, :].tolist())] = rank # figure out which source points need to be sent where - send_map = {r:[] for r in range(self._size)} + send_map = {r: [] for r in range(self._size)} source_coords = self._tree.get_filled_coords(grid_resolution) source_mapping = {tuple(c.tolist()): hMap.coord2index(c) for c in source_coords} - owned_source_coords = self._tree.get_coords_of(self._proj_x_local, grid_resolution) - owned_source_indices = [source_mapping[tuple(coord.tolist())] for coord in owned_source_coords] + owned_source_coords = self._tree.get_coords_of( + self._proj_x_local, grid_resolution + ) + owned_source_indices = [ + source_mapping[tuple(coord.tolist())] for coord in owned_source_coords + ] for i in range(len(owned_source_indices)): for rank, rank_domain in src_domains.items(): if rank_domain[0] is None or rank_domain[1] is None: continue - if ( - np.alltrue(rank_domain[0] <= source_coords[i]) and - np.alltrue(source_coords[i] <= rank_domain[1]) + if np.alltrue(rank_domain[0] <= source_coords[i]) and np.alltrue( + source_coords[i] <= rank_domain[1] ): data = [] data.extend(self._x_local[i, :].tolist()) @@ -841,16 +930,20 @@ def _create_partitions(self): send_map[rank].append(data) # transfer source points - xf_part = self._communicate(self._x_local.shape[-1] + self._f_local.shape[-1], send_map) - xf_part = np.array(xf_part).reshape(-1, self._x_local.shape[-1] + self._f_local.shape[-1]) - x = xf_part[:, 0:self._x_local.shape[-1]] - f = xf_part[:, self._x_local.shape[-1]:] + xf_part = self._communicate( + self._x_local.shape[-1] + self._f_local.shape[-1], send_map + ) + xf_part = np.array(xf_part).reshape( + -1, self._x_local.shape[-1] + self._f_local.shape[-1] + ) + x = xf_part[:, 0 : self._x_local.shape[-1]] + f = xf_part[:, self._x_local.shape[-1] :] return x, x_query, f def _communicate(self, entry_size, send_map, return_inverse=False): send_counts = [len(send_map[r]) for r in range(self._size)] - send_counts[self._rank] = 0 # ignore local count + send_counts[self._rank] = 0 # ignore local count glob_send_counts = self._comm.allgather(send_counts) send_reqs = [] @@ -860,7 +953,9 @@ def _communicate(self, entry_size, send_map, return_inverse=False): if len(data) == 0: continue for d_idx, entry in enumerate(data): - req = self._comm.isend(entry, dest=recv_rank, tag=create_tag(d_idx, self._rank, recv_rank)) + req = self._comm.isend( + entry, dest=recv_rank, tag=create_tag(d_idx, self._rank, recv_rank) + ) send_reqs.append(req) recv_reqs = [] @@ -870,14 +965,16 @@ def _communicate(self, entry_size, send_map, return_inverse=False): if glob_send_counts[send_rank][self._rank] == 0: continue for d_idx in range(glob_send_counts[send_rank][self._rank]): - req = self._comm.irecv(None, source=send_rank, tag=create_tag(d_idx, send_rank, self._rank)) + req = self._comm.irecv( + None, source=send_rank, tag=create_tag(d_idx, send_rank, self._rank) + ) recv_reqs.append(tuple([send_rank, req])) MPI.Request.Waitall(send_reqs) result = [] result.extend(send_map[self._rank]) - inv_map = {r:[] for r in range(self._size)} + inv_map = {r: [] for r in range(self._size)} inv_map[self._rank].extend(send_map[self._rank]) for source_rank, req in recv_reqs: @@ -891,6 +988,7 @@ def _communicate(self, entry_size, send_map, return_inverse=False): else: return result + class RBF_PU: """ Interpolates f(x) for f: R^n -> R^m using partition of unity RBF interpolant. @@ -917,24 +1015,42 @@ def __init__(self, config, logger, comm: MPI.Comm, rank, size): self._f = None def configure(self, interp_config): - self._domain.configure({} if 'domain_config' not in interp_config else interp_config['domain_config']) - self._use_pu = False if 'use_pu' not in interp_config else interp_config['use_pu'] + self._domain.configure( + {} + if "domain_config" not in interp_config + else interp_config["domain_config"] + ) + self._use_pu = ( + False if "use_pu" not in interp_config else interp_config["use_pu"] + ) if self._use_pu: - self._pu_overlap = 0.1 if 'pu_overlap' not in interp_config else interp_config['pu_overlap'] - self._pu_cluster_size = 50 if 'pu_cluster_size' not in interp_config else interp_config['pu_cluster_size'] - if 'basis' not in interp_config: + self._pu_overlap = ( + 0.1 + if "pu_overlap" not in interp_config + else interp_config["pu_overlap"] + ) + self._pu_cluster_size = ( + 50 + if "pu_cluster_size" not in interp_config + else interp_config["pu_cluster_size"] + ) + if "basis" not in interp_config: return - match interp_config['basis']: - case 'c0': + match interp_config["basis"]: + case "c0": self._phi = RBF_PU.basis_c0 - case 'c2': + case "c2": self._phi = RBF_PU.basis_c2 - case 'c4': + case "c4": self._phi = RBF_PU.basis_c4 - case 'c6': + case "c6": self._phi = RBF_PU.basis_c6 - case 'gauss': - eps = 1.0 if 'gauss_eps' not in interp_config else interp_config['gauss_eps'] + case "gauss": + eps = ( + 1.0 + if "gauss_eps" not in interp_config + else interp_config["gauss_eps"] + ) self._phi = partial(RBF_PU.basis_gauss, eps=eps) def set_local_data(self, x, x_, f): @@ -979,15 +1095,17 @@ def _compute_cluster_centers(self, x): mask = np.zeros_like(d4) mask[d] = 1 - centers[2*d + 0, :] = center - mask * d4 - centers[2*d + 1, :] = center + mask * d4 + centers[2 * d + 0, :] = center - mask * d4 + centers[2 * d + 1, :] = center + mask * d4 return centers, local_min, local_max def compute_rbf_pu_interpolant(self, x, f): # compute r_m c_centers, local_min, local_max = self._compute_cluster_centers(x) - index_tree = NDtree(NDtree.Mode.INDEX, local_min, local_max, *self._domain.get_depth_filling()) + index_tree = NDtree( + NDtree.Mode.INDEX, local_min, local_max, *self._domain.get_depth_filling() + ) # TODO later # determine clusters # ignore empty clusters @@ -1048,12 +1166,17 @@ def basis_c2(r): @staticmethod def basis_c4(r): - return np.maximum(0.0, np.power(1.0 - r, 6)) * (35.0 * np.power(r, 2) + 18.0 * r + 3.0) / 3.0 + return ( + np.maximum(0.0, np.power(1.0 - r, 6)) + * (35.0 * np.power(r, 2) + 18.0 * r + 3.0) + / 3.0 + ) @staticmethod def basis_c6(r): return np.maximum(0.0, np.power(1.0 - r, 8)) * ( - 32.0 * np.power(r, 3) + 25.0 * np.power(r, 2) + 8.0 * r + 1.0) + 32.0 * np.power(r, 3) + 25.0 * np.power(r, 2) + 8.0 * r + 1.0 + ) @staticmethod def basis_gauss(r, eps): diff --git a/micro_manager/micro_manager.py b/micro_manager/micro_manager.py index 85dc4c6b..d500d2a0 100644 --- a/micro_manager/micro_manager.py +++ b/micro_manager/micro_manager.py @@ -1154,7 +1154,9 @@ def _solve_micro_simulations_with_model_adaptivity( self._model_adaptivity_controller.finalise_solve() for lid, sim in enumerate(self._micro_sims): - output[lid]["model_resolution"] = self._model_adaptivity_controller.get_sim_class_resolution(sim) + output[lid][ + "model_resolution" + ] = self._model_adaptivity_controller.get_sim_class_resolution(sim) return output def _get_solve_variant(self) -> Callable[[list, float], list]: From 70ca465db94dbe673a4543c259ae2602097361d1 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Wed, 8 Apr 2026 10:43:59 +0200 Subject: [PATCH 03/14] Add doc --- micro_manager/adaptivity/adaptivity.py | 122 +-- micro_manager/adaptivity/global_adaptivity.py | 2 + micro_manager/adaptivity/local_adaptivity.py | 2 + micro_manager/interpolation.py | 769 ++++++++++++++++-- 4 files changed, 785 insertions(+), 110 deletions(-) diff --git a/micro_manager/adaptivity/adaptivity.py b/micro_manager/adaptivity/adaptivity.py index ead339be..2b99e5d3 100644 --- a/micro_manager/adaptivity/adaptivity.py +++ b/micro_manager/adaptivity/adaptivity.py @@ -62,47 +62,7 @@ def __init__( self._mappings = [] self._mapping_configs = [] mappings = configurator.get_adaptivity_mapping_configs() - for mapping in mappings: - src_fields = mapping["src_fields"] - dst_fields = mapping["dst_fields"] - n_neighbors = mapping["n_neighbors"] - - self._mappings.append((src_fields, dst_fields)) - config = {} - if "use_pu" in mapping["rbf_config"]: - config["use_pu"] = mapping["rbf_config"]["use_pu"] - if "pu_overlap" in mapping["rbf_config"]: - config["pu_overlap"] = mapping["rbf_config"]["pu_overlap"] - config["pu_cluster_size"] = n_neighbors - if "basis" in mapping["rbf_config"]: - if "type" in mapping["rbf_config"]["basis"]: - config["basis"] = mapping["rbf_config"]["basis"]["type"] - if ( - config["basis"] == "gauss" - and "eps" in mapping["rbf_config"]["basis"] - ): - config["gauss_eps"] = mapping["rbf_config"]["basis"]["eps"] - - dom_config = {} - dom_config["n_neighbors"] = n_neighbors - if "max_filling" in mapping["domain_config"]: - dom_config["max_filling"] = mapping["domain_config"]["max_filling"] - if "coarsening_factor" in mapping["domain_config"]: - dom_config["coarsening_factor"] = mapping["domain_config"][ - "coarsening_factor" - ] - if "projection" in mapping["domain_config"]: - if "type" in mapping["domain_config"]["projection"]: - dom_config["projection_type"] = mapping["domain_config"][ - "projection" - ]["type"] - if "target_dims" in mapping["domain_config"]["projection"]: - dom_config["projection_std_dims"] = mapping["domain_config"][ - "projection" - ]["target_dims"] - - config["domain_config"] = dom_config - self._mapping_configs.append(config) + self._load_mappings(mappings) # is_sim_active: 1D array having state (active or inactive) of each micro simulation # Start adaptivity calculation with all sims active @@ -155,6 +115,61 @@ def __init__( self._metrics_logger.log_info("n|n active|n inactive|assoc ranks") + def _load_mappings(self, mappings: list) -> None: + """ + Translates the mapping information provided from the configuration file into a + interpolation method parseable structure. + + This will populate the self._mappings and self._mapping_configs buffers. + Called once during __init__. + + Parameters + ---------- + mappings : list + List of mappings as provided by the configuration file. + """ + for mapping in mappings: + src_fields = mapping["src_fields"] + dst_fields = mapping["dst_fields"] + n_neighbors = mapping["n_neighbors"] + + self._mappings.append((src_fields, dst_fields)) + config = {} + if "use_pu" in mapping["rbf_config"]: + config["use_pu"] = mapping["rbf_config"]["use_pu"] + if "pu_overlap" in mapping["rbf_config"]: + config["pu_overlap"] = mapping["rbf_config"]["pu_overlap"] + config["pu_cluster_size"] = n_neighbors + if "basis" in mapping["rbf_config"]: + if "type" in mapping["rbf_config"]["basis"]: + config["basis"] = mapping["rbf_config"]["basis"]["type"] + if ( + config["basis"] == "gauss" + and "eps" in mapping["rbf_config"]["basis"] + ): + config["gauss_eps"] = mapping["rbf_config"]["basis"]["eps"] + + dom_config = {} + dom_config["n_neighbors"] = n_neighbors + if "max_filling" in mapping["domain_config"]: + dom_config["max_filling"] = mapping["domain_config"]["max_filling"] + if "coarsening_factor" in mapping["domain_config"]: + dom_config["coarsening_factor"] = mapping["domain_config"][ + "coarsening_factor" + ] + if "projection" in mapping["domain_config"]: + if "type" in mapping["domain_config"]["projection"]: + dom_config["projection_type"] = mapping["domain_config"][ + "projection" + ]["type"] + if "target_dims" in mapping["domain_config"]["projection"]: + dom_config["projection_std_dims"] = mapping["domain_config"][ + "projection" + ]["target_dims"] + + config["domain_config"] = dom_config + self._mapping_configs.append(config) + def _update_similarity_dists(self, dt: float, data: dict) -> None: """ Calculate metric which determines if two micro simulations are similar enough to have one of them deactivated. @@ -245,16 +260,25 @@ def _check_for_deactivation(self, active_id: int, active_ids: list) -> bool: return True return False - def _interpolate_output(self, micro_input, micro_sims_output): - # now all outputs are init to representative + def _interpolate_output(self, micro_input, micro_sims_output) -> None: + """ + Interpolates the micro output based on the available inputs and outputs using the selected + interpolation method and desired mappings. + Will compute functions f1 ... fN described in the config. + fi: X -> Y, X and Y must be subsets of the coupled fields. + Every output field may only be used once as interpolation target, meaning there may not be + a function fi and fj with shared Yi and Yj. - # We need RBF interp here - # will treat function f1 ... fN described in the config - # fi: X -> Y, X and Y must be subsets of the coupled fields + This method will edit the output buffer, instead of returning a new buffer. - # mapping list[tuple[list[str], list[str]]] = list[mapping=tuple[src args, dst args]] - # will aggregate args - # every output may only be used once as interpolation target + Parameters + ---------- + micro_input : list + List of all local micro simulation inputs. + + micro_sims_output : list + List of all local micro simulation outputs. (current state) + """ targets = [] for _, target_args in self._mappings: targets.extend(target_args) diff --git a/micro_manager/adaptivity/global_adaptivity.py b/micro_manager/adaptivity/global_adaptivity.py index 6fc48409..29ad8f5a 100644 --- a/micro_manager/adaptivity/global_adaptivity.py +++ b/micro_manager/adaptivity/global_adaptivity.py @@ -252,6 +252,8 @@ def get_full_field_micro_output( Parameters ---------- + micro_input : list + List of dicts containing the input data for each simulation. micro_output : list List of dicts having individual output of each simulation. Only the active simulation outputs are entered. diff --git a/micro_manager/adaptivity/local_adaptivity.py b/micro_manager/adaptivity/local_adaptivity.py index b7d2f011..0272739d 100644 --- a/micro_manager/adaptivity/local_adaptivity.py +++ b/micro_manager/adaptivity/local_adaptivity.py @@ -162,6 +162,8 @@ def get_full_field_micro_output( Parameters ---------- + micro_input : list + List of dicts containing the input data for each simulation. micro_output : list List of dicts having individual output of each simulation. Only the active simulation outputs are entered. diff --git a/micro_manager/interpolation.py b/micro_manager/interpolation.py index b859f007..29a53548 100644 --- a/micro_manager/interpolation.py +++ b/micro_manager/interpolation.py @@ -2,13 +2,15 @@ from copy import deepcopy from enum import Enum from functools import partial -from typing import Optional +from typing import Optional, Tuple, Union import sys from mpi4py import MPI import numpy as np from sklearn.neighbors import NearestNeighbors +from micro_manager import Config +from micro_manager.tools.logging_wrapper import Logger from micro_manager.tools.p2p import create_tag # handle compat issue between np version 1 and 2 @@ -98,6 +100,11 @@ def interpolate(self, neighbors: np.ndarray, point: np.ndarray, values): class NDtree: + """ + This is a spatial data structure to store N-dimensional data points. Can be used either for discretization or + spatial indexing purposes. Is based on an octtree but ported to N dimensions. + """ + class Mode(Enum): DISCRETIZE = 0 INDEX = 1 @@ -117,14 +124,16 @@ def __init__( Parameters ---------- + mode : NDtree.Mode + The mode of operation. low : np.ndarray Lower bound of the node. high : np.ndarray Upper bound of the node. max_depth : int Remaining maximum depth of the node. - rtol : float - Maximum Error of points to node center + max_filling : int + Maximum number of points within the node, until split. is_bound : np.ndarray Boolean indicating whether the node is on the boundary. """ @@ -150,7 +159,10 @@ def num_max_split(self) -> int: def filling(self) -> int: return len(self.data) - def clear(self): + def clear(self) -> None: + """ + Clears all data, but preserves node structure. + """ self.data.clear() self.data_reserve_count = 0 @@ -159,7 +171,17 @@ def clear(self): for node in self.children: node.clear() - def propagate_up_reserve_counts(self): + def propagate_up_reserve_counts(self) -> int: + """ + Counts the reserve counts of child nodes and returns sum. + Used during discretization mode when all data points are in leaf nodes at max depth + to approximate the required depth to find N neighbours. + + Returns + ------- + reserve_count : int + sum of child node reserve counts. + """ if self.children is None: return self.data_reserve_count @@ -168,7 +190,27 @@ def propagate_up_reserve_counts(self): return self.data_reserve_count - def find_min_depth_for_n_neighbors(self, n: int, depth: int, p): + def find_min_depth_for_n_neighbors( + self, n: int, depth: int, p: np.ndarray + ) -> Optional[int]: + """ + Finds the minimum depth required to find N nearest neighbors for the given point. + Assumes propagate_up_reserve_counts was called. + + Parameters + ---------- + n : int + Number of nearest neighbors. + depth : int + Recursion depth. Start recursion with 0. + p : np.ndarray + Query point. + + Returns + ------- + min_depth : Optional[int] + None depth cannot be found, else depth. + """ if self.data_reserve_count < n: return None if self.children is None: @@ -191,7 +233,24 @@ def find_min_depth_for_n_neighbors(self, n: int, depth: int, p): min_depth = min(depths) return min_depth - def get_filled_coords(self, bin_low, bin_high): + def get_filled_coords( + self, bin_low: np.ndarray, bin_high: np.ndarray + ) -> list[np.ndarray]: + """ + Finds coordinates of all cells that have non-zero reserve counts. Assumes discretization mode is used. + + Parameters + ---------- + bin_low : np.ndarray + Lower bound of possible bins. + bin_high : np.ndarray + Upper bound of possible bins. + + Returns + ------- + coords : list[np.ndarray] + Coordinates of all cells that have non-zero reserve counts. + """ assert self._mode == NDtree.Mode.DISCRETIZE if self.children is None: @@ -216,6 +275,9 @@ def get_filled_coords(self, bin_low, bin_high): return buffer def split(self): + """ + Splits node if possible and transfers data points to child nodes. + """ if self.children is not None: return if self.max_depth == 0: @@ -238,7 +300,15 @@ def split(self): self._insert_find_child_node(p) self.data.clear() - def insert(self, p): + def insert(self, p: np.ndarray): + """ + Inserts data point into this node if possible, else into child node. + + Parameters + ---------- + p : np.ndarray + Data point to be inserted. + """ if self._mode == NDtree.Mode.INDEX: # first insert to sub nodes if available if self.children is not None: @@ -267,7 +337,26 @@ def insert(self, p): else: self._insert_find_child_node(p) - def get_coord_of(self, point, bin_low, bin_high): + def get_coord_of( + self, point: np.ndarray, bin_low: np.ndarray, bin_high: np.ndarray + ) -> np.ndarray: + """ + Finds the cell coordinate of given point. Assumes discretization mode is used. + + Parameters + ---------- + point : np.ndarray + Query point. + bin_low : np.ndarray + Lower bound of possible bins. + bin_high : np.ndarray + Upper bound of possible bins. + + Returns + ------- + coord : np.ndarray + Coordinate of given point. + """ assert self._mode == NDtree.Mode.DISCRETIZE if self.children is None: @@ -284,7 +373,20 @@ def get_coord_of(self, point, bin_low, bin_high): raise RuntimeError("Failed to locate cell of point") - def is_within(self, point): + def is_within(self, point: np.ndarray) -> bool: + """ + Checks whether given point is within the bounds of this node. + + Parameters + ---------- + point : np.ndarray + Query point. + + Returns + ------- + is_within : bool + True if point is within bounds of this node. + """ return np.alltrue(point >= self.low) and np.alltrue( np.logical_or( np.logical_and(self.is_bound, np.isclose(point, self.high, 1e-10)), @@ -292,14 +394,30 @@ def is_within(self, point): ) ) - def get_height(self): + def get_height(self) -> int: + """ + Returns height of this node. + + Returns + ------- + height : int + Height of this node. + """ if self.children is None: return 0 heights = [node.get_height() for node in self.children] return max(heights) + 1 - def serialize(self): + def serialize(self) -> list[int]: + """ + Serializes the tree in a run-length encoded format. First entry determines amount of owned entries. + + Returns + ------- + serialized_data : list[int] + Serialized tree. + """ if self.children is None: return [2, len(self.data)] @@ -310,7 +428,15 @@ def serialize(self): result.extend(c_result) return result - def deserialize(self, serialized): + def deserialize(self, serialized: list[int]) -> None: + """ + Deserializes tree from serialized data. + + Parameters + ---------- + serialized : list[int] + Serialized tree. + """ if self.children is not None or len(self.data) > 0: raise RuntimeError("Deserialize called on non empty tree.") @@ -326,7 +452,15 @@ def deserialize(self, serialized): ) offset += serialized[offset] - def merge(self, other): + def merge(self, other: "NDtree.Node") -> None: + """ + Merges the other node structure and reserve counts into this node. + + Parameters + ---------- + other : NDtree.Node + Other node structure. + """ is_split = self.children is not None is_split_other = other.children is not None @@ -346,23 +480,88 @@ def merge(self, other): for i in range(self.num_max_split): self.children[i].merge(other.children[i]) - def _insert_find_child_node(self, p): + def _insert_find_child_node(self, p: np.ndarray) -> None: + """ + Inserts the point into the correct child node. + + Parameters + ---------- + p : np.ndarray + Point to insert. + """ for i in range(self.num_max_split): if not self.children[i].is_within(p): continue self.children[i].insert(p) return - def _idx2mask(self, idx): + def _idx2mask(self, idx: int) -> np.ndarray: + """ + Converts the given index into its corresponding binary mask. + + Parameters + ---------- + idx : int + Index to convert. + + Returns + ------- + mask : np.ndarray + Binary mask. If bit i of idx is 1, then entry i of mask if 1. + """ return ( (idx & np.array([1 << i for i in range(self.dim)], dtype=np.int32)) != 0 ).astype(np.int32) - def _idx2coord(self, delta, low, idx): + def _idx2coord( + self, delta: np.ndarray, low: np.ndarray, idx: int + ) -> np.ndarray: + """ + Computes the new lower bound for the child node with the given index. + + Parameters + ---------- + delta : np.ndarray + New cell size + low : np.ndarray + Old lower bound. + idx : np.ndarray + Index of child node. + + Returns + ------- + coord : np.ndarray + New cell lower bound. + """ mask = self._idx2mask(idx).astype(dtype=delta.dtype) return (low + delta * mask).astype(mask.dtype) - def __init__(self, mode, low, high, max_depth, max_filling): + def __init__( + self, + mode: "NDtree.Mode", + low: np.ndarray, + high: np.ndarray, + max_depth: int, + max_filling: int, + ): + """ + Constructs a NDtree with the given parameters. + In discretize mode, all data points are inserted into nodes at max_depth. + In index mode, max_filling is used for insertion. When the threshold is met, nodes are split. + + Parameters + ---------- + mode : NDtree.Mode + Mode of operation. + low : np.ndarray + Lower bound of space. + high : np.ndarray + Upper bound of space. + max_depth : int + Maximum depth of the tree. + max_filling : int + Maximum filling of the tree. + """ self.root = NDtree.Node( mode, low, @@ -372,7 +571,20 @@ def __init__(self, mode, low, high, max_depth, max_filling): np.ones(low.shape[0], dtype=np.int32), ) - def get_filled_coords(self, height=None): + def get_filled_coords(self, height: Optional[int] = None) -> list[np.ndarray]: + """ + Finds coordinates of all cells that have non-zero reserve counts. Assumes discretization mode is used. + + Parameters + ---------- + height : Optional[int] + Height of the tree. If None, will be computed. + + Returns + ------- + coords : list[np.ndarray] + Coordinates of all cells that have non-zero reserve counts. + """ if height is None: height = self.root.get_height() dtype = np.int32 @@ -383,7 +595,24 @@ def get_filled_coords(self, height=None): np.power(2 * np.ones(self.root.dim, dtype=dtype), height), ) - def get_coords_of(self, points, height=None): + def get_coords_of( + self, points: np.ndarray, height: Optional[int] = None + ) -> np.ndarray: + """ + Finds the cell coordinate of all given point. Assumes discretization mode is used. + + Parameters + ---------- + points : np.ndarray + Points to find coordinates of. + height : Optional[int] + Height of the tree. If None, will be computed. + + Returns + ------- + coords : np.ndarray + Coordinates of all points. + """ if height is None: height = self.root.get_height() dtype = np.int32 @@ -396,7 +625,22 @@ def get_coords_of(self, points, height=None): coords[i, :] = self.root.get_coord_of(point, c_min, c_max) return coords - def find_min_depth_for_n_neighbors(self, n, points): + def find_min_depth_for_n_neighbors(self, n: int, points: np.ndarray) -> int: + """ + Finds the minimum depth of all given points to encounter n neighbors. + + Parameters + ---------- + n : int + Number of neighbors. + points : np.ndarray + Query points. + + Returns + ------- + depth : int + Minimum depth of all given points to encounter n neighbors. + """ if points.shape[0] == 0: return 0 depths = np.ones(len(points)) * self.get_height() @@ -406,33 +650,113 @@ def find_min_depth_for_n_neighbors(self, n, points): depths[idx] = d return np.min(depths) - def propagate_up_reserve_counts(self): + def propagate_up_reserve_counts(self) -> int: + """ + Counts the reserve counts of child nodes and returns sum. + Used during discretization mode when all data points are in leaf nodes at max depth to approximate the required depth to find N neighbours. + + Returns + ------- + reserve_counts : int + sum of child node reserve counts + """ return self.root.propagate_up_reserve_counts() def split(self): + """ + Splits node if possible and transfers data points to child nodes. + """ return self.root.split() - def insert(self, p): + def insert(self, p: np.ndarray) -> None: + """ + Inserts data point into this node if possible, else into child node. + + Parameters + ---------- + p : np.ndarray + Data point to be inserted. + """ return self.root.insert(p) - def serialize(self): + def serialize(self) -> list[int]: + """ + Serializes the tree in a run-length encoded format. First entry determines amount of owned entries. + + Returns + ------- + serialized_data : list[int] + Serialized tree. + """ return self.root.serialize() - def deserialize(self, serialized): + def deserialize(self, serialized: list[int]) -> None: + """ + Deserializes tree from serialized data. + + Parameters + ---------- + serialized : list[int] + Serialized tree. + """ return self.root.deserialize(serialized) - def merge(self, other): + def merge(self, other: "NDtree") -> None: + """ + Merges the other node structure and reserve counts into this node. + + Parameters + ---------- + other : NDtree + Other node structure. + """ return self.root.merge(other.root) - def get_height(self): + def get_height(self) -> int: + """ + Returns height of this node. + + Returns + ------- + height : int + Height of this node. + """ return self.root.get_height() - def clear(self): + def clear(self) -> None: + """ + Clears all data, but preserves tree structure. + """ return self.root.clear() class HilbertDirect: - def __init__(self, dim, bits): + """ + Provides a bijective mapping between an N-dimensional space and 1D space based on the algorithm provided in: + Programming the Hilbert curve by John Skilling (from the AIP Conf. Proc. 707, 381 (2004)) + + Example: 5 bits for each of n=3 coordinates. + 15-bit Hilbert integer = A B C D E F G H I J K L M N O is stored + as its Transpose ^ + X[0] = A D G J M X[2]| 7 + X[1] = B E H K N <-------> | /X[1] + X[2] = C F I L O axes |/ + high low 0------> X[0] + """ + + def __init__(self, dim: int, bits: int): + """ + Constructs mapping between N-dimensional space and 1D space. + Used n bits to encode coords along one dimension. + Therefore, 2**n - 1 is the max coord along one dimension. + + Parameters + ---------- + dim : int + Number of dimensions. + bits : int + Number of bits used per dimension. + """ self.dim = dim self.bits = bits self.dtype = None @@ -442,7 +766,20 @@ def __init__(self, dim, bits): else: self.dtype = np.int64 - def index2coord(self, idx): + def index2coord(self, idx: int) -> np.ndarray: + """ + Converts index to coordinate. + + Parameters + ---------- + idx : int + Index to convert. + + Returns + ------- + coords : np.ndarray + Coordinates of index. + """ X = np.zeros(self.dim, dtype=self.dtype) if self.bits == 0: return X @@ -481,7 +818,20 @@ def index2coord(self, idx): return X - def coord2index(self, coord): + def coord2index(self, coord: np.ndarray) -> int: + """ + Converts coordinate to index. + + Parameters + ---------- + coord : np.ndarray + Coordinate to convert. + + Returns + ------- + index : int + Index of coordinate. + """ if self.bits == 0: return 0 X = deepcopy(coord) @@ -524,22 +874,70 @@ def coord2index(self, coord): class Projector(ABC): + """ + Interface to project high-dimensional data into low-dimensional space. + """ + @abstractmethod - def __call__(self, data): + def __call__(self, data: np.ndarray) -> np.ndarray: + """ + Performs projection on high-dimensional data. + + Parameters + ---------- + data : np.ndarray + High-dimensional data. + + Returns + ------- + proj_data : np.ndarray + Projected data. + """ pass @abstractmethod - def initialize(self, data): + def initialize(self, data: np.ndarray) -> None: + """ + Initializes projection parameters based on data. + + Parameters + ---------- + data : np.ndarray + High-dimensional data. + """ pass class STDProjector(Projector): + """ + Projects high-dimensional data into low-dimensional space using the fields with the highest standard deviation. + """ + def __init__(self, target_dims: int, comm: MPI.Comm): + """ + Constructs STD projection. + + Parameters + ---------- + target_dims : int + Number of target dimensions. + comm : MPI.Comm + MPI communicator. + """ self.num_target_dims = target_dims self.target_dims = np.zeros(target_dims, dtype=np.int32) self.comm = comm - def initialize(self, data): + def initialize(self, data: np.ndarray) -> None: + """ + Initializes projection parameters based on data. + Computes target dimensions using provided data. + + Parameters + ---------- + data : np.ndarray + High-dimensional data. + """ assert data.ndim > 1 std = np.zeros(data.shape[-1]) if data.shape[0] > 0: @@ -550,7 +948,20 @@ def initialize(self, data): np.argsort(stds)[::-1][0 : self.num_target_dims] ).astype(np.int32) - def __call__(self, data): + def __call__(self, data: np.ndarray) -> np.ndarray: + """ + Performs projection on high-dimensional data. + + Parameters + ---------- + data : np.ndarray + High-dimensional data. + + Returns + ------- + proj_data : np.ndarray + Projected data. + """ d = data if data.ndim == 1: d = d[np.newaxis, :] @@ -558,10 +969,31 @@ def __call__(self, data): class IdentityProjector(Projector): - def __call__(self, data): + def __call__(self, data: np.ndarray) -> np.ndarray: + """ + Performs projection on high-dimensional data. (does nothing) + + Parameters + ---------- + data : np.ndarray + High-dimensional data. + + Returns + ------- + proj_data : np.ndarray + Projected data. + """ return data - def initialize(self, data): + def initialize(self, data: np.ndarray) -> None: + """ + Initializes projection parameters based on data. (does nothing) + + Parameters + ---------- + data : np.ndarray + High-dimensional data. + """ pass @@ -571,7 +1003,17 @@ class InterleavedDomain: Will de- and re-compose the distributed data to allow for domain local operations. """ - def __init__(self, config, comm: MPI.Comm): + def __init__(self, config: Config, comm: MPI.Comm): + """ + Constructs InterleavedDomain object. + + Parameters + ---------- + config : Config + Configuration object. + comm : MPI.Comm + MPI communicator. + """ self._config = config self._comm = comm self._size = comm.Get_size() @@ -597,7 +1039,15 @@ def __init__(self, config, comm: MPI.Comm): self._query_rank_mapping = None - def configure(self, domain_config): + def configure(self, domain_config: dict) -> None: + """ + Configures InterleavedDomain object to the provided settings. + + Parameters + ---------- + domain_config : dict + Target configuration. + """ self._max_filling = ( 8 if "max_filling" not in domain_config else domain_config["max_filling"] ) @@ -624,12 +1074,39 @@ def configure(self, domain_config): case "identity": self._projector = IdentityProjector() - def set_local_data(self, x, x_, f): + def set_local_data(self, x: np.ndarray, x_: np.ndarray, f: np.ndarray) -> None: + """ + Sets local data for interleaved domain. + + Parameters + ---------- + x : np.ndarray + Support points. + x_ : np.ndarray + Query points. + f : np.ndarray + Support point function values. + """ self._x_local = x self._x_query_local = x_ self._f_local = f - def decompose(self): + def decompose(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Decomposes the domain, by conceptually merging all support and query points across all rank + and splitting the query points, s.t. each rank will have approx the same amount of query points. + Support points alongside their function values are distributed to the respective ranks, that query points + are surrounded by sufficient support points. + + Returns + ------- + x : np.ndarray + Assigned support points. + x_ : np.ndarray + Assigned query points. + f : np.ndarray + Assigned support point function values. + """ # if not parallel, no work to be done if self._size == 1: return self._x_local, self._x_query_local, self._f_local @@ -637,10 +1114,35 @@ def decompose(self): self._generate_trees() return self._create_partitions() - def get_depth_filling(self): + def get_depth_filling(self) -> Tuple[int, int]: + """ + Gets the tree properties. + + Returns + ------- + max_depth : int + Maximum depth of the tree. + max_filling : int + Maximum filling of the tree. + """ return self._max_depth, self._max_filling - def reassemble(self, x_query, f_query): + def reassemble(self, x_query: np.ndarray, f_query: np.ndarray) -> np.ndarray: + """ + Reassembles the query point function values to match the configuration prior of decomposition. + + Parameters + ---------- + x_query : np.ndarray + Query points. + f_query : np.ndarray + Query point function values. + + Returns + ------- + reassembled : np.ndarray + Reassembled query point function values. + """ # if not parallel, no work to be done if self._size == 1: return f_query @@ -653,7 +1155,7 @@ def reassemble(self, x_query, f_query): data.extend(x_query[i, :].tolist()) data.extend(f_query[i, :].tolist()) send_map[dst_rank].append(data) - local_data = self._communicate(x_query.shape[-1] + f_query.shape[-1], send_map) + local_data = self._communicate(send_map) local_data = np.array(local_data).reshape( -1, x_query.shape[-1] + f_query.shape[-1] ) @@ -669,7 +1171,10 @@ def reassemble(self, x_query, f_query): result[idx, :] = local_data[d_idx, x_query.shape[-1] :] return result - def _normalize_x(self): + def _normalize_x(self) -> None: + """ + Normalizes support and query points to fit within -1 and 1. + """ x_loc_min = np.ones(self._x_local.shape[-1]) * np.inf if self._x_local.shape[0] > 0: x_loc_min = np.min(self._x_local, axis=0) @@ -725,7 +1230,11 @@ def eval_cond(): self._proj_x_local = self._projector(self._x_local) self._proj_x_query_local = self._projector(self._x_query_local) - def _generate_trees(self): + def _generate_trees(self) -> None: + """ + Generates domain decomposition trees, shares them across all ranks and constructs + a globally valid tree used for partitioning. + """ self._normalize_x() proj_dim = self._proj_x_local.shape[1] @@ -773,7 +1282,20 @@ def bcast_tree(t) -> NDtree: self._tree = bcast_tree(tree) self._query_tree = bcast_tree(query_tree) - def _create_partitions(self): + def _create_partitions(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Creates partitions based on a equidistant splitting of the + hilbert space indices of the domain decomposition trees. + + Returns + ------- + x : np.ndarray + Support points around new query points + x_ : np.ndarray + New query points + f : np.ndarray + Support point function values. + """ self._tree.propagate_up_reserve_counts() r_m_depth = self._tree.find_min_depth_for_n_neighbors( self._n_neighbors, self._proj_x_query_local @@ -895,9 +1417,7 @@ def _create_partitions(self): raise RuntimeError("Corresponding rank not found for query point") # transfer query points - x_query_part, inv_map = self._communicate( - self._x_query_local.shape[-1], send_map, return_inverse=True - ) + x_query_part, inv_map = self._communicate(send_map, return_inverse=True) x_query = np.array(x_query_part).reshape(-1, self._x_query_local.shape[-1]) # invert query send map for later (to transfer back) self._query_rank_mapping = {} @@ -930,9 +1450,7 @@ def _create_partitions(self): send_map[rank].append(data) # transfer source points - xf_part = self._communicate( - self._x_local.shape[-1] + self._f_local.shape[-1], send_map - ) + xf_part = self._communicate(send_map) xf_part = np.array(xf_part).reshape( -1, self._x_local.shape[-1] + self._f_local.shape[-1] ) @@ -941,7 +1459,25 @@ def _create_partitions(self): return x, x_query, f - def _communicate(self, entry_size, send_map, return_inverse=False): + def _communicate( + self, send_map: dict[int, list], return_inverse: bool = False + ) -> Union[list, Tuple[list, dict[int, list]]]: + """ + Transfers data between ranks a p2p communication according to the provided send_map. + + Parameters + ---------- + send_map : dict[int, list] + Mapping from destination rank to list of data to be transferred. + return_inverse : bool + Return inverse transfer or not. + + Returns + ------- + comm_result : Union[list, Tuple[list, dict[int, list]]] + If return_inverse is True, returns a list of all received data as well as a mapping from which rank + which data was sent. Otherwise, returns only the received data list. + """ send_counts = [len(send_map[r]) for r in range(self._size)] send_counts[self._rank] = 0 # ignore local count glob_send_counts = self._comm.allgather(send_counts) @@ -996,7 +1532,26 @@ class RBF_PU: The approach here does not require a support radius as data is normalized. """ - def __init__(self, config, logger, comm: MPI.Comm, rank, size): + def __init__( + self, config: Config, logger: Logger, comm: MPI.Comm, rank: int, size: int + ): + """ + Constructs the RBF_PU interpolator. + For rank local interpolation provide MPI.COMM_SELF as comm with according rank and size. + + Parameters + ---------- + config : Config + Configuration object. + logger : Logger + Logger object. + comm : MPI.Comm + MPI communicator. + rank : int + Rank within the provided MPI communicator. + size : int + Size within the provided MPI communicator. + """ self._config = config self._logger = logger self._comm = comm @@ -1014,7 +1569,15 @@ def __init__(self, config, logger, comm: MPI.Comm, rank, size): self._x_query = None self._f = None - def configure(self, interp_config): + def configure(self, interp_config: dict) -> None: + """ + Configures the interpolator to the provided parameters. + + Parameters + ---------- + interp_config : dict + Interpolator configuration. + """ self._domain.configure( {} if "domain_config" not in interp_config @@ -1053,10 +1616,30 @@ def configure(self, interp_config): ) self._phi = partial(RBF_PU.basis_gauss, eps=eps) - def set_local_data(self, x, x_, f): + def set_local_data(self, x: np.ndarray, x_: np.ndarray, f: np.ndarray) -> None: + """ + Sets local data for interleaved domain. + + Parameters + ---------- + x : np.ndarray + Support points. + x_ : np.ndarray + Query points. + f : np.ndarray + Support point function values. + """ self._domain.set_local_data(x, x_, f) - def interpolate(self): + def interpolate(self) -> np.ndarray: + """ + Interpolates the function values at the set query points. + + Returns + ------- + interp_result : np.ndarray + Interpolated function values. + """ self._x, self._x_query, self._f = self._domain.decompose() interp = self.compute_interpolant(self._x, self._f) @@ -1083,7 +1666,27 @@ def evaluate_interpolant(self): else: return self.evaluate_rbf_interpolant - def _compute_cluster_centers(self, x): + def _compute_cluster_centers( + self, x: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Creates cluster centers based on the provided support points. + 2 cluster centers per dimension and one in the middle. + + Parameters + ---------- + x : np.ndarray + Support points. + + Returns + ------- + cluster_centers : np.ndarray + Cluster centers. + local_min : np.ndarray + Minimum of local points. + local_max : np.ndarray + Maximum of local points. + """ assert self._use_pu local_min, local_max = np.min(x, axis=0), np.max(x, axis=0) d4 = (local_max - local_min) / 4 @@ -1112,7 +1715,30 @@ def compute_rbf_pu_interpolant(self, x, f): # compute local RBF interpolant for remaining clusters pass - def compute_rbf_interpolant(self, x, f): + def compute_rbf_interpolant( + self, x: np.ndarray, f: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + Constructs an interpolant based on the provided support points and function values. + + Parameters + ---------- + x : np.ndarray + Support points. + f : np.ndarray + Support point function values. + + Returns + ------- + interp_weights_high: np.ndarray + Interpolant weights, higher order. + interp_weights_low: np.ndarray + Interpolant weights, lower order. + x : np.ndarray + Support points. + f : np.ndarray + Support point function values. + """ n_points = x.shape[0] src_size = x.shape[-1] dst_size = f.shape[-1] @@ -1140,7 +1766,28 @@ def evaluate_rbf_pu_interpolant(self, interp, xq): # sum contributions pass - def evaluate_rbf_interpolant(self, interp, xq): + def evaluate_rbf_interpolant( + self, + interp: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], + xq: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Interpolates the function values at the set query points. + + Parameters + ---------- + interp : Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] + Interpolation model as computed by compute_rbf_interpolant + xq : np.ndarray + Query points. + + Returns + ------- + xq : np.ndarray + Query points. + fq : np.ndarray + Query point function values. + """ w, b, x, f = interp r = np.linalg.norm(xq[None, :, :] - x[:, None, :], ord=2, axis=-1) From 32fabc6e81087fec26fb51a71d65687da10bbf07 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Wed, 8 Apr 2026 10:49:34 +0200 Subject: [PATCH 04/14] add changelog # Conflicts: # CHANGELOG.md # Conflicts: # CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b37288fb..4c0a62a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## latest +- Added RBF interpolation, currently used within adaptivity for output interpolation [#242](https://github.com/precice/micro-manager/pull/242) - Fixed model adaptivity convergence at resolution boundaries to prevent infinite loops for out-of-range switching requests [#252](https://github.com/precice/micro-manager/pull/252) - Add function `set_global_id` to the dummies and the example in the integration test [#247](https://github.com/precice/micro-manager/pull/247) From e22827ded7f2dfd59973d88e5e92988046633902 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Wed, 8 Apr 2026 11:44:20 +0200 Subject: [PATCH 05/14] try fix import error --- micro_manager/interpolation.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/micro_manager/interpolation.py b/micro_manager/interpolation.py index 29a53548..ad622814 100644 --- a/micro_manager/interpolation.py +++ b/micro_manager/interpolation.py @@ -7,7 +7,22 @@ from mpi4py import MPI import numpy as np -from sklearn.neighbors import NearestNeighbors + +try: + from sklearn.neighbors import NearestNeighbors +except ImportError: + + class Dummy: + def __init__(self, *args, **kwargs): + pass + + def __call__(self, *args, **kwargs): + return self + + def __getattr__(self, item): + return self + + NearestNeighbors = Dummy from micro_manager import Config from micro_manager.tools.logging_wrapper import Logger @@ -53,6 +68,8 @@ def get_nearest_neighbor_indices( ) ) k = len(coords) + if NearestNeighbors.__name__ != "NearestNeighbors": + raise RuntimeError("scipy was not imported.") neighbors = NearestNeighbors(n_neighbors=k).fit(coords) neighbor_indices = neighbors.kneighbors( From e324c6d24e39fbda5f16ff359e793ef2d2ea7730 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Fri, 24 Apr 2026 12:12:43 +0200 Subject: [PATCH 06/14] make mappings optional --- micro_manager/config.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/micro_manager/config.py b/micro_manager/config.py index 3ed6491e..cb8489cd 100644 --- a/micro_manager/config.py +++ b/micro_manager/config.py @@ -329,10 +329,14 @@ def read_json_micro_manager(self): self._logger.log_info_rank_zero("Adaptivity type: " + self._adaptivity_type) - if self._adaptivity_type == "global": + try: self._adaptivity_mappings = self._data["simulation_params"][ "adaptivity_settings" ]["mappings"] + except BaseException: + self._logger.log_info_rank_zero( + "Adaptivity will not interpolate outputs, only use representatives." + ) if self._data["simulation_params"]["adaptivity_settings"].get( "lazy_initialization" From 2a9a51c8934bb04cab97efd947463617d7845b17 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Fri, 24 Apr 2026 19:22:41 +0200 Subject: [PATCH 07/14] add test a fixes --- .github/workflows/check-coverage.yml | 1 + .../run-adaptivity-tests-parallel.yml | 8 + micro_manager/interpolation.py | 17 +- tests/unit/test_interpolation.py | 493 +++++++++++++++++- 4 files changed, 512 insertions(+), 7 deletions(-) diff --git a/.github/workflows/check-coverage.yml b/.github/workflows/check-coverage.yml index 56bdf83d..a4fe8a11 100644 --- a/.github/workflows/check-coverage.yml +++ b/.github/workflows/check-coverage.yml @@ -99,6 +99,7 @@ jobs: . ../../.venv/bin/activate mpirun -n 2 --allow-run-as-root -x PYTHONPATH=. python3 -m coverage run --parallel-mode --source=micro_manager -m unittest test_adaptivity_parallel mpirun -n 2 --allow-run-as-root -x PYTHONPATH=. python3 -m coverage run --parallel-mode --source=micro_manager -m unittest test_load_balancing + mpirun -n 2 --allow-run-as-root -x PYTHONPATH=. python3 -m coverage run --parallel-mode --source=micro_manager -m unittest test_interpolation - name: Combine coverage data working-directory: micro-manager/tests/unit diff --git a/.github/workflows/run-adaptivity-tests-parallel.yml b/.github/workflows/run-adaptivity-tests-parallel.yml index ad9d9407..240d6a28 100644 --- a/.github/workflows/run-adaptivity-tests-parallel.yml +++ b/.github/workflows/run-adaptivity-tests-parallel.yml @@ -109,3 +109,11 @@ jobs: . .venv/bin/activate cd tests/unit mpiexec -n 4 --oversubscribe --allow-run-as-root python3 -m unittest test_load_balancing.py + + - name: Run interpolation unit tests with 2 ranks + timeout-minutes: 3 + working-directory: micro-manager + run: | + . .venv/bin/activate + cd tests/unit + mpiexec -n 2 --allow-run-as-root python3 -m unittest test_interpolation.py diff --git a/micro_manager/interpolation.py b/micro_manager/interpolation.py index ad622814..f836d839 100644 --- a/micro_manager/interpolation.py +++ b/micro_manager/interpolation.py @@ -230,10 +230,10 @@ def find_min_depth_for_n_neighbors( """ if self.data_reserve_count < n: return None - if self.children is None: - return None if not self.is_within(p): return None + if self.children is None: + return depth tmp = [ node.find_min_depth_for_n_neighbors(n, depth + 1, p) @@ -478,6 +478,7 @@ def merge(self, other: "NDtree.Node") -> None: other : NDtree.Node Other node structure. """ + assert self._mode == NDtree.Mode.DISCRETIZE is_split = self.children is not None is_split_other = other.children is not None @@ -1104,9 +1105,13 @@ def set_local_data(self, x: np.ndarray, x_: np.ndarray, f: np.ndarray) -> None: f : np.ndarray Support point function values. """ - self._x_local = x - self._x_query_local = x_ - self._f_local = f + def dim_extend(a): + if a.ndim == 1: + return a.reshape(-1, 1) + return a + self._x_local = dim_extend(x) + self._x_query_local = dim_extend(x_) + self._f_local = dim_extend(f) def decompose(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ @@ -1243,7 +1248,7 @@ def eval_cond(): self._x_query_local = self._x_query_local / self._normalization[None, :] glob_cond = self._comm.allgather(eval_cond()) - self._projector.initialize(self._x_local) + self._projector.initialize(self._x_local * self._normalization[None, :]) self._proj_x_local = self._projector(self._x_local) self._proj_x_query_local = self._projector(self._x_query_local) diff --git a/tests/unit/test_interpolation.py b/tests/unit/test_interpolation.py index 47062bba..20d80197 100644 --- a/tests/unit/test_interpolation.py +++ b/tests/unit/test_interpolation.py @@ -1,7 +1,18 @@ import numpy as np +import unittest from unittest import TestCase from unittest.mock import MagicMock -from micro_manager.interpolation import Interpolation +from micro_manager.interpolation import ( + Interpolation, + NDtree, + HilbertDirect, + Projector, + STDProjector, + IdentityProjector, + InterleavedDomain, + RBF_PU +) +from mpi4py import MPI class TestInterpolation(TestCase): @@ -74,3 +85,483 @@ def test_nearest_neighbor_k_larger_than_coords(self): indices = interpolation.get_nearest_neighbor_indices(coords, inter_point, k) self.assertEqual(len(indices), 2) mock_logger.log_info.assert_called_once() + + +class TestNDtree(TestCase): + def test_node_properties(self): + for mode in [NDtree.Mode.DISCRETIZE, NDtree.Mode.INDEX]: + node = NDtree.Node(mode, -np.ones(2), np.ones(2), 2, 2, np.ones(2)) + self.assertEqual(node.dim, 2) + self.assertEqual(node.num_max_split, 4) + self.assertEqual(node.filling, 0) + node.data.append(0) + self.assertEqual(node.filling, 1) + + def test_node_clear(self): + for mode in [NDtree.Mode.DISCRETIZE, NDtree.Mode.INDEX]: + node = NDtree.Node(mode, -np.ones(2), np.ones(2), 2, 2, np.ones(2)) + node.data.append(0) + node.data_reserve_count = 1 + child = NDtree.Node(mode, -np.ones(2), np.ones(2), 1, 2, np.ones(2)) + child.data.append(0) + child.data_reserve_count = 1 + node.children = [child] + + node.clear() + self.assertEqual(len(node.data), 0) + self.assertEqual(node.data_reserve_count, 0) + self.assertEqual(len(node.data), 0) + self.assertEqual(child.data_reserve_count, 0) + + def test_node_propagate_reserve_count(self): + for mode in [NDtree.Mode.DISCRETIZE, NDtree.Mode.INDEX]: + node = NDtree.Node(mode, -np.ones(2), np.ones(2), 2, 2, np.ones(2)) + node.data_reserve_count = 1 + self.assertEqual(node.propagate_up_reserve_counts(), 1) + child = NDtree.Node(mode, -np.ones(2), np.ones(2), 1, 2, np.ones(2)) + child.data_reserve_count = 1 + node.children = [child, child] + self.assertEqual(node.propagate_up_reserve_counts(), 3) + self.assertEqual(node.data_reserve_count, 3) + + def test_node_find_min_depth(self): + for mode in [NDtree.Mode.DISCRETIZE, NDtree.Mode.INDEX]: + node = NDtree.Node(mode, -np.ones(2), np.ones(2), 2, 4, np.ones(2)) + node.children = [ + NDtree.Node(mode, np.array([-1, -1]), np.array([0, 0]), 1, 4, np.array([0, 0])), + NDtree.Node(mode, np.array([ 0, -1]), np.array([1, 0]), 1, 4, np.array([1, 0])), + NDtree.Node(mode, np.array([-1, 0]), np.array([0, 1]), 1, 4, np.array([0, 1])), + NDtree.Node(mode, np.array([ 0, 0]), np.array([1, 1]), 1, 4, np.array([1, 1])), + ] + node.children[0].children = [ + NDtree.Node(mode, np.array([ -1, -1]), np.array([-0.5, -0.5]), 0, 4, np.zeros(2)), + NDtree.Node(mode, np.array([-0.5, -1]), np.array([ 0, -0.5]), 0, 4, np.zeros(2)), + NDtree.Node(mode, np.array([ -1, -0.5]), np.array([-0.5, 0]), 0, 4, np.zeros(2)), + NDtree.Node(mode, np.array([-0.5, -0.5]), np.array([ 0, 0]), 0, 4, np.zeros(2)), + ] + node.children[1].children = [ + NDtree.Node(mode, np.array([ 0, -1]), np.array([ 0.5, -0.5]), 0, 4, np.array([0, 0])), + NDtree.Node(mode, np.array([ 0.5, -1]), np.array([ 1, -0.5]), 0, 4, np.array([1, 0])), + NDtree.Node(mode, np.array([ 0, -0.5]), np.array([ 0.5, 0]), 0, 4, np.array([0, 0])), + NDtree.Node(mode, np.array([ 0.5, -0.5]), np.array([ 1, 0]), 0, 4, np.array([1, 0])), + ] + node.children[0].children[0].data_reserve_count = 1 + node.children[0].children[1].data_reserve_count = 2 + node.children[0].children[2].data_reserve_count = 2 + node.children[0].children[3].data_reserve_count = 1 + node.children[1].children[0].data_reserve_count = 3 + node.children[1].children[1].data_reserve_count = 3 + node.children[1].children[2].data_reserve_count = 4 + node.children[1].children[3].data_reserve_count = 4 + node.propagate_up_reserve_counts() + + self.assertEqual(node.find_min_depth_for_n_neighbors(3, 0, np.array([-1, -1])), 1) + self.assertEqual(node.find_min_depth_for_n_neighbors(3, 0, np.array([1, -1])), 2) + + def test_node_filled_coords(self): + node = NDtree.Node(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 1, 4, np.ones(2)) + node.insert(np.array([-0.5, -0.5])) + node.insert(np.array([0.5, 0.5])) + node.insert(np.array([0.5, 0.5])) + for c in node.children: + c.data_reserve_count = len(c.data) + + coords = node.get_filled_coords(np.array([0, 0]), np.array([2, 2])) + true_targets = np.array([[0, 0], [1, 1], [1, 1]]) + for i in range(3): + self.assertTrue(np.all(true_targets[i] == coords[i])) + + def test_node_split(self): + # DISC + node = NDtree.Node(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 1, 4, np.ones(2)) + node.split() + self.assertTrue(node.children is not None) + node = NDtree.Node(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 1, 4, np.ones(2)) + node.insert(np.array([-0.5, -0.5])) + c_list = node.children + node.split() + self.assertTrue(c_list == node.children) + + # IND + node = NDtree.Node(NDtree.Mode.INDEX, -np.ones(2), np.ones(2), 1, 4, np.ones(2)) + node.insert(np.array([-0.5, -0.5])) + node.insert(np.array([-0.5, -0.5])) + node.insert(np.array([ 0.5, 0.5])) + node.split() + self.assertEqual(len(node.children[0].data), 2) + self.assertEqual(len(node.children[3].data), 1) + + def test_node_insert(self): + # DISC + node = NDtree.Node(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 2, 4, np.ones(2)) + node.insert(np.array([-1, -1])) + node.insert(np.array([1, 1])) + self.assertEqual(len(node.children[0].children[0].data), 1) + self.assertEqual(len(node.children[3].children[3].data), 1) + + # IND + node = NDtree.Node(NDtree.Mode.INDEX, -np.ones(2), np.ones(2), 1, 4, np.ones(2)) + node.insert(np.array([-0.5, -0.5])) + node.insert(np.array([-0.5, -0.5])) + node.insert(np.array([0.5, 0.5])) + node.insert(np.array([0.5, 0.5])) + self.assertEqual(len(node.data), 4) + node.insert(np.array([0.5, 0.5])) + self.assertTrue(node.children is not None) + self.assertEqual(len(node.children[0].data), 2) + self.assertEqual(len(node.children[3].data), 3) + + def test_node_get_coord(self): + node = NDtree.Node(NDtree.Mode.INDEX, -np.ones(2), np.ones(2), 2, 4, np.ones(2)) + self.assertRaises(AssertionError, lambda: node.get_coord_of(0, 0, 0)) + + node = NDtree.Node(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 1, 4, np.ones(2)) + node.split() + self.assertTrue(np.all(node.get_coord_of(np.array([-0.5, -0.5]), np.array([0, 0]), np.array([2, 2])) == np.array([0, 0]))) + self.assertTrue(np.all(node.get_coord_of(np.array([ 0.5, -0.5]), np.array([0, 0]), np.array([2, 2])) == np.array([1, 0]))) + self.assertTrue(np.all(node.get_coord_of(np.array([-0.5, 0.5]), np.array([0, 0]), np.array([2, 2])) == np.array([0, 1]))) + self.assertTrue(np.all(node.get_coord_of(np.array([ 0.5, 0.5]), np.array([0, 0]), np.array([2, 2])) == np.array([1, 1]))) + + def test_node_within(self): + for mode in [NDtree.Mode.DISCRETIZE, NDtree.Mode.INDEX]: + node = NDtree.Node(mode, -np.ones(2), np.ones(2), 0, 4, np.ones(2)) + self.assertTrue(node.is_within(np.array([-1, -1]))) + self.assertTrue(node.is_within(np.array([ 1, -1]))) + self.assertTrue(node.is_within(np.array([-1, 1]))) + self.assertTrue(node.is_within(np.array([ 1, 1]))) + + node = NDtree.Node(mode, -np.ones(2), np.ones(2), 0, 4, np.array([1, 0])) + self.assertTrue(node.is_within(np.array([-1, -1]))) + self.assertTrue(node.is_within(np.array([ 1, -1]))) + self.assertFalse(node.is_within(np.array([-1, 1]))) + self.assertFalse(node.is_within(np.array([ 1, 1]))) + + node = NDtree.Node(mode, -np.ones(2), np.ones(2), 0, 4, np.array([0, 1])) + self.assertTrue(node.is_within(np.array([-1, -1]))) + self.assertFalse(node.is_within(np.array([ 1, -1]))) + self.assertTrue(node.is_within(np.array([-1, 1]))) + self.assertFalse(node.is_within(np.array([1, 1]))) + + node = NDtree.Node(mode, -np.ones(2), np.ones(2), 0, 4, np.array([0, 0])) + self.assertTrue(node.is_within(np.array([-1, -1]))) + self.assertFalse(node.is_within(np.array([1, -1]))) + self.assertFalse(node.is_within(np.array([-1, 1]))) + self.assertFalse(node.is_within(np.array([1, 1]))) + self.assertTrue(node.is_within(np.array([0, 0]))) + + def test_node_height(self): + # DISC + node = NDtree.Node(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 2, 4, np.ones(2)) + self.assertEqual(node.get_height(), 0) + node.insert(np.array([-0.5, -0.5])) + self.assertEqual(node.get_height(), 2) + + # IND + node = NDtree.Node(NDtree.Mode.INDEX, -np.ones(2), np.ones(2), 2, 2, np.ones(2)) + self.assertEqual(node.get_height(), 0) + node.insert(np.array([-1, -1])) + self.assertEqual(node.get_height(), 0) + node.insert(np.array([-1, -1])) + node.insert(np.array([1, 1])) + self.assertEqual(node.get_height(), 1) + node.insert(np.array([-1, -1])) + self.assertEqual(node.get_height(), 2) + + def test_node_serialize(self): + for mode in [NDtree.Mode.DISCRETIZE, NDtree.Mode.INDEX]: + node = NDtree.Node(mode, -np.ones(2), np.ones(2), 1, 4, np.ones(2)) + self.assertListEqual(node.serialize(), [2, 0]) + child = NDtree.Node(mode, -np.ones(2), np.ones(2), 0, 4, np.ones(2)) + node.children = [child, child, child, child] + self.assertEqual(node.serialize(), [9, 2, 0, 2, 0, 2, 0, 2, 0]) + + def test_node_deserialize(self): + node = NDtree.Node(NDtree.Mode.INDEX, -np.ones(2), np.ones(2), 1, 4, np.ones(2)) + node.deserialize([9, 2, 1, 2, 2, 2, 3, 2, 4]) + self.assertTrue(node.children is not None) + self.assertEqual(node.children[0].data_reserve_count, 1) + self.assertEqual(node.children[1].data_reserve_count, 2) + self.assertEqual(node.children[2].data_reserve_count, 3) + self.assertEqual(node.children[3].data_reserve_count, 4) + + def test_node_merge(self): + t1 = NDtree.Node(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 2, 4, np.ones(2)) + t1.split() + t1.children[0].split() + t1.children[0].children[0].data_reserve_count = 2 + t1.children[0].children[1].data_reserve_count = 2 + t1.children[0].children[2].data_reserve_count = 2 + t1.children[0].children[3].data_reserve_count = 2 + t1_total = t1.propagate_up_reserve_counts() + + t2 = NDtree.Node(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 2, 4, np.ones(2)) + t2.split() + t2.children[3].split() + t2.children[3].children[0].data_reserve_count = 3 + t2.children[3].children[1].data_reserve_count = 3 + t2.children[3].children[2].data_reserve_count = 3 + t2.children[3].children[3].data_reserve_count = 3 + t2_total = t2.propagate_up_reserve_counts() + + expected_total = t1_total + t2_total + t = NDtree.Node(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 2, 4, np.ones(2)) + t.merge(t1) + t.merge(t2) + self.assertTrue(t.children[0].children is not None) + self.assertTrue(t.children[3].children is not None) + total = t.propagate_up_reserve_counts() + self.assertEqual(total, expected_total) + + def test_filled_coords(self): + t = NDtree(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 1, 4) + t.root.deserialize([9, 2, 1, 2, 0, 2, 0, 2, 2]) + + coords = t.get_filled_coords() + true_targets = np.array([[0, 0], [1, 1], [1, 1]]) + for i in range(3): + self.assertTrue(np.all(true_targets[i] == coords[i])) + + def test_coords_of(self): + t = NDtree(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 1, 4) + t.root.deserialize([9, 2, 0, 2, 0, 2, 0, 2, 0]) + + coords = t.get_coords_of(np.array([[-0.5, -0.5], [0.5, -0.5], [-0.5, 0.5], [0.5, 0.5]])) + true_targets = np.array([[0, 0], [1, 0], [0, 1], [1, 1]]) + for i in range(4): + self.assertTrue(np.all(true_targets[i] == coords[i])) + + def test_min_depth(self): + t = NDtree(NDtree.Mode.INDEX, -np.ones(2), np.ones(2), 2, 4) + t.root.deserialize([16, 2, 1, 2, 3, 2, 1, 9, 2, 1, 2, 1, 2, 3, 2, 3]) + t.propagate_up_reserve_counts() + self.assertEqual(t.find_min_depth_for_n_neighbors(4, np.array([[0.5, 0.5]])), 1) + self.assertEqual(t.find_min_depth_for_n_neighbors(3, np.array([[0.5, 0.5]])), 2) + + def test_disc_insert(self): + """ + Test if point is inserted at max depth + """ + tree = NDtree(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 3, 4) + tree.insert(-np.ones(2)) + self.assertTrue(len(tree.root.children[0].children[0].children[0].data) > 0) + self.assertTrue(np.all(tree.root.children[0].children[0].children[0].data[0] == -np.ones(2))) + + +class TestHilberDirect(TestCase): + def test_i2c(self): + h = HilbertDirect(2, 3) + n_per_dim = np.power(2, 3) + n_max = np.power(n_per_dim, 2) + c_low = np.zeros(2) + c_high = np.ones(2) * n_per_dim - 1 + + for i in range(n_max): + c = h.index2coord(i) + self.assertTrue(np.all(c >= c_low) and np.all(c <= c_high)) + + def test_c2i(self): + h = HilbertDirect(2, 3) + n_per_dim = np.power(2, 3) + n_max = np.power(n_per_dim, 2) - 1 + + for y in range(n_per_dim): + for x in range(n_per_dim): + i = h.coord2index(np.array([x, y])) + self.assertTrue(0 <= i <= n_max) + + def test_unique(self): + h = HilbertDirect(2, 3) + n_per_dim = np.power(2, 3) + n_max = np.power(n_per_dim, 2) + + indices = [] + for y in range(n_per_dim): + for x in range(n_per_dim): + indices.append(h.coord2index(np.array([x, y]))) + u = np.unique(np.array(indices)) + self.assertEqual(len(u), n_max) + + +class TestProjector(TestCase): + def test_std_proj(self): + proj : Projector = STDProjector(1, MPI.COMM_SELF) + data = np.array([ + [0.1, 0], + [0.2, 5], + [0.1, 10], + [0.3, 20], + ]) + proj.initialize(data) + self.assertEqual(proj.target_dims[0], 1) + self.assertListEqual(data[:, 1].tolist(), proj(data).flatten().tolist()) + + def test_id_proj(self): + proj : Projector = IdentityProjector() + data = np.array([ + [0.1, 0], + [0.2, 5], + [0.1, 10], + [0.3, 20], + ]) + proj.initialize(data) + self.assertTrue(np.all(data == proj(data))) + + +def f_ana(x): + return 1 + 2 * x[:, 0] + 1 * x[:, 1] + 0.1 * x[:, 2] + +rbf_config = { + "domain_config": { + "max_filling": 8, + "coarsening_factor": 2, + "n_neighbors": 10, + "projection_type": "std", + "projection_std_dims": 2, + }, + "use_pu": False, + "basis": "c6" +} +# we have 2 clusters, centered around -1/-1/0 and 1/1/0 +ordered_global_x = np.array([ + [-1.5, -1.0, -0.1], [-0.5, -1.0, -0.1], [-1.0, -1.5, -0.1], [-1.0, -0.5, -0.1], [-1.0, -1.0, -0.1], + [-1.5, -1.0, 0.1], [-0.5, -1.0, 0.1], [-1.0, -1.5, 0.1], [-1.0, -0.5, 0.1], [-1.0, -1.0, 0.1], + [ 1.5, 1.0, -0.1], [ 0.5, 1.0, -0.1], [ 1.0, 1.5, -0.1], [ 1.0, 0.5, -0.1], [ 1.0, 1.0, -0.1], + [ 1.5, 1.0, 0.1], [ 0.5, 1.0, 0.1], [ 1.0, 1.5, 0.1], [ 1.0, 0.5, 0.1], [ 1.0, 1.0, 0.1], +]) +ordered_global_f = f_ana(ordered_global_x) +reordering = np.array([ + 6, 17, 7, 12, 5, 16, 4, 1, 0, 8, 9, 13, 18, 14, 19, 15, 3, 11, 2, 10 +]) +ordered_global_xq = np.array([ + [-1.25, -1.25, 0.0], [-0.75, -1.25, 0.0], [-0.75, -0.75, 0.0], [-1.25, -0.75, 0.0], + [ 1.25, 1.25, 0.0], [ 0.75, 1.25, 0.0], [ 0.75, 0.75, 0.0], [ 1.25, 0.75, 0.0], +]) +reordering_q = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + + +class TestInterleavedDomain(TestCase): + def setUp(self): + self._comm = MPI.COMM_WORLD + self._rank = self._comm.Get_rank() + self._size = self._comm.Get_size() + config = MagicMock() + self._domain = InterleavedDomain(config, self._comm) + self._domain.configure(rbf_config["domain_config"]) + self._ordered_global_x = ordered_global_x + self._ordered_global_f = ordered_global_f + self._reordering = reordering + self._ordered_global_xq = ordered_global_xq + self._reordering_q = reordering_q + + + @unittest.skipUnless( + MPI.COMM_WORLD.Get_size() == 2, "This test only works with 2 ranks." + ) + def test_p2p_comm(self): + send_map = {0: [], 1: []} + if self._rank == 0: + send_map[0].extend([0, 1, 2]) + send_map[1].extend([3, 4, 5]) + else: + send_map[0].extend([6, 7, 8]) + send_map[1].extend([9, 10, 11]) + + local_result, inv_map = self._domain._communicate(send_map, True) + + if self._rank == 0: + self.assertListEqual(sorted(local_result), [0, 1, 2, 6, 7, 8]) + self.assertListEqual(sorted(inv_map[0]), [0, 1, 2]) + self.assertListEqual(sorted(inv_map[1]), [6, 7, 8]) + else: + self.assertListEqual(sorted(local_result), [3, 4, 5, 9, 10, 11]) + self.assertListEqual(sorted(inv_map[0]), [3, 4, 5]) + self.assertListEqual(sorted(inv_map[1]), [9, 10, 11]) + + @unittest.skipUnless( + MPI.COMM_WORLD.Get_size() == 2, "This test only works with 2 ranks." + ) + def test_normalize(self): + self._domain.set_local_data( + self._ordered_global_x[self._reordering][10*self._rank:10*self._rank+10], + self._ordered_global_xq[self._reordering_q][4*self._rank:4*self._rank+4], + self._ordered_global_f[self._reordering][10*self._rank:10*self._rank+10], + ) + + self._domain._normalize_x() + + self.assertTrue(np.all(self._domain._x_local >= -1) and np.all(self._domain._x_local <= 1)) + self.assertTrue(np.all(self._domain._x_query_local >= -1) and np.all(self._domain._x_query_local <= 1)) + self.assertTrue(np.all(self._domain._projector.target_dims == np.array([0, 1]))) + self.assertEqual(self._domain._proj_x_local.ndim, 2) + self.assertEqual(self._domain._proj_x_query_local.ndim, 2) + + @unittest.skipUnless( + MPI.COMM_WORLD.Get_size() == 2, "This test only works with 2 ranks." + ) + def test_gen_trees(self): + self._domain.set_local_data( + self._ordered_global_x[self._reordering][10 * self._rank:10 * self._rank + 10], + self._ordered_global_xq[self._reordering_q][4 * self._rank:4 * self._rank + 4], + self._ordered_global_f[self._reordering][10 * self._rank:10 * self._rank + 10], + ) + + self._domain._generate_trees() + self._domain._tree.propagate_up_reserve_counts() + c_list = self._domain._tree.root.children + self.assertTrue( + c_list[1].data_reserve_count == 0 and + c_list[2].data_reserve_count == 0 and + c_list[0].data_reserve_count != 0 and + c_list[3].data_reserve_count != 0 + ) + self.assertEqual(self._domain._tree.root.data_reserve_count, 20) + + @unittest.skipUnless( + MPI.COMM_WORLD.Get_size() == 2, "This test only works with 2 ranks." + ) + def test_create_partitions(self): + self._domain.set_local_data( + self._ordered_global_x[self._reordering][10 * self._rank:10 * self._rank + 10], + self._ordered_global_xq[self._reordering_q][4 * self._rank:4 * self._rank + 4], + self._ordered_global_f[self._reordering][10 * self._rank:10 * self._rank + 10], + ) + self._domain._generate_trees() + x, xq, f = self._domain._create_partitions() + + expected_xq = self._ordered_global_xq[4 * self._rank:4 * self._rank + 4] / self._domain._normalization[None, :] + expected_xq_set = set() + for i in range(len(expected_xq)): + expected_xq_set.add(tuple(expected_xq[i].tolist())) + for i in range(len(xq)): + self.assertTrue(tuple(xq[i].tolist()) in expected_xq_set) + + +class TestRBF(TestCase): + def setUp(self): + self._comm = MPI.COMM_WORLD + self._rank = self._comm.Get_rank() + self._size = self._comm.Get_size() + self._rbf = RBF_PU( + MagicMock(), + MagicMock(), + self._comm, + self._rank, + self._size + ) + self._rbf.configure(rbf_config) + + @unittest.skipUnless( + MPI.COMM_WORLD.Get_size() == 2, "This test only works with 2 ranks." + ) + def test_interpolation(self): + xq = ordered_global_xq[reordering_q][4 * self._rank:4 * self._rank + 4] + self._rbf.set_local_data( + ordered_global_x[reordering][10 * self._rank:10 * self._rank + 10], + xq, + ordered_global_f[reordering][10 * self._rank:10 * self._rank + 10], + ) + + fq = self._rbf.interpolate() + fq_ana = f_ana(xq).reshape(-1, 1) + norms = np.linalg.norm(fq_ana - fq, ord=2, axis=-1) + self.assertTrue(np.allclose(norms, 0, rtol=1e-8)) From d802061932b36e7c0f071627421804d5501fc1fd Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Fri, 24 Apr 2026 19:23:16 +0200 Subject: [PATCH 08/14] fix format --- micro_manager/interpolation.py | 2 + tests/unit/test_interpolation.py | 346 ++++++++++++++++++++++--------- 2 files changed, 249 insertions(+), 99 deletions(-) diff --git a/micro_manager/interpolation.py b/micro_manager/interpolation.py index f836d839..bc1f9912 100644 --- a/micro_manager/interpolation.py +++ b/micro_manager/interpolation.py @@ -1105,10 +1105,12 @@ def set_local_data(self, x: np.ndarray, x_: np.ndarray, f: np.ndarray) -> None: f : np.ndarray Support point function values. """ + def dim_extend(a): if a.ndim == 1: return a.reshape(-1, 1) return a + self._x_local = dim_extend(x) self._x_query_local = dim_extend(x_) self._f_local = dim_extend(f) diff --git a/tests/unit/test_interpolation.py b/tests/unit/test_interpolation.py index 20d80197..8a1b4ef1 100644 --- a/tests/unit/test_interpolation.py +++ b/tests/unit/test_interpolation.py @@ -10,7 +10,7 @@ STDProjector, IdentityProjector, InterleavedDomain, - RBF_PU + RBF_PU, ) from mpi4py import MPI @@ -128,22 +128,66 @@ def test_node_find_min_depth(self): for mode in [NDtree.Mode.DISCRETIZE, NDtree.Mode.INDEX]: node = NDtree.Node(mode, -np.ones(2), np.ones(2), 2, 4, np.ones(2)) node.children = [ - NDtree.Node(mode, np.array([-1, -1]), np.array([0, 0]), 1, 4, np.array([0, 0])), - NDtree.Node(mode, np.array([ 0, -1]), np.array([1, 0]), 1, 4, np.array([1, 0])), - NDtree.Node(mode, np.array([-1, 0]), np.array([0, 1]), 1, 4, np.array([0, 1])), - NDtree.Node(mode, np.array([ 0, 0]), np.array([1, 1]), 1, 4, np.array([1, 1])), + NDtree.Node( + mode, np.array([-1, -1]), np.array([0, 0]), 1, 4, np.array([0, 0]) + ), + NDtree.Node( + mode, np.array([0, -1]), np.array([1, 0]), 1, 4, np.array([1, 0]) + ), + NDtree.Node( + mode, np.array([-1, 0]), np.array([0, 1]), 1, 4, np.array([0, 1]) + ), + NDtree.Node( + mode, np.array([0, 0]), np.array([1, 1]), 1, 4, np.array([1, 1]) + ), ] node.children[0].children = [ - NDtree.Node(mode, np.array([ -1, -1]), np.array([-0.5, -0.5]), 0, 4, np.zeros(2)), - NDtree.Node(mode, np.array([-0.5, -1]), np.array([ 0, -0.5]), 0, 4, np.zeros(2)), - NDtree.Node(mode, np.array([ -1, -0.5]), np.array([-0.5, 0]), 0, 4, np.zeros(2)), - NDtree.Node(mode, np.array([-0.5, -0.5]), np.array([ 0, 0]), 0, 4, np.zeros(2)), + NDtree.Node( + mode, np.array([-1, -1]), np.array([-0.5, -0.5]), 0, 4, np.zeros(2) + ), + NDtree.Node( + mode, np.array([-0.5, -1]), np.array([0, -0.5]), 0, 4, np.zeros(2) + ), + NDtree.Node( + mode, np.array([-1, -0.5]), np.array([-0.5, 0]), 0, 4, np.zeros(2) + ), + NDtree.Node( + mode, np.array([-0.5, -0.5]), np.array([0, 0]), 0, 4, np.zeros(2) + ), ] node.children[1].children = [ - NDtree.Node(mode, np.array([ 0, -1]), np.array([ 0.5, -0.5]), 0, 4, np.array([0, 0])), - NDtree.Node(mode, np.array([ 0.5, -1]), np.array([ 1, -0.5]), 0, 4, np.array([1, 0])), - NDtree.Node(mode, np.array([ 0, -0.5]), np.array([ 0.5, 0]), 0, 4, np.array([0, 0])), - NDtree.Node(mode, np.array([ 0.5, -0.5]), np.array([ 1, 0]), 0, 4, np.array([1, 0])), + NDtree.Node( + mode, + np.array([0, -1]), + np.array([0.5, -0.5]), + 0, + 4, + np.array([0, 0]), + ), + NDtree.Node( + mode, + np.array([0.5, -1]), + np.array([1, -0.5]), + 0, + 4, + np.array([1, 0]), + ), + NDtree.Node( + mode, + np.array([0, -0.5]), + np.array([0.5, 0]), + 0, + 4, + np.array([0, 0]), + ), + NDtree.Node( + mode, + np.array([0.5, -0.5]), + np.array([1, 0]), + 0, + 4, + np.array([1, 0]), + ), ] node.children[0].children[0].data_reserve_count = 1 node.children[0].children[1].data_reserve_count = 2 @@ -155,11 +199,17 @@ def test_node_find_min_depth(self): node.children[1].children[3].data_reserve_count = 4 node.propagate_up_reserve_counts() - self.assertEqual(node.find_min_depth_for_n_neighbors(3, 0, np.array([-1, -1])), 1) - self.assertEqual(node.find_min_depth_for_n_neighbors(3, 0, np.array([1, -1])), 2) + self.assertEqual( + node.find_min_depth_for_n_neighbors(3, 0, np.array([-1, -1])), 1 + ) + self.assertEqual( + node.find_min_depth_for_n_neighbors(3, 0, np.array([1, -1])), 2 + ) def test_node_filled_coords(self): - node = NDtree.Node(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 1, 4, np.ones(2)) + node = NDtree.Node( + NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 1, 4, np.ones(2) + ) node.insert(np.array([-0.5, -0.5])) node.insert(np.array([0.5, 0.5])) node.insert(np.array([0.5, 0.5])) @@ -173,10 +223,14 @@ def test_node_filled_coords(self): def test_node_split(self): # DISC - node = NDtree.Node(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 1, 4, np.ones(2)) + node = NDtree.Node( + NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 1, 4, np.ones(2) + ) node.split() self.assertTrue(node.children is not None) - node = NDtree.Node(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 1, 4, np.ones(2)) + node = NDtree.Node( + NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 1, 4, np.ones(2) + ) node.insert(np.array([-0.5, -0.5])) c_list = node.children node.split() @@ -186,14 +240,16 @@ def test_node_split(self): node = NDtree.Node(NDtree.Mode.INDEX, -np.ones(2), np.ones(2), 1, 4, np.ones(2)) node.insert(np.array([-0.5, -0.5])) node.insert(np.array([-0.5, -0.5])) - node.insert(np.array([ 0.5, 0.5])) + node.insert(np.array([0.5, 0.5])) node.split() self.assertEqual(len(node.children[0].data), 2) self.assertEqual(len(node.children[3].data), 1) def test_node_insert(self): # DISC - node = NDtree.Node(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 2, 4, np.ones(2)) + node = NDtree.Node( + NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 2, 4, np.ones(2) + ) node.insert(np.array([-1, -1])) node.insert(np.array([1, 1])) self.assertEqual(len(node.children[0].children[0].data), 1) @@ -215,32 +271,62 @@ def test_node_get_coord(self): node = NDtree.Node(NDtree.Mode.INDEX, -np.ones(2), np.ones(2), 2, 4, np.ones(2)) self.assertRaises(AssertionError, lambda: node.get_coord_of(0, 0, 0)) - node = NDtree.Node(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 1, 4, np.ones(2)) + node = NDtree.Node( + NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 1, 4, np.ones(2) + ) node.split() - self.assertTrue(np.all(node.get_coord_of(np.array([-0.5, -0.5]), np.array([0, 0]), np.array([2, 2])) == np.array([0, 0]))) - self.assertTrue(np.all(node.get_coord_of(np.array([ 0.5, -0.5]), np.array([0, 0]), np.array([2, 2])) == np.array([1, 0]))) - self.assertTrue(np.all(node.get_coord_of(np.array([-0.5, 0.5]), np.array([0, 0]), np.array([2, 2])) == np.array([0, 1]))) - self.assertTrue(np.all(node.get_coord_of(np.array([ 0.5, 0.5]), np.array([0, 0]), np.array([2, 2])) == np.array([1, 1]))) + self.assertTrue( + np.all( + node.get_coord_of( + np.array([-0.5, -0.5]), np.array([0, 0]), np.array([2, 2]) + ) + == np.array([0, 0]) + ) + ) + self.assertTrue( + np.all( + node.get_coord_of( + np.array([0.5, -0.5]), np.array([0, 0]), np.array([2, 2]) + ) + == np.array([1, 0]) + ) + ) + self.assertTrue( + np.all( + node.get_coord_of( + np.array([-0.5, 0.5]), np.array([0, 0]), np.array([2, 2]) + ) + == np.array([0, 1]) + ) + ) + self.assertTrue( + np.all( + node.get_coord_of( + np.array([0.5, 0.5]), np.array([0, 0]), np.array([2, 2]) + ) + == np.array([1, 1]) + ) + ) def test_node_within(self): for mode in [NDtree.Mode.DISCRETIZE, NDtree.Mode.INDEX]: node = NDtree.Node(mode, -np.ones(2), np.ones(2), 0, 4, np.ones(2)) self.assertTrue(node.is_within(np.array([-1, -1]))) - self.assertTrue(node.is_within(np.array([ 1, -1]))) - self.assertTrue(node.is_within(np.array([-1, 1]))) - self.assertTrue(node.is_within(np.array([ 1, 1]))) + self.assertTrue(node.is_within(np.array([1, -1]))) + self.assertTrue(node.is_within(np.array([-1, 1]))) + self.assertTrue(node.is_within(np.array([1, 1]))) node = NDtree.Node(mode, -np.ones(2), np.ones(2), 0, 4, np.array([1, 0])) self.assertTrue(node.is_within(np.array([-1, -1]))) - self.assertTrue(node.is_within(np.array([ 1, -1]))) + self.assertTrue(node.is_within(np.array([1, -1]))) self.assertFalse(node.is_within(np.array([-1, 1]))) - self.assertFalse(node.is_within(np.array([ 1, 1]))) + self.assertFalse(node.is_within(np.array([1, 1]))) node = NDtree.Node(mode, -np.ones(2), np.ones(2), 0, 4, np.array([0, 1])) self.assertTrue(node.is_within(np.array([-1, -1]))) - self.assertFalse(node.is_within(np.array([ 1, -1]))) + self.assertFalse(node.is_within(np.array([1, -1]))) self.assertTrue(node.is_within(np.array([-1, 1]))) - self.assertFalse(node.is_within(np.array([1, 1]))) + self.assertFalse(node.is_within(np.array([1, 1]))) node = NDtree.Node(mode, -np.ones(2), np.ones(2), 0, 4, np.array([0, 0])) self.assertTrue(node.is_within(np.array([-1, -1]))) @@ -251,7 +337,9 @@ def test_node_within(self): def test_node_height(self): # DISC - node = NDtree.Node(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 2, 4, np.ones(2)) + node = NDtree.Node( + NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 2, 4, np.ones(2) + ) self.assertEqual(node.get_height(), 0) node.insert(np.array([-0.5, -0.5])) self.assertEqual(node.get_height(), 2) @@ -285,7 +373,9 @@ def test_node_deserialize(self): self.assertEqual(node.children[3].data_reserve_count, 4) def test_node_merge(self): - t1 = NDtree.Node(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 2, 4, np.ones(2)) + t1 = NDtree.Node( + NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 2, 4, np.ones(2) + ) t1.split() t1.children[0].split() t1.children[0].children[0].data_reserve_count = 2 @@ -294,7 +384,9 @@ def test_node_merge(self): t1.children[0].children[3].data_reserve_count = 2 t1_total = t1.propagate_up_reserve_counts() - t2 = NDtree.Node(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 2, 4, np.ones(2)) + t2 = NDtree.Node( + NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 2, 4, np.ones(2) + ) t2.split() t2.children[3].split() t2.children[3].children[0].data_reserve_count = 3 @@ -304,7 +396,9 @@ def test_node_merge(self): t2_total = t2.propagate_up_reserve_counts() expected_total = t1_total + t2_total - t = NDtree.Node(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 2, 4, np.ones(2)) + t = NDtree.Node( + NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 2, 4, np.ones(2) + ) t.merge(t1) t.merge(t2) self.assertTrue(t.children[0].children is not None) @@ -325,7 +419,9 @@ def test_coords_of(self): t = NDtree(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 1, 4) t.root.deserialize([9, 2, 0, 2, 0, 2, 0, 2, 0]) - coords = t.get_coords_of(np.array([[-0.5, -0.5], [0.5, -0.5], [-0.5, 0.5], [0.5, 0.5]])) + coords = t.get_coords_of( + np.array([[-0.5, -0.5], [0.5, -0.5], [-0.5, 0.5], [0.5, 0.5]]) + ) true_targets = np.array([[0, 0], [1, 0], [0, 1], [1, 1]]) for i in range(4): self.assertTrue(np.all(true_targets[i] == coords[i])) @@ -344,7 +440,9 @@ def test_disc_insert(self): tree = NDtree(NDtree.Mode.DISCRETIZE, -np.ones(2), np.ones(2), 3, 4) tree.insert(-np.ones(2)) self.assertTrue(len(tree.root.children[0].children[0].children[0].data) > 0) - self.assertTrue(np.all(tree.root.children[0].children[0].children[0].data[0] == -np.ones(2))) + self.assertTrue( + np.all(tree.root.children[0].children[0].children[0].data[0] == -np.ones(2)) + ) class TestHilberDirect(TestCase): @@ -384,25 +482,29 @@ def test_unique(self): class TestProjector(TestCase): def test_std_proj(self): - proj : Projector = STDProjector(1, MPI.COMM_SELF) - data = np.array([ - [0.1, 0], - [0.2, 5], - [0.1, 10], - [0.3, 20], - ]) + proj: Projector = STDProjector(1, MPI.COMM_SELF) + data = np.array( + [ + [0.1, 0], + [0.2, 5], + [0.1, 10], + [0.3, 20], + ] + ) proj.initialize(data) self.assertEqual(proj.target_dims[0], 1) self.assertListEqual(data[:, 1].tolist(), proj(data).flatten().tolist()) def test_id_proj(self): - proj : Projector = IdentityProjector() - data = np.array([ - [0.1, 0], - [0.2, 5], - [0.1, 10], - [0.3, 20], - ]) + proj: Projector = IdentityProjector() + data = np.array( + [ + [0.1, 0], + [0.2, 5], + [0.1, 10], + [0.3, 20], + ] + ) proj.initialize(data) self.assertTrue(np.all(data == proj(data))) @@ -410,32 +512,59 @@ def test_id_proj(self): def f_ana(x): return 1 + 2 * x[:, 0] + 1 * x[:, 1] + 0.1 * x[:, 2] + rbf_config = { "domain_config": { - "max_filling": 8, - "coarsening_factor": 2, - "n_neighbors": 10, - "projection_type": "std", - "projection_std_dims": 2, - }, + "max_filling": 8, + "coarsening_factor": 2, + "n_neighbors": 10, + "projection_type": "std", + "projection_std_dims": 2, + }, "use_pu": False, - "basis": "c6" + "basis": "c6", } # we have 2 clusters, centered around -1/-1/0 and 1/1/0 -ordered_global_x = np.array([ - [-1.5, -1.0, -0.1], [-0.5, -1.0, -0.1], [-1.0, -1.5, -0.1], [-1.0, -0.5, -0.1], [-1.0, -1.0, -0.1], - [-1.5, -1.0, 0.1], [-0.5, -1.0, 0.1], [-1.0, -1.5, 0.1], [-1.0, -0.5, 0.1], [-1.0, -1.0, 0.1], - [ 1.5, 1.0, -0.1], [ 0.5, 1.0, -0.1], [ 1.0, 1.5, -0.1], [ 1.0, 0.5, -0.1], [ 1.0, 1.0, -0.1], - [ 1.5, 1.0, 0.1], [ 0.5, 1.0, 0.1], [ 1.0, 1.5, 0.1], [ 1.0, 0.5, 0.1], [ 1.0, 1.0, 0.1], -]) +ordered_global_x = np.array( + [ + [-1.5, -1.0, -0.1], + [-0.5, -1.0, -0.1], + [-1.0, -1.5, -0.1], + [-1.0, -0.5, -0.1], + [-1.0, -1.0, -0.1], + [-1.5, -1.0, 0.1], + [-0.5, -1.0, 0.1], + [-1.0, -1.5, 0.1], + [-1.0, -0.5, 0.1], + [-1.0, -1.0, 0.1], + [1.5, 1.0, -0.1], + [0.5, 1.0, -0.1], + [1.0, 1.5, -0.1], + [1.0, 0.5, -0.1], + [1.0, 1.0, -0.1], + [1.5, 1.0, 0.1], + [0.5, 1.0, 0.1], + [1.0, 1.5, 0.1], + [1.0, 0.5, 0.1], + [1.0, 1.0, 0.1], + ] +) ordered_global_f = f_ana(ordered_global_x) -reordering = np.array([ - 6, 17, 7, 12, 5, 16, 4, 1, 0, 8, 9, 13, 18, 14, 19, 15, 3, 11, 2, 10 -]) -ordered_global_xq = np.array([ - [-1.25, -1.25, 0.0], [-0.75, -1.25, 0.0], [-0.75, -0.75, 0.0], [-1.25, -0.75, 0.0], - [ 1.25, 1.25, 0.0], [ 0.75, 1.25, 0.0], [ 0.75, 0.75, 0.0], [ 1.25, 0.75, 0.0], -]) +reordering = np.array( + [6, 17, 7, 12, 5, 16, 4, 1, 0, 8, 9, 13, 18, 14, 19, 15, 3, 11, 2, 10] +) +ordered_global_xq = np.array( + [ + [-1.25, -1.25, 0.0], + [-0.75, -1.25, 0.0], + [-0.75, -0.75, 0.0], + [-1.25, -0.75, 0.0], + [1.25, 1.25, 0.0], + [0.75, 1.25, 0.0], + [0.75, 0.75, 0.0], + [1.25, 0.75, 0.0], + ] +) reordering_q = np.array([0, 2, 4, 6, 1, 3, 5, 7]) @@ -453,7 +582,6 @@ def setUp(self): self._ordered_global_xq = ordered_global_xq self._reordering_q = reordering_q - @unittest.skipUnless( MPI.COMM_WORLD.Get_size() == 2, "This test only works with 2 ranks." ) @@ -482,15 +610,26 @@ def test_p2p_comm(self): ) def test_normalize(self): self._domain.set_local_data( - self._ordered_global_x[self._reordering][10*self._rank:10*self._rank+10], - self._ordered_global_xq[self._reordering_q][4*self._rank:4*self._rank+4], - self._ordered_global_f[self._reordering][10*self._rank:10*self._rank+10], + self._ordered_global_x[self._reordering][ + 10 * self._rank : 10 * self._rank + 10 + ], + self._ordered_global_xq[self._reordering_q][ + 4 * self._rank : 4 * self._rank + 4 + ], + self._ordered_global_f[self._reordering][ + 10 * self._rank : 10 * self._rank + 10 + ], ) self._domain._normalize_x() - self.assertTrue(np.all(self._domain._x_local >= -1) and np.all(self._domain._x_local <= 1)) - self.assertTrue(np.all(self._domain._x_query_local >= -1) and np.all(self._domain._x_query_local <= 1)) + self.assertTrue( + np.all(self._domain._x_local >= -1) and np.all(self._domain._x_local <= 1) + ) + self.assertTrue( + np.all(self._domain._x_query_local >= -1) + and np.all(self._domain._x_query_local <= 1) + ) self.assertTrue(np.all(self._domain._projector.target_dims == np.array([0, 1]))) self.assertEqual(self._domain._proj_x_local.ndim, 2) self.assertEqual(self._domain._proj_x_query_local.ndim, 2) @@ -500,19 +639,25 @@ def test_normalize(self): ) def test_gen_trees(self): self._domain.set_local_data( - self._ordered_global_x[self._reordering][10 * self._rank:10 * self._rank + 10], - self._ordered_global_xq[self._reordering_q][4 * self._rank:4 * self._rank + 4], - self._ordered_global_f[self._reordering][10 * self._rank:10 * self._rank + 10], + self._ordered_global_x[self._reordering][ + 10 * self._rank : 10 * self._rank + 10 + ], + self._ordered_global_xq[self._reordering_q][ + 4 * self._rank : 4 * self._rank + 4 + ], + self._ordered_global_f[self._reordering][ + 10 * self._rank : 10 * self._rank + 10 + ], ) self._domain._generate_trees() self._domain._tree.propagate_up_reserve_counts() c_list = self._domain._tree.root.children self.assertTrue( - c_list[1].data_reserve_count == 0 and - c_list[2].data_reserve_count == 0 and - c_list[0].data_reserve_count != 0 and - c_list[3].data_reserve_count != 0 + c_list[1].data_reserve_count == 0 + and c_list[2].data_reserve_count == 0 + and c_list[0].data_reserve_count != 0 + and c_list[3].data_reserve_count != 0 ) self.assertEqual(self._domain._tree.root.data_reserve_count, 20) @@ -521,14 +666,23 @@ def test_gen_trees(self): ) def test_create_partitions(self): self._domain.set_local_data( - self._ordered_global_x[self._reordering][10 * self._rank:10 * self._rank + 10], - self._ordered_global_xq[self._reordering_q][4 * self._rank:4 * self._rank + 4], - self._ordered_global_f[self._reordering][10 * self._rank:10 * self._rank + 10], + self._ordered_global_x[self._reordering][ + 10 * self._rank : 10 * self._rank + 10 + ], + self._ordered_global_xq[self._reordering_q][ + 4 * self._rank : 4 * self._rank + 4 + ], + self._ordered_global_f[self._reordering][ + 10 * self._rank : 10 * self._rank + 10 + ], ) self._domain._generate_trees() x, xq, f = self._domain._create_partitions() - expected_xq = self._ordered_global_xq[4 * self._rank:4 * self._rank + 4] / self._domain._normalization[None, :] + expected_xq = ( + self._ordered_global_xq[4 * self._rank : 4 * self._rank + 4] + / self._domain._normalization[None, :] + ) expected_xq_set = set() for i in range(len(expected_xq)): expected_xq_set.add(tuple(expected_xq[i].tolist())) @@ -541,24 +695,18 @@ def setUp(self): self._comm = MPI.COMM_WORLD self._rank = self._comm.Get_rank() self._size = self._comm.Get_size() - self._rbf = RBF_PU( - MagicMock(), - MagicMock(), - self._comm, - self._rank, - self._size - ) + self._rbf = RBF_PU(MagicMock(), MagicMock(), self._comm, self._rank, self._size) self._rbf.configure(rbf_config) @unittest.skipUnless( MPI.COMM_WORLD.Get_size() == 2, "This test only works with 2 ranks." ) def test_interpolation(self): - xq = ordered_global_xq[reordering_q][4 * self._rank:4 * self._rank + 4] + xq = ordered_global_xq[reordering_q][4 * self._rank : 4 * self._rank + 4] self._rbf.set_local_data( - ordered_global_x[reordering][10 * self._rank:10 * self._rank + 10], + ordered_global_x[reordering][10 * self._rank : 10 * self._rank + 10], xq, - ordered_global_f[reordering][10 * self._rank:10 * self._rank + 10], + ordered_global_f[reordering][10 * self._rank : 10 * self._rank + 10], ) fq = self._rbf.interpolate() From 99b0c277c197c1052108f744164dd307fb446d1e Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Fri, 24 Apr 2026 19:34:44 +0200 Subject: [PATCH 09/14] small fixes --- .github/workflows/run-adaptivity-tests-parallel.yml | 2 ++ micro_manager/adaptivity/adaptivity.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/run-adaptivity-tests-parallel.yml b/.github/workflows/run-adaptivity-tests-parallel.yml index 240d6a28..34ab09a3 100644 --- a/.github/workflows/run-adaptivity-tests-parallel.yml +++ b/.github/workflows/run-adaptivity-tests-parallel.yml @@ -115,5 +115,7 @@ jobs: working-directory: micro-manager run: | . .venv/bin/activate + pip install .[sklearn] + pip uninstall -y pyprecice cd tests/unit mpiexec -n 2 --allow-run-as-root python3 -m unittest test_interpolation.py diff --git a/micro_manager/adaptivity/adaptivity.py b/micro_manager/adaptivity/adaptivity.py index 2b99e5d3..267b6cc5 100644 --- a/micro_manager/adaptivity/adaptivity.py +++ b/micro_manager/adaptivity/adaptivity.py @@ -288,11 +288,11 @@ def _interpolate_output(self, micro_input, micro_sims_output) -> None: active_lids = self.get_active_sim_local_ids() inactive_lids = self.get_inactive_sim_local_ids() arg_sizes = {} - for name, value in micro_input[active_lids[0]].items(): + for name, value in micro_input[-1].items(): arg_sizes[name] = ( 1 if type(value) != np.ndarray and type(value) != list else len(value) ) - for name, value in micro_sims_output[active_lids[0]].items(): + for name, value in micro_sims_output[-1].items(): arg_sizes[name] = ( 1 if type(value) != np.ndarray and type(value) != list else len(value) ) From 64bc462eb8e71ab8fb329d928c7263d4cd951ed4 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Fri, 24 Apr 2026 21:09:30 +0200 Subject: [PATCH 10/14] add doc --- docs/configuration.md | 54 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index a7940cea..bff05904 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -89,6 +89,60 @@ To turn on adaptivity, set `"adaptivity": true` in `simulation_params`. Then und | `similarity_measure` | Similarity measure to be used for adaptivity. Can be either `L1`, `L2`, `L1rel` or `L2rel`. By default, `L1` is used. The `rel` variants calculate the respective relative norms. This parameter is *optional*. | `L2rel` | | `lazy_initialization` | Set to `true` to lazily create and initialize micro simulations. If selected, micro simulation objects are created only when the micro simulation is activated for the first time. | `false` | | `load_balancing` | Set to `true` to dynamically balance simulations for parallel runs. See [load balancing settings](#load-balancing) below. | `false` | +| `mappings` | Optional interpolation of results. Set to list of mapping configurations. See below for further details. | `[]` | + +Adaptivity can optionally interpolate results using RBF interpolation. For any subset of `write_data_names` fields, a function +can be defined from `read_data_names` to `write_data_names`. When using multiple functions, their interpolation target, i.e., fields +of `write_data_names` must be mutually disjunct. Mappings can be defined as: +```json +"mappings": [ + { + "src_fields": ["input1", "input2"], + "dst_fields": ["output1", "output2"], + "n_neighbors": 50, + "rbf_config": { + "use_pu": false, + "pu_overlap": 0.1, + "basis": { + "type": "c6" + } + }, + "domain_config": { + "max_filling": 8, + "coarsening_factor": 2, + "projection": { + "type": "std", + "target_dims": 3 + } + } + }, + {...} +] +``` + +| Parameter | Description | Default | +|-----------------|-----------------------------------------------------------------------|---------| +| `src_fields` | List of entries from `read_data_names` | `None` | +| `dst_fields` | List of entries from `write_data_names` | `None` | +| `n_neighbours` | Interpolation parameter. Determines minimum amount of support points. | `50` | +| `rbf_config` | RBF interpolation configuration. | `None` | +| `domain_config` | Function source domain description. | | + +Currently, only RBF interpolation is supported. However, the configuration options for PU-RBF interpolation already exist. +A selection of different basis function is available: `c0`, `c2`, `c4`, `c6`. +The domain must be described/further configured as input data is shared across rank and must be redistribute for interpolation. +Towards this, spatial discretization techniques are used. For better performance, data can be projected to a lower dimensional space +using the fields with the highest standard deviation. + +| Parameter | Description | Default | +|---------------------|---------------------------------------------------------------------------|------------| +| `use_pu` | Enables PU-RBF. (currently not supported) | `False` | +| `pu_overlap` | Controlls overlap radius for PU decomposition. | `0.1` | +| `basis` | RBF basis function: `c0`, `c2`, `c4`, `c6` | `None` | +| `max_filling` | Tunes maximum filling of tree nodes used during decomposition. | `8` | +| `coarsening_factor` | Adjusts the fidelity of the discretized domain. Only integer values >= 1. | `2` | +| `projection` | Either `std` or `identity`. | `identity` | +| `target_dims` | Only if `std` is used. Denotes the target dimension after projection. | `None` | Example of adaptivity configuration is From 7619fd2b75afb86587638b95bb5855184b66f7d9 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Sat, 25 Apr 2026 10:04:46 +0200 Subject: [PATCH 11/14] add small eps to lb timings for inactive sims, fix doc --- docs/configuration.md | 1 + micro_manager/micro_manager.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index bff05904..3ef85525 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -94,6 +94,7 @@ To turn on adaptivity, set `"adaptivity": true` in `simulation_params`. Then und Adaptivity can optionally interpolate results using RBF interpolation. For any subset of `write_data_names` fields, a function can be defined from `read_data_names` to `write_data_names`. When using multiple functions, their interpolation target, i.e., fields of `write_data_names` must be mutually disjunct. Mappings can be defined as: + ```json "mappings": [ { diff --git a/micro_manager/micro_manager.py b/micro_manager/micro_manager.py index d500d2a0..ccc0783a 100644 --- a/micro_manager/micro_manager.py +++ b/micro_manager/micro_manager.py @@ -1102,11 +1102,17 @@ def _solve_micro_simulations_with_adaptivity( # Resolve micro sim output data for inactive simulations for inactive_lid in inactive_sim_lids: + self.load_balancing.pre_sim_solve( + self._global_ids_of_local_sims[inactive_lid] + ) micro_sims_output[inactive_lid]["Active-State"] = 0 gid = self._global_ids_of_local_sims[inactive_lid] micro_sims_output[inactive_lid][ "Active-Steps" ] = self._micro_sims_active_steps[gid] + self.load_balancing.post_sim_solve( + self._global_ids_of_local_sims[inactive_lid] + ) # Collect micro sim output for adaptivity calculation for i in range(self._local_number_of_sims): From 2ad6c7aabacc15f66f1d5871f0314b1212880c11 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Wed, 29 Apr 2026 16:48:06 +0200 Subject: [PATCH 12/14] add some fixes for edge cases --- micro_manager/adaptivity/adaptivity.py | 5 ++++ micro_manager/adaptivity/global_adaptivity.py | 9 ++++++ micro_manager/interpolation.py | 28 +++++++++++++------ micro_manager/micro_manager.py | 7 +++-- micro_manager/micro_simulation.py | 16 +++++++++-- 5 files changed, 50 insertions(+), 15 deletions(-) diff --git a/micro_manager/adaptivity/adaptivity.py b/micro_manager/adaptivity/adaptivity.py index 267b6cc5..bfeb7d47 100644 --- a/micro_manager/adaptivity/adaptivity.py +++ b/micro_manager/adaptivity/adaptivity.py @@ -59,6 +59,7 @@ def __init__( self._max_similarity_dist = 0.0 self._interpolation = None + self._interp_min = -1 self._mappings = [] self._mapping_configs = [] mappings = configurator.get_adaptivity_mapping_configs() @@ -132,6 +133,10 @@ def _load_mappings(self, mappings: list) -> None: src_fields = mapping["src_fields"] dst_fields = mapping["dst_fields"] n_neighbors = mapping["n_neighbors"] + if self._interp_min == -1: + self._interp_min = n_neighbors + else: + self._interp_min = min(n_neighbors, self._interp_min) self._mappings.append((src_fields, dst_fields)) config = {} diff --git a/micro_manager/adaptivity/global_adaptivity.py b/micro_manager/adaptivity/global_adaptivity.py index 29ad8f5a..69ffdc60 100644 --- a/micro_manager/adaptivity/global_adaptivity.py +++ b/micro_manager/adaptivity/global_adaptivity.py @@ -267,7 +267,16 @@ def get_full_field_micro_output( ) micro_sims_output = deepcopy(micro_output) + num_active = np.sum(self._is_sim_active) + if num_active == self._is_sim_active.shape[0]: + self._precice_participant.stop_last_profiling_section() + return micro_sims_output + self._communicate_micro_output(micro_sims_output) + if num_active <= self._interp_min: + self._precice_participant.stop_last_profiling_section() + return micro_sims_output + self._interpolate_output(micro_input, micro_sims_output) self._precice_participant.stop_last_profiling_section() diff --git a/micro_manager/interpolation.py b/micro_manager/interpolation.py index bc1f9912..dec965d0 100644 --- a/micro_manager/interpolation.py +++ b/micro_manager/interpolation.py @@ -1336,29 +1336,38 @@ def _create_partitions(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: for coord, idx in query_mapping.items(): query_mapping_inv[idx] = coord sorted_1d_query_indices = sorted(query_mapping.values()) + if len(sorted_1d_query_indices) == 0: + return ( + np.zeros((0, self._x_local.shape[-1])), + np.zeros((0, self._x_query_local.shape[-1])), + np.zeros((0, self._f_local.shape[-1])), + ) # partition based on query points - target_point_per_rank = len(sorted_1d_query_indices) // self._size + target_point_per_rank = max( + (len(sorted_1d_query_indices) + 1) // self._size, 16 + ) partitions = {r: [-1, -1] for r in range(self._size)} last_val = sorted_1d_query_indices[0] start_idx = 0 part_begin = 0 part_idx = 0 - for i in range(1, len(sorted_1d_query_indices)): + for i in range(0, len(sorted_1d_query_indices)): if sorted_1d_query_indices[i] != last_val: last_val = sorted_1d_query_indices[i] start_idx = i - if i - part_begin + 1 < target_point_per_rank: - continue - # handle last partition if part_idx == self._size - 1: partitions[part_idx][0] = part_begin partitions[part_idx][1] = len(sorted_1d_query_indices) - 1 part_idx = part_idx + 1 + part_begin = len(sorted_1d_query_indices) break + if i - part_begin + 1 < target_point_per_rank: + continue + # partition has minimum size, find nearest end of current cell end_idx = i for j in range(i + 1, len(sorted_1d_query_indices)): @@ -1380,8 +1389,9 @@ def _create_partitions(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: # last part was not used if ( part_idx in partitions - and partitions[part_idx][0] == 0 - and partitions[part_idx][1] == 0 + and partitions[part_idx][0] == -1 + and partitions[part_idx][1] == -1 + and part_begin < len(sorted_1d_query_indices) ): partitions[part_idx][0] = part_begin partitions[part_idx][1] = len(sorted_1d_query_indices) - 1 @@ -1389,7 +1399,7 @@ def _create_partitions(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: # assign surrounding src domain to rank local query points src_domains = {r: [None, None] for r in range(self._size)} for rank, p_range in partitions.items(): - if -1 == p_range[0] == p_range[1]: + if -1 == p_range[0] and p_range[0] == p_range[1]: continue # gather coords, find bounding box @@ -1423,7 +1433,7 @@ def _create_partitions(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: # find owning partition found = False for rank, rank_range in partitions.items(): - if -1 == rank_range[0] == rank_range[1]: + if -1 == rank_range[0] and rank_range[0] == rank_range[1]: continue if ( diff --git a/micro_manager/micro_manager.py b/micro_manager/micro_manager.py index ccc0783a..ecd7aee7 100644 --- a/micro_manager/micro_manager.py +++ b/micro_manager/micro_manager.py @@ -1160,9 +1160,10 @@ def _solve_micro_simulations_with_model_adaptivity( self._model_adaptivity_controller.finalise_solve() for lid, sim in enumerate(self._micro_sims): - output[lid][ - "model_resolution" - ] = self._model_adaptivity_controller.get_sim_class_resolution(sim) + res = -1 + if sim is not None: + self._model_adaptivity_controller.get_sim_class_resolution(sim) + output[lid]["model_resolution"] = res return output def _get_solve_variant(self) -> Callable[[list, float], list]: diff --git a/micro_manager/micro_simulation.py b/micro_manager/micro_simulation.py index c784e39d..8ca1beae 100644 --- a/micro_manager/micro_simulation.py +++ b/micro_manager/micro_simulation.py @@ -555,9 +555,19 @@ def __getattr__(self, name): # Only add initialize override if the wrapped class actually has it, # so that requires_initialize() returns True for those classes. if has_initialize: - class_body += """ -def initialize(self, *args, **kwargs): - return self._wrapped.initialize(*args, **kwargs) + argspec = inspect.getfullargspec(cls.initialize) + # build args + init_args = f"{', '.join(argspec.args)}" + params = f"{', '.join(argspec.args[1::])}" + if argspec.varargs is not None: + init_args += f", *args" + params += f", args" + if argspec.varkw is not None: + init_args += f", **kwargs" + params += f", kwargs" + class_body += f""" +def initialize({init_args}): + return self._wrapped.initialize({params}) """ # Only add output override if the wrapped class actually has it, From 528d1085b6cfcf40f07a18ec3fbd01608bb3668d Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Wed, 29 Apr 2026 16:58:07 +0200 Subject: [PATCH 13/14] fix broken test --- tests/unit/test_interpolation.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_interpolation.py b/tests/unit/test_interpolation.py index 8a1b4ef1..e30713e7 100644 --- a/tests/unit/test_interpolation.py +++ b/tests/unit/test_interpolation.py @@ -679,15 +679,13 @@ def test_create_partitions(self): self._domain._generate_trees() x, xq, f = self._domain._create_partitions() - expected_xq = ( - self._ordered_global_xq[4 * self._rank : 4 * self._rank + 4] - / self._domain._normalization[None, :] - ) - expected_xq_set = set() - for i in range(len(expected_xq)): - expected_xq_set.add(tuple(expected_xq[i].tolist())) - for i in range(len(xq)): - self.assertTrue(tuple(xq[i].tolist()) in expected_xq_set) + if self._rank == 0: + expected_xq = self._ordered_global_xq / self._domain._normalization[None, :] + expected_xq_set = set() + for i in range(len(expected_xq)): + expected_xq_set.add(tuple(expected_xq[i].tolist())) + for i in range(len(xq)): + self.assertTrue(tuple(xq[i].tolist()) in expected_xq_set) class TestRBF(TestCase): From 6ad6c81067477e4fc7cbbe0d4455801dce9e5409 Mon Sep 17 00:00:00 2001 From: Alex Hocks Date: Fri, 15 May 2026 15:36:34 +0200 Subject: [PATCH 14/14] fix test --- micro_manager/micro_manager.py | 2 +- tests/unit/test_model_adaptivity.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/micro_manager/micro_manager.py b/micro_manager/micro_manager.py index ecd7aee7..25110e4c 100644 --- a/micro_manager/micro_manager.py +++ b/micro_manager/micro_manager.py @@ -1162,7 +1162,7 @@ def _solve_micro_simulations_with_model_adaptivity( for lid, sim in enumerate(self._micro_sims): res = -1 if sim is not None: - self._model_adaptivity_controller.get_sim_class_resolution(sim) + res = self._model_adaptivity_controller.get_sim_class_resolution(sim) output[lid]["model_resolution"] = res return output diff --git a/tests/unit/test_model_adaptivity.py b/tests/unit/test_model_adaptivity.py index 181cefd6..7abb3933 100644 --- a/tests/unit/test_model_adaptivity.py +++ b/tests/unit/test_model_adaptivity.py @@ -174,7 +174,7 @@ def solve_variant(micro_sims_input, dt, computed_outputs): ) self.assertEqual(manager._micro_sims[0].name, "coarse") - self.assertEqual(result, [{"result": 2}]) + self.assertEqual(result, [{"result": 2, "model_resolution": 1}]) self.assertTrue(controller._converged) def test_manager_loop_exits_on_invalid_switch_request(self): @@ -214,4 +214,4 @@ def solve_variant(micro_sims_input, dt, computed_outputs): self.assertEqual(len(solve_calls), 1) self.assertEqual(solve_calls[0]["computed_outputs"], {}) - self.assertEqual(result, [{"result": 1.0}]) + self.assertEqual(result, [{"result": 1.0, "model_resolution": 1}])