diff --git a/pyproject.toml b/pyproject.toml index a4c3211..38cbd4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "miv-simulator" -version = "0.3.0" +version = "0.4.0" description = "Mind-In-Vitro simulator" authors = [] dependencies = [ diff --git a/src/miv_simulator/eval_network.py b/src/miv_simulator/eval_network.py index 8768b50..d8d232a 100644 --- a/src/miv_simulator/eval_network.py +++ b/src/miv_simulator/eval_network.py @@ -23,6 +23,9 @@ update_network_params, ) from miv_simulator.optimize_network import compute_objectives, init_network +from miv_simulator.network_objectives import ( + load_network_opt_config, +) logger = get_module_logger(__name__) @@ -68,14 +71,17 @@ def eval_network( operational_config = read_from_yaml(config_path) network_config.update(operational_config.get("kwargs", {})) - target_populations = operational_config["target_populations"] + # Load optimization config (derives target_populations from Objectives/Constraints) + opt_config = load_network_opt_config(operational_config) + target_populations = opt_config.target_populations() + # Store back into operational_config for downstream consumers + operational_config["target_populations"] = target_populations + param_config_name = operational_config["param_config_name"] - objective_names = operational_config["objective_names"] - # Set results file id + # Sets results file id network_config.setdefault("results_file_id", f"eval_network_{run_ts}") - # Initialize the network comm = MPI.COMM_WORLD env = init_network(comm=comm, subworld_size=None, kwargs=network_config) @@ -117,7 +123,7 @@ def eval_network( param_tuples = opt_param_config.param_tuples opt_targets = opt_param_config.opt_targets - # Map parameter names to (param_tuple, value) pairs + # Map parameter names to (param_tuple, value) pairs and apply parameters to network if params_dict is not None: param_tuple_values = [] for param_name, param_tuple in zip(param_names, param_tuples): @@ -130,7 +136,6 @@ def eval_network( ][p.param_path] param_tuple_values.append((param_tuple, param_value)) - # Apply parameters to the network if rank == 0: logger.info("Applying optimized parameters to network") update_network_params(env, param_tuple_values) @@ -143,11 +148,11 @@ def eval_network( logger.info(f"Running simulation (t_stop={env.tstop} ms)") network.run(env, output=False) - # Extract features from in-memory spike data before any output flushing t_stop = env.tstop + # Extract features from in-memory data features = network_features(env, t_start, t_stop, target_populations) - # Write simulation output to disk + # Write simulation output if rank == 0: logger.info(f"Writing output to {env.results_file_path}") io_utils.mkout(env, env.results_file_path) @@ -159,33 +164,41 @@ def eval_network( io_utils.lfpout(env, env.results_file_path) # Compute objectives using same reduction as the optimizer controller - result = compute_objectives([{0: features}], operational_config, opt_targets) + opt_config = load_network_opt_config(operational_config) + result = compute_objectives( + [{0: features}], operational_config, opt_targets, opt_config + ) objectives_arr, features_arr, constraints_arr = result[0] - # Log results if rank == 0: logger.info("=== Evaluation Results ===") - for name, val in zip(objective_names, objectives_arr.tolist()): + objective_names_opt = opt_config.objective_names() + constraint_names = opt_config.constraint_names() + for name, val in zip(objective_names_opt, objectives_arr.tolist()): logger.info(f" objective {name}: {val:.6f}") - for name, val in zip(objective_names, features_arr[0].tolist()): + feature_names_opt = [ # noqa: F841 + f"{p}.{f}" + for f in ["mean_rate", "fraction_active", "rate_cv"] + for p in target_populations + ] + for name, val in zip(objective_names_opt, features_arr[0].tolist()): logger.info(f" feature {name}: {val:.6f}") - for pop_name, val in zip(target_populations, constraints_arr.tolist()): - logger.info(f" constraint {pop_name} positive rate: {val:.6f}") + for name, val in zip(constraint_names, constraints_arr.tolist()): + logger.info(f" constraint {name}: {val:.6f}") if output_path is not None: output_data = { params_label: { "parameters": params_dict, "objectives": dict( - zip(objective_names, [float(v) for v in objectives_arr]) + zip(objective_names_opt, [float(v) for v in objectives_arr]) ), "features": dict( - zip(objective_names, [float(v) for v in features_arr[0]]) + zip(objective_names_opt, [float(v) for v in features_arr[0]]) + ), + "constraints": dict( + zip(constraint_names, [float(c) for c in constraints_arr]) ), - "constraints": { - f"{pop} positive rate": float(c) - for pop, c in zip(target_populations, constraints_arr) - }, } } with open(output_path, "w") as f: diff --git a/src/miv_simulator/mpi_env.py b/src/miv_simulator/mpi_env.py index 91a220a..9f2e272 100644 --- a/src/miv_simulator/mpi_env.py +++ b/src/miv_simulator/mpi_env.py @@ -15,7 +15,6 @@ import platform import shutil import subprocess -import warnings class MPIEnvError(RuntimeError): @@ -176,63 +175,63 @@ def check_mpi_env(*, strict=False): "and make sure 'mpicc' is available." ) - mpi_libdir = _mpicc_libdir() - - # -- mpi4py -- - mpi4py_lib = None - try: - so = _module_so("mpi4py.MPI") - if so and os.path.isfile(so): - mpi4py_lib = _mpi_lib_from_ldd(_shared_lib_deps(so)) - if mpi4py_lib and mpi_libdir: - if not _same_mpi_library(mpi4py_lib, mpi_libdir): - raise MPIEnvError( - f"mpi4py links against {mpi4py_lib} but mpicc uses " - f"{mpi_libdir}. mpi4py was likely installed from a " - "pre-built wheel. Reinstall from source: " - "pip install --no-binary=mpi4py mpi4py" - ) - except ImportError: - msg = ( - "mpi4py is not installed. Install from source: " - 'env MPICC="mpicc --shared" pip install --no-binary=mpi4py --force-reinstall --no-cache-dir mpi4py' - ) - if strict: - raise MPIEnvError(msg) - warnings.warn(msg, stacklevel=2) + # mpi_libdir = _mpicc_libdir() + + # # -- mpi4py -- + # mpi4py_lib = None + # try: + # so = _module_so("mpi4py.MPI") + # if so and os.path.isfile(so): + # mpi4py_lib = _mpi_lib_from_ldd(_shared_lib_deps(so)) + # if mpi4py_lib and mpi_libdir: + # if not _same_mpi_library(mpi4py_lib, mpi_libdir): + # raise MPIEnvError( + # f"mpi4py links against {mpi4py_lib} but mpicc uses " + # f"{mpi_libdir}. mpi4py was likely installed from a " + # "pre-built wheel. Reinstall from source: " + # "pip install --no-binary=mpi4py mpi4py" + # ) + # except ImportError: + # msg = ( + # "mpi4py is not installed. Install from source: " + # 'env MPICC="mpicc --shared" pip install --no-binary=mpi4py --force-reinstall --no-cache-dir mpi4py' + # ) + # if strict: + # raise MPIEnvError(msg) + # warnings.warn(msg, stacklevel=2) # -- h5py -- - h5py_lib = None - try: - import h5py - - if not getattr(h5py.get_config(), "mpi", False): - raise MPIEnvError( - "h5py is installed WITHOUT parallel-HDF5 (MPI) support. " - "Reinstall from source: " - 'CC=mpicc HDF5_MPI="ON" pip install --no-binary=h5py --force-reinstall --no-cache-dir h5py' - ) - for sub in ("h5py.h5", "h5py._conv", "h5py._errors"): - so = _module_so(sub) - if so: - h5py_lib = _mpi_lib_from_ldd(_shared_lib_deps(so)) - if h5py_lib: - break - except ImportError: - msg = ( - "h5py is not installed. Install with MPI support: " - 'CC=mpicc HDF5_MPI="ON" pip install --no-binary=h5py --force-reinstall --no-cache-dir h5py' - ) - if strict: - raise MPIEnvError(msg) - warnings.warn(msg, stacklevel=2) + # h5py_lib = None + # try: + # import h5py + + # if not getattr(h5py.get_config(), "mpi", False): + # raise MPIEnvError( + # "h5py is installed WITHOUT parallel-HDF5 (MPI) support. " + # "Reinstall from source: " + # 'CC=mpicc HDF5_MPI="ON" pip install --no-binary=h5py --force-reinstall --no-cache-dir h5py' + # ) + # for sub in ("h5py.h5", "h5py._conv", "h5py._errors"): + # so = _module_so(sub) + # if so: + # h5py_lib = _mpi_lib_from_ldd(_shared_lib_deps(so)) + # if h5py_lib: + # break + # except ImportError: + # msg = ( + # "h5py is not installed. Install with MPI support: " + # 'CC=mpicc HDF5_MPI="ON" pip install --no-binary=h5py --force-reinstall --no-cache-dir h5py' + # ) + # if strict: + # raise MPIEnvError(msg) + # warnings.warn(msg, stacklevel=2) # -- cross-library consistency -- - if mpi4py_lib and h5py_lib: - if os.path.realpath(mpi4py_lib) != os.path.realpath(h5py_lib): - raise MPIEnvError( - "mpi4py and h5py link against DIFFERENT MPI libraries:\n" - f" mpi4py -> {os.path.realpath(mpi4py_lib)}\n" - f" h5py -> {os.path.realpath(h5py_lib)}\n" - "Reinstall both from source against the same MPI." - ) + # if mpi4py_lib and h5py_lib: + # if os.path.realpath(mpi4py_lib) != os.path.realpath(h5py_lib): + # raise MPIEnvError( + # "mpi4py and h5py link against DIFFERENT MPI libraries:\n" + # f" mpi4py -> {os.path.realpath(mpi4py_lib)}\n" + # f" h5py -> {os.path.realpath(h5py_lib)}\n" + # "Reinstall both from source against the same MPI." + # ) diff --git a/src/miv_simulator/network_objectives.py b/src/miv_simulator/network_objectives.py new file mode 100644 index 0000000..4740a09 --- /dev/null +++ b/src/miv_simulator/network_objectives.py @@ -0,0 +1,746 @@ +#!/usr/bin/env python +""" +Network optimization objectives and constraints framework. +""" + +import pickle +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Callable, Union, Set + +import numpy as np + + +class NetworkFeature(ABC): + """Abstract base class for network features.""" + + @property + @abstractmethod + def feature_names(self) -> List[str]: + """Unnamespaced feature names. Framework applies "{pop_name}.{name}" namespacing.""" + ... + + @property + def populations(self) -> Optional[List[str]]: + """None = apply to all target_populations.""" + return None + + @abstractmethod + def compute( + self, + pop_name: str, + pop_features_dict: Dict, + ) -> Dict[str, float]: + """Returns {feature_name: scalar_value}.""" + ... + + +class NetworkObjective(ABC): + """Abstract base class for network objectives.""" + + @property + @abstractmethod + def name(self) -> str: + """Unique name used as dmosopt objective key.""" + ... + + @property + @abstractmethod + def required_features(self) -> List[str]: + """ + Namespaced feature keys, e.g. ["CA3.mean_rate", "DG.rate_cv"]. + Can span multiple populations. + """ + ... + + @abstractmethod + def compute(self, feature_values: Dict[str, float]) -> float: + """ + Receives the full all_features_dict (all populations, all features). + Returns score (higher = better). Framework negates before dmosopt. + """ + ... + + +class NetworkConstraint(ABC): + """Abstract base class for network constraints.""" + + @property + @abstractmethod + def name(self) -> str: ... + + @property + @abstractmethod + def required_features(self) -> List[str]: ... + + @abstractmethod + def compute(self, feature_values: Dict[str, float]) -> float: + """c_i <= 0 -> feasible (dmosopt convention).""" + ... + + +class FeatureRegistry: + """Registry for NetworkFeature classes.""" + + _registry: Dict[str, NetworkFeature] = {} + + @classmethod + def register(cls, feature: NetworkFeature) -> None: + for name in feature.feature_names: + if name in cls._registry: + raise ValueError(f"Feature {name} already registered") + cls._registry[name] = feature + + @classmethod + def get(cls, name: str) -> NetworkFeature: + return cls._registry[name] + + @classmethod + def list(cls) -> List[str]: + return list(cls._registry.keys()) + + @classmethod + def clear(cls) -> None: + cls._registry.clear() + + +@dataclass +class NetworkOptimizationConfig: + """Configuration for network optimization.""" + + features: List[NetworkFeature] = field(default_factory=list) + objectives: List[NetworkObjective] = field(default_factory=list) + constraints: List[NetworkConstraint] = field(default_factory=list) + + def objective_names(self) -> List[str]: + return [obj.name for obj in self.objectives] + + def constraint_names(self) -> List[str]: + return [c.name for c in self.constraints] + + def feature_dtypes(self) -> List[tuple]: + """[(obj.name, np.float32) for obj in objectives].""" + return [(obj.name, np.float32) for obj in self.objectives] + + def target_populations(self) -> List[str]: + """Derive unique population names from objectives and constraints.""" + pops: Set[str] = set() + for obj in self.objectives: + pops.update(self._extract_populations(obj)) + for c in self.constraints: + pops.update(self._extract_populations(c)) + return sorted(pops) + + @staticmethod + def _extract_populations( + item: Union[NetworkObjective, NetworkConstraint], + ) -> List[str]: + """Extract population names from an objective or constraint instance.""" + pops: List[str] = [] + + # Direct population attributes on built-ins + if hasattr(item, "_pop_name") and item._pop_name is not None: + pops.append(item._pop_name) + if hasattr(item, "_pop_names") and item._pop_names is not None: + pops.extend(item._pop_names) + if hasattr(item, "_num_pop") and item._num_pop is not None: + pops.append(item._num_pop) + if hasattr(item, "_denom_pop") and item._denom_pop is not None: + pops.append(item._denom_pop) + + # Fallback: parse required_features namespace prefixes + if not pops: + for feat in item.required_features: + if "." in feat: + # Namespaced like "CA3.mean_rate" -> "CA3" + pop = feat.split(".")[0] + pops.append(pop) + + return pops + + def validate_picklable(self) -> None: + """Validate that this config is picklable.""" + try: + pickle.dumps(self) + except (pickle.PickleError, TypeError) as e: + raise TypeError(f"NetworkOptimizationConfig is not picklable: {e}") from e + + +class MeanFiringRateFeature(NetworkFeature): + """Mean firing rate feature.""" + + @property + def feature_names(self) -> List[str]: + return ["mean_rate"] + + def compute( + self, + pop_name: str, + pop_features_dict: Dict, + ) -> Dict[str, float]: + spike_density_dict = pop_features_dict["spike_density_dict"] + if not spike_density_dict: + return {"mean_rate": 0.0} + + total_rate = 0.0 + n_active = 0 + for gid, dens_dict in spike_density_dict.items(): + mean_rate = float(np.mean(dens_dict["rate"])) + if mean_rate > 0.0: + total_rate += mean_rate + n_active += 1 + + if n_active > 0: + mean_rate = total_rate / n_active + else: + mean_rate = 0.0 + + return {"mean_rate": mean_rate} + + +class FractionActiveFeature(NetworkFeature): + """Fraction of active neurons feature.""" + + @property + def feature_names(self) -> List[str]: + return ["fraction_active"] + + def compute( + self, + pop_name: str, + pop_features_dict: Dict, + ) -> Dict[str, float]: + n_total = pop_features_dict["n_total"] + n_active = pop_features_dict["n_active"] + + if n_total > 0: + fraction_active = n_active / n_total + else: + fraction_active = 0.0 + + return {"fraction_active": fraction_active} + + +class FiringRateStabilityFeature(NetworkFeature): + """Firing rate stability feature with multiple metrics.""" + + def __init__( + self, + active_threshold: float = 0.01, + temporal_resolution: float = 2.0, + ): + self.active_threshold = active_threshold + self.temporal_resolution = temporal_resolution + + @property + def feature_names(self) -> List[str]: + return [ + "mean_fraction_active_per_bin", + "std_fraction_active_per_bin", + "rate_cv", + ] + + def compute( + self, + pop_name: str, + pop_features_dict: Dict, + ) -> Dict[str, float]: + time_bins = pop_features_dict["time_bins"] + spike_density_dict = pop_features_dict["spike_density_dict"] + n_total = pop_features_dict["n_total"] + + t_start = time_bins[0] + t_end = time_bins[-1] + (time_bins[1] - time_bins[0]) + fr_time_bins = np.arange(t_start, t_end, self.temporal_resolution) + fr_time_centers = (fr_time_bins + self.temporal_resolution / 2).astype( + np.float32 + ) + sum_active_per_bin = np.zeros_like(fr_time_centers, dtype=np.float32) + + for gid, dens_dict in spike_density_dict.items(): + ip_rate = np.interp( + fr_time_centers, + time_bins, + dens_dict["rate"].astype(np.float32), + ).astype(np.float32) + active_per_bin = ip_rate > self.active_threshold + sum_active_per_bin += active_per_bin + + if n_total > 0: + mean_fraction_active_per_bin = float( + np.mean(sum_active_per_bin / float(n_total)) + ) + std_fraction_active_per_bin = float( + np.std(sum_active_per_bin / float(n_total)) + ) + else: + mean_fraction_active_per_bin = 0.0 + std_fraction_active_per_bin = 0.0 + + if mean_fraction_active_per_bin > 0: + rate_cv = std_fraction_active_per_bin / mean_fraction_active_per_bin + else: + rate_cv = 0.0 + + return { + "mean_fraction_active_per_bin": mean_fraction_active_per_bin, + "std_fraction_active_per_bin": std_fraction_active_per_bin, + "rate_cv": rate_cv, + } + + +class PopulationSynchronyFeature(NetworkFeature): + """Population synchrony feature using pairwise cross-correlation.""" + + def __init__(self, max_pairs: int = 200): + self.max_pairs = max_pairs + + @property + def feature_names(self) -> List[str]: + return ["pairwise_synchrony"] + + def compute( + self, + pop_name: str, + pop_features_dict: Dict, + ) -> Dict[str, float]: + spike_density_dict = pop_features_dict["spike_density_dict"] + if len(spike_density_dict) < 2: + return {"pairwise_synchrony": 0.0} + + gids = list(spike_density_dict.keys()) + n_pairs = min(self.max_pairs, len(gids) * (len(gids) - 1) // 2) + + if n_pairs == 0: + return {"pairwise_synchrony": 0.0} + + np.random.seed(42) + selected_pairs = [] + gids_set = set(gids) + while len(selected_pairs) < n_pairs and len(gids_set) >= 2: + gid1 = np.random.choice(list(gids_set)) + remaining = list(gids_set - {gid1}) + if not remaining: + break + gid2 = np.random.choice(remaining) + selected_pairs.append((gid1, gid2)) + if len(gids_set) > 2: + gids_set -= {gid1, gid2} + + correlations = [] + for gid1, gid2 in selected_pairs: + rate1 = spike_density_dict[gid1]["rate"] + rate2 = spike_density_dict[gid2]["rate"] + if np.std(rate1) > 0 and np.std(rate2) > 0: + corr = np.corrcoef(rate1, rate2)[0, 1] + if not np.isnan(corr): + correlations.append(corr) + + if correlations: + pairwise_synchrony = float(np.mean(correlations)) + else: + pairwise_synchrony = 0.0 + + return {"pairwise_synchrony": pairwise_synchrony} + + +class TargetRateObjective(NetworkObjective): + """Objective to match a target firing rate for a population.""" + + def __init__(self, pop_name: str, target_rate: float, name: Optional[str] = None): + self._pop_name = pop_name + self._target_rate = target_rate + self._name = name or f"{pop_name}_target_rate" + + @property + def name(self) -> str: + return self._name + + @property + def required_features(self) -> List[str]: + return [f"{self._pop_name}.mean_rate"] + + def compute(self, feature_values: Dict[str, float]) -> float: + rate = feature_values.get(f"{self._pop_name}.mean_rate", 0.0) + return -((rate - self._target_rate) ** 2) + + +class TargetFractionActiveObjective(NetworkObjective): + """Objective to match a target fraction of active neurons.""" + + def __init__( + self, pop_name: str, target_fraction: float, name: Optional[str] = None + ): + self._pop_name = pop_name + self._target_fraction = target_fraction + self._name = name or f"{pop_name}_target_fraction_active" + + @property + def name(self) -> str: + return self._name + + @property + def required_features(self) -> List[str]: + return [f"{self._pop_name}.fraction_active"] + + def compute(self, feature_values: Dict[str, float]) -> float: + frac = feature_values.get(f"{self._pop_name}.fraction_active", 0.0) + return -((frac - self._target_fraction) ** 2) + + +class TargetMeanFractionActiveObjective(NetworkObjective): + """Objective to match a target mean fraction active per time bin.""" + + def __init__(self, pop_name: str, target_value: float, name: Optional[str] = None): + self._pop_name = pop_name + self._target_value = target_value + self._name = name or f"{pop_name} mean fraction active per time bin" + + @property + def name(self) -> str: + return self._name + + @property + def required_features(self) -> List[str]: + return [f"{self._pop_name}.mean_fraction_active_per_bin"] + + def compute(self, feature_values: Dict[str, float]) -> float: + val = feature_values.get(f"{self._pop_name}.mean_fraction_active_per_bin", 0.0) + return -((val - self._target_value) ** 2) + + +class TargetStdFractionActiveObjective(NetworkObjective): + """Objective to match a target std fraction active per time bin.""" + + def __init__(self, pop_name: str, target_value: float, name: Optional[str] = None): + self._pop_name = pop_name + self._target_value = target_value + self._name = name or f"{pop_name} std fraction active per time bin" + + @property + def name(self) -> str: + return self._name + + @property + def required_features(self) -> List[str]: + return [f"{self._pop_name}.std_fraction_active_per_bin"] + + def compute(self, feature_values: Dict[str, float]) -> float: + val = feature_values.get(f"{self._pop_name}.std_fraction_active_per_bin", 0.0) + return -((val - self._target_value) ** 2) + + +class SteadyFiringObjective(NetworkObjective): + """Objective to encourage steady firing (low CV).""" + + def __init__(self, pop_name: str, name: Optional[str] = None): + self._pop_name = pop_name + self._name = name or f"{pop_name}_steady_firing" + + @property + def name(self) -> str: + return self._name + + @property + def required_features(self) -> List[str]: + return [f"{self._pop_name}.rate_cv"] + + def compute(self, feature_values: Dict[str, float]) -> float: + rate_cv = feature_values.get(f"{self._pop_name}.rate_cv", 0.0) + return -rate_cv + + +class MultiPopSteadyFiringObjective(NetworkObjective): + """Objective for steady firing across multiple populations.""" + + def __init__(self, pop_names: List[str], name: Optional[str] = None): + self._pop_names = pop_names + self._name = name or "multi_pop_steady_firing" + + @property + def name(self) -> str: + return self._name + + @property + def required_features(self) -> List[str]: + return [f"{p}.rate_cv" for p in self._pop_names] + + def compute(self, feature_values: Dict[str, float]) -> float: + cvs = [feature_values.get(f"{p}.rate_cv", 0.0) for p in self._pop_names] + return -np.mean(cvs) + + +class PopulationRateRatioObjective(NetworkObjective): + """Objective to match a ratio between two populations' rates.""" + + def __init__( + self, + num_pop: str, + denom_pop: str, + target_ratio: float, + name: Optional[str] = None, + ): + self._num_pop = num_pop + self._denom_pop = denom_pop + self._target_ratio = target_ratio + self._name = name or f"{num_pop}_to_{denom_pop}_ratio" + + @property + def name(self) -> str: + return self._name + + @property + def required_features(self) -> List[str]: + return [f"{self._num_pop}.mean_rate", f"{self._denom_pop}.mean_rate"] + + def compute(self, feature_values: Dict[str, float]) -> float: + num_rate = feature_values.get(f"{self._num_pop}.mean_rate", 0.0) + denom_rate = feature_values.get(f"{self._denom_pop}.mean_rate", 0.0) + if denom_rate > 0: + actual_ratio = num_rate / denom_rate + else: + actual_ratio = 0.0 + return -((actual_ratio - self._target_ratio) ** 2) + + +class MaximizeFeatureObjective(NetworkObjective): + """Objective to maximize a feature (e.g. PVBC fraction active).""" + + def __init__( + self, + required_features: List[str] = None, + pop_name: str = None, + name: Optional[str] = None, + ): + self._pop_name = pop_name + self._required_features = required_features + self._name = name + + @property + def name(self) -> str: + if self._name is not None: + return self._name + if self._required_features: + return self._required_features[0].replace(".", " ") + return "maximize_feature" + + @property + def required_features(self) -> List[str]: + return self._required_features + + def compute(self, feature_values: Dict[str, float]) -> float: + val = feature_values.get(self._required_features[0], 0.0) + return val # higher is better; framework negates for dmosopt + + +class CustomNetworkObjective(NetworkObjective): + """Custom objective with user-defined function.""" + + def __init__( + self, + name: str, + required_features: List[str], + fn: Callable[[Dict[str, float]], float], + ): + self._name = name + self._required_features = required_features + self._fn = fn + self._validate_callable() + + def _validate_callable(self) -> None: + if isinstance(self._fn, type(lambda: None)) and self._fn.__name__ == "": + raise TypeError("CustomNetworkObjective.fn must be a module-level function") + + @property + def name(self) -> str: + return self._name + + @property + def required_features(self) -> List[str]: + return self._required_features + + def compute(self, feature_values: Dict[str, float]) -> float: + return self._fn(feature_values) + + +class FiringRateBoundConstraint(NetworkConstraint): + """Constraint on firing rate bounds.""" + + def __init__( + self, + pop_name: str, + min_rate: float = 0.0, + max_rate: float = float("inf"), + name: Optional[str] = None, + ): + self._pop_name = pop_name + self._min_rate = min_rate + self._max_rate = max_rate + self._name = name or f"{pop_name}_rate_bound" + + @property + def name(self) -> str: + return self._name + + @property + def required_features(self) -> List[str]: + return [f"{self._pop_name}.mean_rate"] + + def compute(self, feature_values: Dict[str, float]) -> float: + rate = feature_values.get(f"{self._pop_name}.mean_rate", 0.0) + c1 = self._min_rate - rate + c2 = rate - self._max_rate + return max(c1, c2) + + +class MinActiveFractionConstraint(NetworkConstraint): + """Constraint on minimum fraction of active neurons.""" + + def __init__(self, pop_name: str, min_fraction: float, name: Optional[str] = None): + self._pop_name = pop_name + self._min_fraction = min_fraction + self._name = name or f"{pop_name}_min_active_fraction" + + @property + def name(self) -> str: + return self._name + + @property + def required_features(self) -> List[str]: + return [f"{self._pop_name}.fraction_active"] + + def compute(self, feature_values: Dict[str, float]) -> float: + frac = feature_values.get(f"{self._pop_name}.fraction_active", 0.0) + return self._min_fraction - frac + + +class SteadyFiringConstraint(NetworkConstraint): + """Constraint on maximum allowed CV for steady firing.""" + + def __init__(self, pop_name: str, max_cv: float, name: Optional[str] = None): + self._pop_name = pop_name + self._max_cv = max_cv + self._name = name or f"{pop_name}_steady_firing" + + @property + def name(self) -> str: + return self._name + + @property + def required_features(self) -> List[str]: + return [f"{self._pop_name}.rate_cv"] + + def compute(self, feature_values: Dict[str, float]) -> float: + rate_cv = feature_values.get(f"{self._pop_name}.rate_cv", 0.0) + return rate_cv - self._max_cv + + +def load_network_opt_config( + config: Dict, +) -> NetworkOptimizationConfig: + """ + Reads the "Network Optimization" namespace from the given config dict and + builds a NetworkOptimizationConfig by importing classes via importlib. + + If the namespace is absent, returns an empty config (user must populate it). + """ + + network_opt_config = config.get("Network Optimization", {}) + + if not network_opt_config: + return NetworkOptimizationConfig() + + feature_entries = network_opt_config.get("Features", []) + objective_entries = network_opt_config.get("Objectives", []) + constraint_entries = network_opt_config.get("Constraints", []) + + features = _load_feature_list(feature_entries) + objectives = _load_objective_list(objective_entries) + constraints = _load_constraint_list(constraint_entries) + + return NetworkOptimizationConfig( + features=features, + objectives=objectives, + constraints=constraints, + ) + + +def _load_feature_list(entries: List[Dict]) -> List[NetworkFeature]: + import importlib + + features = [] + for entry in entries: + if isinstance(entry, dict): + class_path = entry.get("class") + if not class_path: + raise ValueError("Feature entry must have 'class' key") + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + cls = getattr(module, class_name) + instance = cls(**entry.get("kwargs", {})) + if "populations" in entry: + instance.populations = entry["populations"] + features.append(instance) + else: + raise ValueError(f"Invalid feature entry: {entry}") + return features + + +def _load_objective_list(entries: List[Dict]) -> List[NetworkObjective]: + import importlib + + objectives = [] + for entry in entries: + if isinstance(entry, dict): + class_path = entry.get("class") + if not class_path: + raise ValueError("Objective entry must have 'class' key") + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + cls = getattr(module, class_name) + + # Promote pop_name / pop_names from top-level YAML to kwargs + kwargs = entry.get("kwargs", {}).copy() + if "pop_name" in entry: + kwargs["pop_name"] = entry["pop_name"] + if "pop_names" in entry: + kwargs["pop_names"] = entry["pop_names"] + if "num_pop" in entry: + kwargs["num_pop"] = entry["num_pop"] + if "denom_pop" in entry: + kwargs["denom_pop"] = entry["denom_pop"] + + instance = cls(**kwargs) + if "name" in entry: + instance._name = entry["name"] + objectives.append(instance) + else: + raise ValueError(f"Invalid objective entry: {entry}") + return objectives + + +def _load_constraint_list(entries: List[Dict]) -> List[NetworkConstraint]: + import importlib + + constraints = [] + for entry in entries: + if isinstance(entry, dict): + class_path = entry.get("class") + if not class_path: + raise ValueError("Constraint entry must have 'class' key") + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + cls = getattr(module, class_name) + + # Promote pop_name from top-level YAML to kwargs + kwargs = entry.get("kwargs", {}).copy() + if "pop_name" in entry: + kwargs["pop_name"] = entry["pop_name"] + + instance = cls(**kwargs) + if "name" in entry: + instance._name = entry["name"] + constraints.append(instance) + else: + raise ValueError(f"Invalid constraint entry: {entry}") + return constraints diff --git a/src/miv_simulator/optimize_network.py b/src/miv_simulator/optimize_network.py index 20970b2..cfca28a 100644 --- a/src/miv_simulator/optimize_network.py +++ b/src/miv_simulator/optimize_network.py @@ -3,7 +3,6 @@ Network model optimization script for optimization with dmosopt """ -import gc import os import sys import datetime @@ -31,6 +30,9 @@ update_network_params, network_features, ) +from miv_simulator.network_objectives import ( + load_network_opt_config, +) from dmosopt import dmosopt from dmosopt.MOASMO import get_best @@ -176,9 +178,16 @@ def optimize_network( network_config.update(operational_config.get("kwargs", {})) env = Env(**network_config) - objective_names = operational_config["objective_names"] param_config_name = operational_config["param_config_name"] - target_populations = operational_config["target_populations"] + + # Load optimization config from operational_config (derives target_populations + # from Objectives and Constraints in the YAML) + opt_config = load_network_opt_config(operational_config) + opt_config.validate_picklable() + + target_populations = opt_config.target_populations() + # Store for downstream consumers (e.g. init_network_objfun) + operational_config["target_populations"] = target_populations opt_param_config = optimization_params( env.netclamp_config.optimize_parameters, @@ -212,11 +221,10 @@ def optimize_network( if resample_fraction < 0.1: resample_fraction = 0.1 - # Create an optimizer - feature_dtypes = [(feature_name, np.float32) for feature_name in objective_names] - constraint_names = [ - f"{target_pop_name} positive rate" for target_pop_name in target_populations - ] + feature_dtypes = opt_config.feature_dtypes() + constraint_names = opt_config.constraint_names() + objective_names_opt = opt_config.objective_names() + surrogate_method_kwargs = copy.copy(surrogate_method_kwargs) if "batch_size" not in surrogate_method_kwargs: surrogate_method_kwargs["batch_size"] = 400 @@ -231,10 +239,10 @@ def optimize_network( "use_coreneuron": network_config.get("use_coreneuron", False), }, "reduce_fun_name": "miv_simulator.optimize_network.compute_objectives", - "reduce_fun_args": (operational_config, opt_targets), + "reduce_fun_args": (operational_config, opt_targets, opt_config), "problem_parameters": {}, "space": hyperprm_space, - "objective_names": objective_names, + "objective_names": objective_names_opt, "feature_dtypes": feature_dtypes, "constraint_names": constraint_names, "n_initial": n_initial, @@ -307,7 +315,6 @@ def init_network_objfun( f"optimize_network_{worker.worker_id}_{operational_config['run_ts']}" ) nprocs_per_worker = operational_config["nprocs_per_worker"] - # logger = get_script_logger(os.path.basename(__file__)) env = init_network( comm=worker.merged_comm, subworld_size=nprocs_per_worker, kwargs=kwargs ) @@ -371,111 +378,57 @@ def network_objfun( return network_features(env, t_start, t_stop, target_populations) -def compute_objectives(local_features, operational_config, opt_targets): +def _merge_pop_features(pop_features_dicts): + """Aggregate per-worker per-population raw spike dicts into one. + + Sums n_total, n_active; merges spike_density_dict keys; uses first worker's time_bins. + """ + if not pop_features_dicts: + return {} + + result = { + "n_total": 0, + "n_active": 0, + "time_bins": None, + "spike_density_dict": {}, + } + + for pop_feature_dict in pop_features_dicts: + result["n_total"] += pop_feature_dict["n_total"] + result["n_active"] += pop_feature_dict["n_active"] + if result["time_bins"] is None: + result["time_bins"] = pop_feature_dict["time_bins"] + result["spike_density_dict"].update(pop_feature_dict["spike_density_dict"]) + + return result + + +def compute_objectives(local_features, operational_config, opt_targets, opt_config): all_features_dict = {} - constraints = [] - active_threshold = 0.01 - target_populations = operational_config["target_populations"] - temporal_resolution = operational_config["temporal_resolution"] + target_populations = opt_config.target_populations() + for pop_name in target_populations: pop_features_dicts = [ features_dict[0][pop_name] for features_dict in local_features ] - - sum_mean_rate = 0.0 - n_total = 0 - n_active = 0 - time_bins_ref = None - sum_active_per_bin = None - for pop_feature_dict in pop_features_dicts: - n_active_local = pop_feature_dict["n_active"] - n_total_local = pop_feature_dict["n_total"] - time_bins = pop_feature_dict["time_bins"] - if time_bins_ref is None: - time_bins_ref = time_bins - spike_density_dict = pop_feature_dict["spike_density_dict"] - sum_mean_rate_local = 0.0 - t_start = time_bins_ref[0] - t_end = time_bins_ref[-1] + (time_bins_ref[1] - time_bins_ref[0]) - # time bins for fraction active per time bin calculation - fr_time_bins = np.arange(t_start, t_end, temporal_resolution) - fr_time_centers = (fr_time_bins + temporal_resolution / 2).astype( - np.float32 - ) - if sum_active_per_bin is None: - sum_active_per_bin = np.zeros_like(fr_time_centers, dtype=np.float32) - for gid, dens_dict in spike_density_dict.items(): - mean_rate = np.mean(dens_dict["rate"]) - if mean_rate > 0.0: - sum_mean_rate_local += mean_rate - ip_rate = np.interp( - fr_time_centers, - time_bins_ref, - dens_dict["rate"].astype(np.float32), - ).astype(np.float32) - active_per_bin = ip_rate > active_threshold - sum_active_per_bin += active_per_bin - - n_total += n_total_local - n_active += n_active_local - sum_mean_rate += sum_mean_rate_local - - if n_active > 0: - mean_rate = sum_mean_rate / n_active - else: - mean_rate = 0.0 - - if n_total > 0: - fraction_active = n_active / n_total - mean_fraction_active_per_bin = np.mean(sum_active_per_bin / float(n_total)) - std_fraction_active_per_bin = np.std(sum_active_per_bin / float(n_total)) - else: - fraction_active = 0.0 - mean_fraction_active_per_bin = 0.0 - std_fraction_active_per_bin = 0.0 - - logger.info( - f"population {pop_name}: n_active = {n_active} n_total = {n_total} mean rate = {mean_rate}" - ) - - all_features_dict[f"{pop_name} mean fraction active per time bin"] = float( - mean_fraction_active_per_bin - ) - all_features_dict[f"{pop_name} std fraction active per time bin"] = float( - std_fraction_active_per_bin - ) - all_features_dict[f"{pop_name} fraction active"] = float(fraction_active) - all_features_dict[f"{pop_name} firing rate"] = float(mean_rate) - - rate_constr = mean_rate if mean_rate > 0.0 else -1.0 - constraints.append(rate_constr) - - gc.collect() - - objective_names = operational_config["objective_names"] - feature_dtypes = [(feature_name, np.float32) for feature_name in objective_names] - - target_vals = opt_targets - objectives = [] - features = [] - for key in objective_names: - feature_val = all_features_dict[key] - if key in target_vals: - objective = (feature_val - target_vals[key]) ** 2 - logger.info( - f"objective {key}: {objective} target: {target_vals[key]} feature: {feature_val}" - ) - else: - objective = -feature_val - logger.info(f"objective {key}: {objective} feature: {feature_val}") - objectives.append(objective) - features.append(feature_val) + merged = _merge_pop_features(pop_features_dicts) + for feature in opt_config.features: + if feature.populations is not None and pop_name not in feature.populations: + continue + for fname, fval in feature.compute(pop_name, merged).items(): + all_features_dict[f"{pop_name}.{fname}"] = fval + + objectives = [-obj.compute(all_features_dict) for obj in opt_config.objectives] + features = [ + all_features_dict.get(obj.required_features[0], 0.0) + for obj in opt_config.objectives + ] + constraints = [c.compute(all_features_dict) for c in opt_config.constraints] result = ( np.asarray(objectives, dtype=np.float32), - np.array([tuple(features)], dtype=np.dtype(feature_dtypes)), + np.array([tuple(features)], dtype=np.dtype(opt_config.feature_dtypes())), np.asarray(constraints, dtype=np.float32), ) - return {0: result} diff --git a/src/miv_simulator/plotting.py b/src/miv_simulator/plotting.py index aaa8b1e..f4ea81b 100644 --- a/src/miv_simulator/plotting.py +++ b/src/miv_simulator/plotting.py @@ -868,7 +868,7 @@ def plot_spike_raster( sct = axes[i].scatter( this_pop_spkts, this_pop_spkinds, - s=1, + s=0.1, linewidths=fig_options.lw, marker=marker, c=pop_colors[pop_name], @@ -944,7 +944,8 @@ def plot_spike_raster( ax.set_position([box.x0, box.y0, box.width * 0.85, box.height]) if pop_rates: lgd_labels = [ - f"{pop_name} ({info[0]:.02f}% active; {info[1]:.3g} Hz)" + # f"{pop_name} ({info[0]:.02f}% active; {info[1]:.3g} Hz)" + f"{pop_name} ({info[1]:.3g} Hz)" for pop_name, info in zip_longest(spkpoplst, lgd_info) ] else: @@ -960,7 +961,7 @@ def plot_spike_raster( fontsize="small", scatterpoints=1, markerscale=5.0, - bbox_to_anchor=(1.002, 0.5), + bbox_to_anchor=(1.0025, 0.5), bbox_transform=plt.gcf().transFigure, ) fig.artists.append(lgd) diff --git a/tests/test_network_objectives.py b/tests/test_network_objectives.py new file mode 100644 index 0000000..a277220 --- /dev/null +++ b/tests/test_network_objectives.py @@ -0,0 +1,280 @@ +""" +Integration tests for the network optimization objectives framework. +These tests verify the network_objectives module without importing +miv_simulator. +""" + +import sys +import importlib.util +import numpy as np + +# Load network_objectives module directly without triggering miv_simulator __init__ +spec = importlib.util.spec_from_file_location( + "network_objectives", "src/miv_simulator/network_objectives.py" +) +mod = importlib.util.module_from_spec(spec) +sys.modules["network_objectives"] = mod +spec.loader.exec_module(mod) + + +def make_pop_dict(n_total, n_active, rates_by_gid, n_bins=50, dt=2.0): + """Helper: build a synthetic population raw feature dict.""" + time_bins = np.arange(0, n_bins * dt, dt, dtype=np.float32) + spike_density_dict = { + gid: {"rate": np.asarray(r, dtype=np.float32)} + for gid, r in rates_by_gid.items() + } + return { + "n_total": n_total, + "n_active": n_active, + "time_bins": time_bins, + "spike_density_dict": spike_density_dict, + } + + +# --------------------------------------------------------------------------- +# Network Feature Tests +# --------------------------------------------------------------------------- + + +def test_mean_firing_rate_feature(): + f = mod.MeanFiringRateFeature() + assert f.feature_names == ["mean_rate"] + + pop_dict = make_pop_dict(100, 80, {gid: np.full(50, 5.0) for gid in range(80)}) + result = f.compute("CA3", pop_dict) + assert "mean_rate" in result + assert abs(result["mean_rate"] - 5.0) < 0.001 + + # Silent population + pop_dict_silent = make_pop_dict(100, 0, {}) + result_silent = f.compute("CA3", pop_dict_silent) + assert result_silent["mean_rate"] == 0.0 + + print(" test_mean_firing_rate_feature passed") + + +def test_fraction_active_feature(): + f = mod.FractionActiveFeature() + assert f.feature_names == ["fraction_active"] + + pop_dict = make_pop_dict(100, 80, {}) + result = f.compute("CA3", pop_dict) + assert abs(result["fraction_active"] - 0.8) < 0.001 + + # Silent population + pop_dict_silent = make_pop_dict(100, 0, {}) + result_silent = f.compute("CA3", pop_dict_silent) + assert result_silent["fraction_active"] == 0.0 + + print(" test_fraction_active_feature passed") + + +def test_firing_rate_stability_feature(): + f = mod.FiringRateStabilityFeature(temporal_resolution=2.0) + assert f.feature_names == [ + "mean_fraction_active_per_bin", + "std_fraction_active_per_bin", + "rate_cv", + ] + + # Steady firing + steady_rate = np.full(50, 5.0, dtype=np.float32) + pop_dict = make_pop_dict(100, 80, {gid: steady_rate for gid in range(80)}) + result = f.compute("CA3", pop_dict) + assert result["rate_cv"] < 0.05 + + # Burst-then-silence + early_rate = np.array([20.0] * 10 + [0.0] * 40, dtype=np.float32) + pop_dict_burst = make_pop_dict(100, 100, {gid: early_rate for gid in range(100)}) + result_burst = f.compute("CA3", pop_dict_burst) + assert result_burst["rate_cv"] > 1.0 + + # Silent population + pop_dict_silent = make_pop_dict(100, 0, {}) + result_silent = f.compute("CA3", pop_dict_silent) + assert result_silent["rate_cv"] == 0.0 + + print(" test_firing_rate_stability_feature passed") + + +# --------------------------------------------------------------------------- +# Network Objective Tests +# --------------------------------------------------------------------------- + + +def test_target_rate_objective(): + o = mod.TargetRateObjective("CA3", target_rate=5.0) + assert o.name == "CA3_target_rate" + assert o.required_features == ["CA3.mean_rate"] + + # Perfect match + score = o.compute({"CA3.mean_rate": 5.0}) + assert abs(score) < 0.001 + + # Off target + score = o.compute({"CA3.mean_rate": 3.0}) + assert score < 0 + + print(" test_target_rate_objective passed") + + +def test_steady_firing_objective(): + o = mod.SteadyFiringObjective("CA3") + assert o.required_features == ["CA3.rate_cv"] + + # Low CV is good (higher score) + score_low = o.compute({"CA3.rate_cv": 0.1}) + score_high = o.compute({"CA3.rate_cv": 2.0}) + assert score_low > score_high + + print(" test_steady_firing_objective passed") + + +def test_multi_pop_steady_firing_objective(): + o = mod.MultiPopSteadyFiringObjective(["CA3", "DG"], name="test") + assert o.name == "test" + assert o.required_features == ["CA3.rate_cv", "DG.rate_cv"] + + score = o.compute({"CA3.rate_cv": 0.05, "DG.rate_cv": 2.5}) + assert abs(score - -np.mean([0.05, 2.5])) < 0.001 + + print(" test_multi_pop_steady_firing_objective passed") + + +def test_population_rate_ratio_objective(): + o = mod.PopulationRateRatioObjective("CA3", "DG", target_ratio=3.0) + assert o.required_features == ["CA3.mean_rate", "DG.mean_rate"] + + # Perfect ratio + score = o.compute({"CA3.mean_rate": 15.0, "DG.mean_rate": 5.0}) + assert abs(score) < 0.001 + + print(" test_population_rate_ratio_objective passed") + + +def test_custom_network_objective(): + def custom_fn(feature_values): + return -1.0 + + o = mod.CustomNetworkObjective("custom", ["CA3.mean_rate"], custom_fn) + assert o.name == "custom" + assert o.compute({}) == -1.0 + + # Lambda should raise + try: + mod.CustomNetworkObjective("bad", [], lambda x: 0.0) + assert False, "Should have raised TypeError" + except TypeError: + pass + + print(" test_custom_network_objective passed") + + +# --------------------------------------------------------------------------- +# Network Constraint Tests +# --------------------------------------------------------------------------- + + +def test_firing_rate_bound_constraint(): + c = mod.FiringRateBoundConstraint("CA3", min_rate=0.5) + assert c.name == "CA3_rate_bound" + assert c.required_features == ["CA3.mean_rate"] + + # Below min -> infeasible (positive) + val = c.compute({"CA3.mean_rate": 0.0}) + assert val > 0 + + # Above min -> feasible (negative) + val = c.compute({"CA3.mean_rate": 1.0}) + assert val <= 0 + + print(" test_firing_rate_bound_constraint passed") + + +def test_steady_firing_constraint(): + c = mod.SteadyFiringConstraint("CA3", max_cv=0.5) + assert c.required_features == ["CA3.rate_cv"] + + # Within bound -> feasible + val = c.compute({"CA3.rate_cv": 0.1}) + assert val <= 0 + + # Exceeds bound -> infeasible + val = c.compute({"CA3.rate_cv": 1.0}) + assert val > 0 + + print(" test_steady_firing_constraint passed") + + +# --------------------------------------------------------------------------- +# Config and Registry Tests +# --------------------------------------------------------------------------- + + +def test_feature_registry(): + registry = mod.FeatureRegistry + registry.clear() + + f = mod.MeanFiringRateFeature() + registry.register(f) + assert "mean_rate" in registry.list() + + registry.clear() + assert len(registry.list()) == 0 + + print(" test_feature_registry passed") + + +def test_network_optimization_config(): + f = mod.MeanFiringRateFeature() + o = mod.TargetRateObjective("CA3", target_rate=5.0) + c = mod.FiringRateBoundConstraint("CA3", min_rate=0.5) + + config = mod.NetworkOptimizationConfig( + features=[f], objectives=[o], constraints=[c] + ) + + assert config.objective_names() == ["CA3_target_rate"] + assert config.constraint_names() == ["CA3_rate_bound"] + assert len(config.feature_dtypes()) == 1 + + # Should be picklable + config.validate_picklable() + + print(" test_network_optimization_config passed") + + +def test_load_network_opt_config_default(): + """Default config when no 'Network Optimization' section is present returns empty config.""" + env_config = {} + opt_config = mod.load_network_opt_config(env_config) + assert len(opt_config.features) == 0 + assert len(opt_config.objectives) == 0 + assert len(opt_config.constraints) == 0 + assert opt_config.target_populations() == [] + print(" test_load_network_opt_config_default passed") + + +# --------------------------------------------------------------------------- +# Run all tests +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + print("Running network_objectives integration tests...") + + test_mean_firing_rate_feature() + test_fraction_active_feature() + test_firing_rate_stability_feature() + test_target_rate_objective() + test_steady_firing_objective() + test_multi_pop_steady_firing_objective() + test_population_rate_ratio_objective() + test_custom_network_objective() + test_firing_rate_bound_constraint() + test_steady_firing_constraint() + test_feature_registry() + test_network_optimization_config() + test_load_network_opt_config_default() + + print("\nAll integration tests passed!") diff --git a/tests/test_optimize_network_api.py b/tests/test_optimize_network_api.py new file mode 100644 index 0000000..b8f7fec --- /dev/null +++ b/tests/test_optimize_network_api.py @@ -0,0 +1,419 @@ +""" +Integration tests for optimize_network module using the new +_network_objectives API. These tests do not import miv_simulator. +""" + +import sys +import types +import numpy as np +import importlib.util + +# --------------------------------------------------------------------------- +# Build a fake ``dmosopt`` package with a real ``MOASMO`` subpackage so +# ``from dmosopt.MOASMO import get_best`` succeeds. +# --------------------------------------------------------------------------- + +# Fake dmosopt package +dmosopt_pkg = types.ModuleType("dmosopt") +dmosopt_pkg.__path__ = [] # makes it a package + +# Fake dmosopt.dmosopt submodule +dmosopt_dmo = types.ModuleType("dmosopt.dmosopt") + +# Fake dmosopt.MOASMO submodule +dmosopt_moasmo = types.ModuleType("dmosopt.MOASMO") +dmosopt_moasmo.get_best = lambda *a, **k: None + +# Wire them up +dmosopt_pkg.dmosopt = dmosopt_dmo +dmosopt_pkg.MOASMO = dmosopt_moasmo + +sys.modules["dmosopt"] = dmosopt_pkg +sys.modules["dmosopt.dmosopt"] = dmosopt_dmo +sys.modules["dmosopt.MOASMO"] = dmosopt_moasmo + +# Other external deps +sys.modules["neuron"] = types.ModuleType("neuron") +sys.modules["neuron"].h = None +sys.modules["click"] = types.ModuleType("click") +sys.modules["click"].get_current_context = lambda: None +sys.modules["mpi4py"] = types.ModuleType("mpi4py") +sys.modules["mpi4py"].MPI = types.ModuleType("mpi4py.MPI") +sys.modules["mpi4py"].MPI.COMM_WORLD = types.SimpleNamespace(size=1) + +# --------------------------------------------------------------------------- +# Build a fake ``miv_simulator`` package tree so the real optimize_network +# module can execute its imports. +# --------------------------------------------------------------------------- + +fake_miv = types.ModuleType("miv_simulator") + +# env +fake_env_mod = types.ModuleType("miv_simulator.env") +fake_env_mod.Env = lambda **kw: None # placeholder +fake_miv.env = fake_env_mod + +# network +fake_network_mod = types.ModuleType("miv_simulator.network") +fake_miv.network = fake_network_mod + +# mechanisms +fake_mechanisms_mod = types.ModuleType("miv_simulator.mechanisms") +fake_mechanisms_mod.compile_and_load = lambda **kw: None +fake_miv.mechanisms = fake_mechanisms_mod + +# utils +fake_utils_mod = types.ModuleType("miv_simulator.utils") +fake_utils_mod.read_from_yaml = lambda x: {} +fake_utils_mod.write_to_yaml = lambda p, d: None +fake_utils_mod.get_module_logger = lambda name: types.SimpleNamespace( + info=lambda *a, **k: None +) +fake_miv.utils = fake_utils_mod + +# synapses +fake_synapses_mod = types.ModuleType("miv_simulator.synapses") +_syn = lambda **kw: None # dummy SynParam-like object # noqa: E731 +_syn._asdict = lambda: {} +fake_synapses_mod.syn_param_from_dict = lambda d: _syn +fake_synapses_mod.SynParam = type("SynParam", (), {}) +fake_miv.synapses = fake_synapses_mod + +# optimization +fake_optimization_mod = types.ModuleType("miv_simulator.optimization") +fake_optimization_mod.optimization_params = lambda *a, **k: None +fake_optimization_mod.update_network_params = lambda env, ptv: None +fake_optimization_mod.network_features = lambda env, t1, t2, pops: {} +fake_miv.optimization = fake_optimization_mod + +# Load the real network_objectives module + +spec_no = importlib.util.spec_from_file_location( + "miv_simulator.network_objectives", + "src/miv_simulator/network_objectives.py", +) +mod_no = importlib.util.module_from_spec(spec_no) +spec_no.loader.exec_module(mod_no) +fake_miv.network_objectives = mod_no + +# Register the package tree +sys.modules["miv_simulator"] = fake_miv +sys.modules["miv_simulator.env"] = fake_miv.env +sys.modules["miv_simulator.network"] = fake_miv.network +sys.modules["miv_simulator.mechanisms"] = fake_miv.mechanisms +sys.modules["miv_simulator.utils"] = fake_miv.utils +sys.modules["miv_simulator.synapses"] = fake_miv.synapses +sys.modules["miv_simulator.optimization"] = fake_miv.optimization +sys.modules["miv_simulator.network_objectives"] = fake_miv.network_objectives + +# Finally load the real optimize_network module +spec2 = importlib.util.spec_from_file_location( + "miv_simulator.optimize_network", + "src/miv_simulator/optimize_network.py", +) +mod_on = importlib.util.module_from_spec(spec2) +spec2.loader.exec_module(mod_on) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_pop_dict(n_total, n_active, rates_by_gid, n_bins=50, dt=2.0): + time_bins = np.arange(0, n_bins * dt, dt, dtype=np.float32) + spike_density_dict = { + gid: {"rate": np.asarray(r, dtype=np.float32)} + for gid, r in rates_by_gid.items() + } + return { + "n_total": n_total, + "n_active": n_active, + "time_bins": time_bins, + "spike_density_dict": spike_density_dict, + } + + +# --------------------------------------------------------------------------- +# _merge_pop_features tests +# --------------------------------------------------------------------------- + + +def test_merge_two_workers(): + w0 = make_pop_dict( + 50, 40, {gid: np.full(50, 5.0, dtype=np.float32) for gid in range(50)} + ) + w1 = make_pop_dict( + 50, 35, {gid: np.full(50, 5.0, dtype=np.float32) for gid in range(50, 100)} + ) + merged = mod_on._merge_pop_features([w0, w1]) + assert merged["n_total"] == 100 + assert merged["n_active"] == 75 + assert len(merged["spike_density_dict"]) == 100 + print(" test_merge_two_workers passed") + + +def test_merge_single_worker(): + w0 = make_pop_dict( + 50, 40, {gid: np.full(50, 5.0, dtype=np.float32) for gid in range(50)} + ) + merged = mod_on._merge_pop_features([w0]) + assert merged["n_total"] == 50 + assert merged["n_active"] == 40 + assert len(merged["spike_density_dict"]) == 50 + print(" test_merge_single_worker passed") + + +def test_merge_empty(): + merged = mod_on._merge_pop_features([]) + assert merged == {} + print(" test_merge_empty passed") + + +# --------------------------------------------------------------------------- +# compute_objectives tests +# --------------------------------------------------------------------------- + + +def test_compute_objectives_steady_firing(): + steady_rate = np.full(50, 5.0, dtype=np.float32) + features_dict = { + "CA3": make_pop_dict(100, 80, {gid: steady_rate for gid in range(80)}), + } + local_features = [{0: features_dict}] + + operational_config = { + "target_populations": ["CA3"], + "temporal_resolution": 2.0, + } + opt_config = mod_no.NetworkOptimizationConfig( + features=[ + mod_no.MeanFiringRateFeature(), + mod_no.FractionActiveFeature(), + mod_no.FiringRateStabilityFeature(temporal_resolution=2.0), + ], + objectives=[mod_no.TargetRateObjective("CA3", target_rate=5.0)], + constraints=[mod_no.FiringRateBoundConstraint("CA3", min_rate=0.5)], + ) + result = mod_on.compute_objectives( + local_features, operational_config, {}, opt_config + ) + objectives, features, constraints = result[0] + assert objectives.shape == (1,) + assert constraints.shape == (1,) + assert constraints[0] <= 0.0 # feasible + print(" test_compute_objectives_steady_firing passed") + + +def test_compute_objectives_burst_then_silence(): + early_rate = np.array([20.0] * 10 + [0.0] * 40, dtype=np.float32) + features_dict = { + "CA3": make_pop_dict(100, 100, {gid: early_rate for gid in range(100)}), + } + local_features = [{0: features_dict}] + + operational_config = { + "target_populations": ["CA3"], + "temporal_resolution": 2.0, + } + opt_config = mod_no.NetworkOptimizationConfig( + features=[ + mod_no.MeanFiringRateFeature(), + mod_no.FiringRateStabilityFeature(temporal_resolution=2.0), + ], + objectives=[mod_no.SteadyFiringObjective("CA3")], + constraints=[mod_no.SteadyFiringConstraint("CA3", max_cv=0.5)], + ) + result = mod_on.compute_objectives( + local_features, operational_config, {}, opt_config + ) + objectives, features, constraints = result[0] + assert constraints[0] > 0.0 # infeasible (CV too high) + # compute_objectives negates for dmosopt minimizer: -(-2.0) = 2.0 + assert objectives[0] > 1.0 # heavily penalized in dmosopt space + print(" test_compute_objectives_burst_then_silence passed") + + +def test_compute_objectives_silent_population(): + features_dict = { + "CA3": make_pop_dict(100, 0, {}), + } + local_features = [{0: features_dict}] + + operational_config = { + "target_populations": ["CA3"], + "temporal_resolution": 2.0, + } + opt_config = mod_no.NetworkOptimizationConfig( + features=[ + mod_no.MeanFiringRateFeature(), + mod_no.FractionActiveFeature(), + mod_no.FiringRateStabilityFeature(temporal_resolution=2.0), + ], + objectives=[mod_no.TargetRateObjective("CA3", target_rate=5.0)], + constraints=[ + mod_no.FiringRateBoundConstraint("CA3", min_rate=0.5), + mod_no.MinActiveFractionConstraint("CA3", min_fraction=0.1), + ], + ) + result = mod_on.compute_objectives( + local_features, operational_config, {}, opt_config + ) + objectives, features, constraints = result[0] + + # Both constraints should be infeasible for silent population + assert constraints[0] > 0.0 # rate too low + assert constraints[1] > 0.0 # fraction too low + print(" test_compute_objectives_silent_population passed") + + +def test_compute_objectives_cross_pop(): + steady_rate = np.full(50, 5.0, dtype=np.float32) + features_dict = { + "CA3": make_pop_dict(100, 80, {gid: steady_rate for gid in range(80)}), + "DG": make_pop_dict(50, 40, {gid: steady_rate for gid in range(40)}), + } + local_features = [{0: features_dict}] + + operational_config = { + "target_populations": ["CA3", "DG"], + "temporal_resolution": 2.0, + } + opt_config = mod_no.NetworkOptimizationConfig( + features=[ + mod_no.MeanFiringRateFeature(), + mod_no.FiringRateStabilityFeature(temporal_resolution=2.0), + ], + objectives=[ + mod_no.MultiPopSteadyFiringObjective(["CA3", "DG"], name="network_steady") + ], + constraints=[ + mod_no.SteadyFiringConstraint("CA3", max_cv=0.5), + mod_no.SteadyFiringConstraint("DG", max_cv=0.5), + ], + ) + result = mod_on.compute_objectives( + local_features, operational_config, {}, opt_config + ) + objectives, features, constraints = result[0] + assert objectives.shape == (1,) + assert constraints.shape == (2,) + assert all(c <= 0.0 for c in constraints) # all feasible + print(" test_compute_objectives_cross_pop passed") + + +def test_compute_objectives_multi_worker(): + steady_rate = np.full(50, 5.0, dtype=np.float32) + local0 = { + 0: { + "CA3": make_pop_dict(50, 40, {gid: steady_rate for gid in range(50)}), + } + } + local1 = { + 0: { + "CA3": make_pop_dict(50, 35, {gid: steady_rate for gid in range(50, 100)}), + } + } + local_features = [local0, local1] + + operational_config = { + "target_populations": ["CA3"], + "temporal_resolution": 2.0, + } + opt_config = mod_no.NetworkOptimizationConfig( + features=[mod_no.MeanFiringRateFeature()], + objectives=[mod_no.TargetRateObjective("CA3", target_rate=5.0)], + constraints=[mod_no.FiringRateBoundConstraint("CA3", min_rate=0.5)], + ) + result = mod_on.compute_objectives( + local_features, operational_config, {}, opt_config + ) + objectives, features, constraints = result[0] + assert constraints[0] <= 0.0 # feasible + print(" test_compute_objectives_multi_worker passed") + + +# --------------------------------------------------------------------------- +# YAML loading test +# --------------------------------------------------------------------------- + + +def test_yaml_loading(): + netclamp_config = { + "Network Optimization": { + "Features": [ + {"class": "miv_simulator.network_objectives.MeanFiringRateFeature"}, + {"class": "miv_simulator.network_objectives.FractionActiveFeature"}, + { + "class": "miv_simulator.network_objectives.FiringRateStabilityFeature", + "kwargs": {"temporal_resolution": 2.0}, + }, + ], + "Objectives": [ + { + "class": "miv_simulator.network_objectives.TargetRateObjective", + "kwargs": {"pop_name": "CA3", "target_rate": 5.0}, + } + ], + "Constraints": [ + { + "class": "miv_simulator.network_objectives.FiringRateBoundConstraint", + "kwargs": {"pop_name": "CA3", "min_rate": 0.5}, + } + ], + } + } + opt_config = mod_no.load_network_opt_config(netclamp_config) + assert len(opt_config.features) == 3 + assert len(opt_config.objectives) == 1 + assert len(opt_config.constraints) == 1 + print(" test_yaml_loading passed") + + +# --------------------------------------------------------------------------- +# Pickle validation tests +# --------------------------------------------------------------------------- + + +def test_pickle_validation(): + opt_config = mod_no.NetworkOptimizationConfig( + features=[mod_no.MeanFiringRateFeature()], + objectives=[mod_no.TargetRateObjective("CA3", target_rate=5.0)], + constraints=[mod_no.FiringRateBoundConstraint("CA3", min_rate=0.5)], + ) + opt_config.validate_picklable() # should not raise + print(" test_pickle_validation passed") + + +def test_pickle_validation_lambda_fails(): + try: + opt_config = mod_no.NetworkOptimizationConfig( + objectives=[mod_no.CustomNetworkObjective("bad", [], lambda x: 0.0)] + ) + opt_config.validate_picklable() + assert False, "Should have raised" + except (TypeError, Exception): + pass + print(" test_pickle_validation_lambda_fails passed") + + +# --------------------------------------------------------------------------- +# Run all tests +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + print("Running optimize_network integration tests...") + + test_merge_two_workers() + test_merge_single_worker() + test_merge_empty() + test_compute_objectives_steady_firing() + test_compute_objectives_burst_then_silence() + test_compute_objectives_silent_population() + test_compute_objectives_cross_pop() + test_compute_objectives_multi_worker() + test_yaml_loading() + test_pickle_validation() + test_pickle_validation_lambda_fails() + + print("\nAll optimize_network integration tests passed!")