I've found some peculiar behavior when simulating networks with a large number of cells. While increasing the cell count in the network has a negligible impact on the simulation time, it seems that internally the recordings need to be "converted" before further analysis.
This is seen whenever you use a function that accesses data in the recorded output (like plt.plot() or np.array()). Even stranger, it does not depend on the size of the recording. In the example below I simulate networks with increasing numbers of neurons, but only record the voltage of the first neuron.
import jaxley as jx
from jax import jit
import time
from jaxley.channels import Na
from jax import config
import numpy as np
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
cell = jx.Cell()
cell.insert(Na())
# Store simulation times, and output conversion times
sim_time_list, array_time_list = list(), list()
cell_num_sweep = np.arange(10, 1000, 100)
for num_cells in cell_num_sweep:
net = jx.Network([cell for _ in range(num_cells)])
def simulate(params):
return jx.integrate(net, params=params, t_max=1_000.0, delta_t=0.025)
jitted_simulate = jit(simulate)
params = net.get_parameters()
# Only record 1 variable
net.delete_recordings()
net.cell(0).record('v')
start_time = time.time()
v = jitted_simulate(params)
simulate_time = time.time() - start_time
print(f"Simulate time {simulate_time}")
start_time = time.time()
v_new = np.array(v)
array_time = time.time() - start_time
print(f"Array convert time {array_time}")
sim_time_list.append(simulate_time)
array_time_list.append(array_time)
It seems this linearly increases the time to convert the simulated output fron jnp.Array to np.array, but the size of the vector is the same, and the simulation time is not impacted:
import matplotlib.pyplot as plt
plt.plot(cell_num_sweep, array_time_list, label='Array Convert')
plt.plot(cell_num_sweep, sim_time_list, label='Sim Time')
plt.xlabel('Num Cells', fontsize=15)
plt.ylabel('Time (s)', fontsize=15)
plt.legend(fontsize=12)

I've found some peculiar behavior when simulating networks with a large number of cells. While increasing the cell count in the network has a negligible impact on the simulation time, it seems that internally the recordings need to be "converted" before further analysis.
This is seen whenever you use a function that accesses data in the recorded output (like
plt.plot()ornp.array()). Even stranger, it does not depend on the size of the recording. In the example below I simulate networks with increasing numbers of neurons, but only record the voltage of the first neuron.It seems this linearly increases the time to convert the simulated output fron
jnp.Arraytonp.array, but the size of the vector is the same, and the simulation time is not impacted: