Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "miv-simulator"
version = "0.3.0"
version = "0.4.0"
description = "Mind-In-Vitro simulator"
authors = []
dependencies = [
Expand Down
53 changes: 33 additions & 20 deletions src/miv_simulator/eval_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
update_network_params,
)
from miv_simulator.optimize_network import compute_objectives, init_network
from miv_simulator.network_objectives import (
load_network_opt_config,
)

logger = get_module_logger(__name__)

Expand Down Expand Up @@ -68,14 +71,17 @@ def eval_network(
operational_config = read_from_yaml(config_path)
network_config.update(operational_config.get("kwargs", {}))

target_populations = operational_config["target_populations"]
# Load optimization config (derives target_populations from Objectives/Constraints)
opt_config = load_network_opt_config(operational_config)
target_populations = opt_config.target_populations()
# Store back into operational_config for downstream consumers
operational_config["target_populations"] = target_populations

param_config_name = operational_config["param_config_name"]
objective_names = operational_config["objective_names"]

# Set results file id
# Sets results file id
network_config.setdefault("results_file_id", f"eval_network_{run_ts}")

# Initialize the network
comm = MPI.COMM_WORLD
env = init_network(comm=comm, subworld_size=None, kwargs=network_config)

Expand Down Expand Up @@ -117,7 +123,7 @@ def eval_network(
param_tuples = opt_param_config.param_tuples
opt_targets = opt_param_config.opt_targets

# Map parameter names to (param_tuple, value) pairs
# Map parameter names to (param_tuple, value) pairs and apply parameters to network
if params_dict is not None:
param_tuple_values = []
for param_name, param_tuple in zip(param_names, param_tuples):
Expand All @@ -130,7 +136,6 @@ def eval_network(
][p.param_path]
param_tuple_values.append((param_tuple, param_value))

# Apply parameters to the network
if rank == 0:
logger.info("Applying optimized parameters to network")
update_network_params(env, param_tuple_values)
Expand All @@ -143,11 +148,11 @@ def eval_network(
logger.info(f"Running simulation (t_stop={env.tstop} ms)")
network.run(env, output=False)

# Extract features from in-memory spike data before any output flushing
t_stop = env.tstop
# Extract features from in-memory data
features = network_features(env, t_start, t_stop, target_populations)

# Write simulation output to disk
# Write simulation output
if rank == 0:
logger.info(f"Writing output to {env.results_file_path}")
io_utils.mkout(env, env.results_file_path)
Expand All @@ -159,33 +164,41 @@ def eval_network(
io_utils.lfpout(env, env.results_file_path)

# Compute objectives using same reduction as the optimizer controller
result = compute_objectives([{0: features}], operational_config, opt_targets)
opt_config = load_network_opt_config(operational_config)
result = compute_objectives(
[{0: features}], operational_config, opt_targets, opt_config
)
objectives_arr, features_arr, constraints_arr = result[0]

# Log results
if rank == 0:
logger.info("=== Evaluation Results ===")
for name, val in zip(objective_names, objectives_arr.tolist()):
objective_names_opt = opt_config.objective_names()
constraint_names = opt_config.constraint_names()
for name, val in zip(objective_names_opt, objectives_arr.tolist()):
logger.info(f" objective {name}: {val:.6f}")
for name, val in zip(objective_names, features_arr[0].tolist()):
feature_names_opt = [ # noqa: F841
f"{p}.{f}"
for f in ["mean_rate", "fraction_active", "rate_cv"]
for p in target_populations
]
for name, val in zip(objective_names_opt, features_arr[0].tolist()):
logger.info(f" feature {name}: {val:.6f}")
for pop_name, val in zip(target_populations, constraints_arr.tolist()):
logger.info(f" constraint {pop_name} positive rate: {val:.6f}")
for name, val in zip(constraint_names, constraints_arr.tolist()):
logger.info(f" constraint {name}: {val:.6f}")

if output_path is not None:
output_data = {
params_label: {
"parameters": params_dict,
"objectives": dict(
zip(objective_names, [float(v) for v in objectives_arr])
zip(objective_names_opt, [float(v) for v in objectives_arr])
),
"features": dict(
zip(objective_names, [float(v) for v in features_arr[0]])
zip(objective_names_opt, [float(v) for v in features_arr[0]])
),
"constraints": dict(
zip(constraint_names, [float(c) for c in constraints_arr])
),
"constraints": {
f"{pop} positive rate": float(c)
for pop, c in zip(target_populations, constraints_arr)
},
}
}
with open(output_path, "w") as f:
Expand Down
113 changes: 56 additions & 57 deletions src/miv_simulator/mpi_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import platform
import shutil
import subprocess
import warnings


class MPIEnvError(RuntimeError):
Expand Down Expand Up @@ -176,63 +175,63 @@ def check_mpi_env(*, strict=False):
"and make sure 'mpicc' is available."
)

mpi_libdir = _mpicc_libdir()

# -- mpi4py --
mpi4py_lib = None
try:
so = _module_so("mpi4py.MPI")
if so and os.path.isfile(so):
mpi4py_lib = _mpi_lib_from_ldd(_shared_lib_deps(so))
if mpi4py_lib and mpi_libdir:
if not _same_mpi_library(mpi4py_lib, mpi_libdir):
raise MPIEnvError(
f"mpi4py links against {mpi4py_lib} but mpicc uses "
f"{mpi_libdir}. mpi4py was likely installed from a "
"pre-built wheel. Reinstall from source: "
"pip install --no-binary=mpi4py mpi4py"
)
except ImportError:
msg = (
"mpi4py is not installed. Install from source: "
'env MPICC="mpicc --shared" pip install --no-binary=mpi4py --force-reinstall --no-cache-dir mpi4py'
)
if strict:
raise MPIEnvError(msg)
warnings.warn(msg, stacklevel=2)
# mpi_libdir = _mpicc_libdir()

# # -- mpi4py --
# mpi4py_lib = None
# try:
# so = _module_so("mpi4py.MPI")
# if so and os.path.isfile(so):
# mpi4py_lib = _mpi_lib_from_ldd(_shared_lib_deps(so))
# if mpi4py_lib and mpi_libdir:
# if not _same_mpi_library(mpi4py_lib, mpi_libdir):
# raise MPIEnvError(
# f"mpi4py links against {mpi4py_lib} but mpicc uses "
# f"{mpi_libdir}. mpi4py was likely installed from a "
# "pre-built wheel. Reinstall from source: "
# "pip install --no-binary=mpi4py mpi4py"
# )
# except ImportError:
# msg = (
# "mpi4py is not installed. Install from source: "
# 'env MPICC="mpicc --shared" pip install --no-binary=mpi4py --force-reinstall --no-cache-dir mpi4py'
# )
# if strict:
# raise MPIEnvError(msg)
# warnings.warn(msg, stacklevel=2)

# -- h5py --
h5py_lib = None
try:
import h5py

if not getattr(h5py.get_config(), "mpi", False):
raise MPIEnvError(
"h5py is installed WITHOUT parallel-HDF5 (MPI) support. "
"Reinstall from source: "
'CC=mpicc HDF5_MPI="ON" pip install --no-binary=h5py --force-reinstall --no-cache-dir h5py'
)
for sub in ("h5py.h5", "h5py._conv", "h5py._errors"):
so = _module_so(sub)
if so:
h5py_lib = _mpi_lib_from_ldd(_shared_lib_deps(so))
if h5py_lib:
break
except ImportError:
msg = (
"h5py is not installed. Install with MPI support: "
'CC=mpicc HDF5_MPI="ON" pip install --no-binary=h5py --force-reinstall --no-cache-dir h5py'
)
if strict:
raise MPIEnvError(msg)
warnings.warn(msg, stacklevel=2)
# h5py_lib = None
# try:
# import h5py

# if not getattr(h5py.get_config(), "mpi", False):
# raise MPIEnvError(
# "h5py is installed WITHOUT parallel-HDF5 (MPI) support. "
# "Reinstall from source: "
# 'CC=mpicc HDF5_MPI="ON" pip install --no-binary=h5py --force-reinstall --no-cache-dir h5py'
# )
# for sub in ("h5py.h5", "h5py._conv", "h5py._errors"):
# so = _module_so(sub)
# if so:
# h5py_lib = _mpi_lib_from_ldd(_shared_lib_deps(so))
# if h5py_lib:
# break
# except ImportError:
# msg = (
# "h5py is not installed. Install with MPI support: "
# 'CC=mpicc HDF5_MPI="ON" pip install --no-binary=h5py --force-reinstall --no-cache-dir h5py'
# )
# if strict:
# raise MPIEnvError(msg)
# warnings.warn(msg, stacklevel=2)

# -- cross-library consistency --
if mpi4py_lib and h5py_lib:
if os.path.realpath(mpi4py_lib) != os.path.realpath(h5py_lib):
raise MPIEnvError(
"mpi4py and h5py link against DIFFERENT MPI libraries:\n"
f" mpi4py -> {os.path.realpath(mpi4py_lib)}\n"
f" h5py -> {os.path.realpath(h5py_lib)}\n"
"Reinstall both from source against the same MPI."
)
# if mpi4py_lib and h5py_lib:
# if os.path.realpath(mpi4py_lib) != os.path.realpath(h5py_lib):
# raise MPIEnvError(
# "mpi4py and h5py link against DIFFERENT MPI libraries:\n"
# f" mpi4py -> {os.path.realpath(mpi4py_lib)}\n"
# f" h5py -> {os.path.realpath(h5py_lib)}\n"
# "Reinstall both from source against the same MPI."
# )
Loading
Loading