Skip to content

Increasing cell count in network increases time to convert recorded arrays #559

@ntolley

Description

@ntolley

I've found some peculiar behavior when simulating networks with a large number of cells. While increasing the cell count in the network has a negligible impact on the simulation time, it seems that internally the recordings need to be "converted" before further analysis.

This is seen whenever you use a function that accesses data in the recorded output (like plt.plot() or np.array()). Even stranger, it does not depend on the size of the recording. In the example below I simulate networks with increasing numbers of neurons, but only record the voltage of the first neuron.

import jaxley as jx
from jax import jit
import time
from jaxley.channels import Na
from jax import config
import numpy as np
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

cell = jx.Cell()
cell.insert(Na())

# Store simulation times, and output conversion times
sim_time_list, array_time_list = list(), list()
cell_num_sweep = np.arange(10, 1000, 100)
for num_cells in cell_num_sweep:

    net = jx.Network([cell for _ in range(num_cells)])
    def simulate(params):
        return jx.integrate(net, params=params, t_max=1_000.0, delta_t=0.025)
    jitted_simulate = jit(simulate)
    
    params = net.get_parameters()

    # Only record 1 variable
    net.delete_recordings()
    net.cell(0).record('v')

    start_time = time.time()
    v = jitted_simulate(params)
    simulate_time = time.time() - start_time
    print(f"Simulate time {simulate_time}")

    start_time = time.time()
    v_new = np.array(v)
    array_time = time.time() - start_time
    print(f"Array convert time {array_time}")

    sim_time_list.append(simulate_time)
    array_time_list.append(array_time)

It seems this linearly increases the time to convert the simulated output fron jnp.Array to np.array, but the size of the vector is the same, and the simulation time is not impacted:

import matplotlib.pyplot as plt

plt.plot(cell_num_sweep, array_time_list, label='Array Convert')
plt.plot(cell_num_sweep, sim_time_list, label='Sim Time')
plt.xlabel('Num Cells', fontsize=15)
plt.ylabel('Time (s)', fontsize=15)
plt.legend(fontsize=12)

image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions