diff --git a/dev.ipynb b/dev.ipynb new file mode 100644 index 000000000..409104d71 --- /dev/null +++ b/dev.ipynb @@ -0,0 +1,853 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# %load_ext autoreload\n", + "# %autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from jax import config\n", + "config.update(\"jax_enable_x64\", True)\n", + "config.update(\"jax_platform_name\", \"cpu\")\n", + "\n", + "import os\n", + "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\".8\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from __future__ import annotations\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import pandas as pd\n", + "\n", + "import networkx as nx\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from jax import vmap\n", + "\n", + "from typing import Optional, List, Dict, Any, Union, Set, Tuple\n", + "from matplotlib.axes import Axes\n", + "from dataclasses import dataclass, field" + ] + }, + { + "cell_type": "code", + "execution_count": 549, + "metadata": {}, + "outputs": [], + "source": [ + "def pandas_to_nx(\n", + " node_attrs: pd.DataFrame, edge_attrs: pd.DataFrame, global_attrs: pd.Series\n", + ") -> nx.DiGraph:\n", + " \"\"\"Convert node_attrs, edge_attrs and global_attrs from pandas datatypes to a NetworkX DiGraph.\n", + "\n", + " Args:\n", + " node_attrs: DataFrame containing node attributes\n", + " edge_attrs: DataFrame containing edge attributes\n", + " global_attrs: Series containing global graph attributes\n", + "\n", + " Returns:\n", + " A directed graph with nodes, edges and global attributes from the input data.\n", + " \"\"\"\n", + " has_edge_attrs = None if edge_attrs.empty else True\n", + " G = nx.from_pandas_edgelist(\n", + " edge_attrs.reset_index(),\n", + " source=\"level_0\",\n", + " target=\"level_1\",\n", + " edge_attr=has_edge_attrs,\n", + " create_using=nx.DiGraph(),\n", + " )\n", + "\n", + " nx.set_node_attributes(G, node_attrs.to_dict(orient=\"index\"))\n", + " G.graph.update(global_attrs.to_dict())\n", + " return G\n", + "\n", + "\n", + "def nx_to_pandas(G: nx.DiGraph) -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series]:\n", + " \"\"\"Convert a NetworkX DiGraph to pandas datatypes.\n", + "\n", + " Args:\n", + " G: Input directed graph\n", + "\n", + " Returns:\n", + " Tuple containing:\n", + " - DataFrame of node attributes\n", + " - DataFrame of edge attributes\n", + " - Series of global graph attributes\n", + " \"\"\"\n", + " edge_df = nx.to_pandas_edgelist(G).set_index([\"source\", \"target\"])\n", + " edge_df.index.names = [None, None]\n", + " node_df = pd.DataFrame.from_dict(dict(G.nodes(data=True)), orient=\"index\")\n", + "\n", + " return node_df, edge_df, pd.Series(G.graph)\n", + "\n", + "\n", + "def swc_to_nx(fname: str, num_lines: Optional[int] = None) -> nx.DiGraph:\n", + " \"\"\"Read a SWC morphology file into a NetworkX DiGraph.\n", + "\n", + " Args:\n", + " fname: Path to the SWC file\n", + " num_lines: Number of lines to read from the file\n", + "\n", + " Returns:\n", + " A directed graph representing the morphology where:\n", + " - Nodes have attributes: id, x, y, z, r (radius)\n", + " - Edges represent parent-child relationships\n", + " \"\"\"\n", + " i_id_xyzr_p = np.loadtxt(fname)[:num_lines]\n", + "\n", + " graph = nx.DiGraph()\n", + " for i, id, x, y, z, r, p in i_id_xyzr_p.tolist(): # tolist: np.float64 -> float\n", + " graph.add_node(int(i), **{\"id\": int(id), \"x\": x, \"y\": y, \"z\": z, \"r\": r})\n", + " if p != -1:\n", + " graph.add_edge(int(p), int(i))\n", + " return graph\n", + "\n", + "\n", + "def nx_to_jax(G: nx.DiGraph) -> jax.tree_util.PyTree:\n", + " \"\"\"Convert a NetworkX DiGraph to a Jax tree.\n", + "\n", + " Args:\n", + " G: Input directed graph\n", + "\n", + " Returns:\n", + " A Jax tree representing the morphology.\n", + " \"\"\"\n", + "\n", + " inds, jax_node_attrs = jax.tree_util.tree_map(lambda *args: jnp.array(args), *G.nodes(data=True))\n", + " jax_node_attrs[\"index\"] = jnp.array(inds)\n", + "\n", + " *inds, jax_edge_attrs = jax.tree_util.tree_map(lambda *args: jnp.array(args), *G.edges(data=True))\n", + " jax_edge_attrs[\"index_pre\"] = jnp.array(inds[0])\n", + " jax_edge_attrs[\"index_post\"] = jnp.array(inds[1])\n", + "\n", + " jax_global_attrs = {k: jnp.array(v) for k, v in G.graph.items()}\n", + "\n", + " return jax_node_attrs, jax_edge_attrs, jax_global_attrs\n", + "\n", + "def jax_to_nx(jax_node_attrs: jax.tree_util.PyTree, jax_edge_attrs: jax.tree_util.PyTree, jax_global_attrs: jax.tree_util.PyTree) -> nx.DiGraph:\n", + " \"\"\"Convert a Jax tree to a NetworkX DiGraph.\n", + "\n", + " Args:\n", + " jax_node_attrs: Jax tree of node attributes\n", + " jax_edge_attrs: Jax tree of edge attributes\n", + " jax_global_attrs: Jax tree of global graph attributes\n", + "\n", + " Returns:\n", + " A NetworkX DiGraph representing the morphology.\n", + " \"\"\"\n", + " node_df, edge_df, global_attrs = jax_to_pandas(jax_node_attrs, jax_edge_attrs, jax_global_attrs)\n", + " return pandas_to_nx(node_df, edge_df, global_attrs)\n", + "\n", + "def jax_to_pandas(jax_node_attrs: jax.tree_util.PyTree, jax_edge_attrs: jax.tree_util.PyTree, jax_global_attrs: jax.tree_util.PyTree) -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series]:\n", + " \"\"\"Convert a Jax tree to pandas datatypes.\n", + "\n", + " Args:\n", + " jax_node_attrs: Jax tree of node attributes\n", + " jax_edge_attrs: Jax tree of edge attributes\n", + " jax_global_attrs: Jax tree of global graph attributes\n", + "\n", + " Returns:\n", + " Tuple containing:\n", + " - DataFrame of node attributes\n", + " - DataFrame of edge attributes\n", + " - Series of global graph attributes\n", + " \"\"\"\n", + "\n", + " node_index = np.array(jax_node_attrs[(\"index\",)])\n", + " node_attrs_df = pd.DataFrame({k:v.tolist() for k, v in jax_node_attrs.items()}, index=node_index).drop(columns=[(\"index\",)])\n", + " \n", + " # edge_index = pd.MultiIndex.from_arrays(np.vstack([jax_edge_attrs[\"index_pre\"], jax_edge_attrs[\"index_post\"]]))\n", + " # edge_attrs_df = pd.DataFrame({k:v.tolist() for k, v in jax_edge_attrs.items()}, index = edge_index).drop(columns=[\"index_pre\", \"index_post\"])\n", + " edge_index = np.array(jax_edge_attrs[\"index\"])\n", + " edge_attrs_df = pd.DataFrame({k:v.tolist() for k, v in jax_edge_attrs.items()}, index = edge_index).drop(columns=[\"index\"])\n", + "\n", + " global_attrs_df = pd.Series(jax_global_attrs)\n", + "\n", + " return node_attrs_df, edge_attrs_df, global_attrs_df\n", + "\n", + "def pandas_to_jax(node_df: pd.DataFrame, edge_df: pd.DataFrame, global_attrs: pd.Series) -> jax.tree_util.PyTree:\n", + " \"\"\"Convert pandas datatypes to a Jax tree.\n", + "\n", + " Args:\n", + " node_df: DataFrame of node attributes\n", + " edge_df: DataFrame of edge attributes\n", + " global_attrs: Series of global graph attributes\n", + " \"\"\"\n", + " node_attrs = node_df.to_dict(orient=\"index\")\n", + " edge_attrs = edge_df.to_dict(orient=\"index\")\n", + "\n", + " inds, jax_node_attrs = jax.tree_util.tree_map(lambda *args: jnp.array(args), *node_attrs.items())\n", + " jax_node_attrs[(\"index\", )] = jnp.array(inds)\n", + "\n", + " # *inds, jax_edge_attrs = jax.tree_util.tree_map(lambda *args: jnp.array(args), *edge_attrs.items())\n", + " # jax_edge_attrs[\"index_pre\"] = jnp.array(inds[0])\n", + " # jax_edge_attrs[\"index_post\"] = jnp.array(inds[1])\n", + " inds, jax_edge_attrs = jax.tree_util.tree_map(lambda *args: jnp.array(args), *edge_attrs.items())\n", + " jax_edge_attrs[\"index\"] = jnp.array(inds)\n", + "\n", + " jax_global_attrs = {k: jnp.array(v) for k, v in global_attrs.items()}\n", + "\n", + " return jax_node_attrs, jax_edge_attrs, jax_global_attrs" + ] + }, + { + "cell_type": "code", + "execution_count": 683, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Union, Any, Callable\n", + "\n", + "def tree_filter(tree, condition, do_x = None, do_y = None):\n", + " do_x = lambda x: None if do_x is None else do_x\n", + " do_y = lambda x: None if do_y is None else do_y\n", + "\n", + " update_if = lambda path, val: do_x(val) if condition(path, val) else do_y(val)\n", + " return jax.tree_util.tree_map_with_path(update_if, tree)\n", + "\n", + "def tree_apply_at(tree, func_mapper: Union[dict[str, callable], callable]):\n", + " if isinstance(func_mapper, Callable):\n", + " return jax.tree.map(lambda x: func_mapper(x), tree)\n", + " \n", + " def update_if_key_matches(path, value):\n", + " if (key := path[0].key) in func_mapper:\n", + " return func_mapper[key](value)\n", + " return value\n", + " return jax.tree_util.tree_map_with_path(update_if_key_matches, tree)\n", + "\n", + "def tree_set_at(tree, keys_values: Union[dict[str, Any], Any], inds = None):\n", + " inds = slice(None) if inds is None else inds\n", + " setter = lambda v: lambda x: x.at[inds].set(v)\n", + " if not isinstance(keys_values, dict):\n", + " return tree_apply_at(tree, setter(keys_values))\n", + " return tree_apply_at(tree, {k: setter(v) for k, v in keys_values.items()})\n", + "\n", + "def tree_get_at(tree, keys = None, inds = None):\n", + " inds = slice(None) if inds is None else inds\n", + " getter = lambda x: x.at[inds].get()\n", + " if keys is None:\n", + " return tree_apply_at(tree, getter)\n", + " return tree_apply_at(tree, {k: getter for k in keys})\n", + "\n", + "def tree_concat(tree, other, axis=0):\n", + " return jax.tree.map(lambda x, y: jnp.concatenate([x, y], axis=axis), tree, other)\n", + "\n", + "has_top_level_key = lambda d, key: any(k[0] == key for k in d)" + ] + }, + { + "cell_type": "code", + "execution_count": 618, + "metadata": {}, + "outputs": [], + "source": [ + "from dataclasses import dataclass, field, replace\n", + "\n", + "@dataclass\n", + "class TestModule:\n", + "\n", + " node_attrs: dict[Any, Any]\n", + " edge_attrs: dict[Any, Any]\n", + " global_attrs: dict[Any, Any]\n", + "\n", + " base: TestModule = None\n", + "\n", + " def __post_init__(self):\n", + " if self.base is None:\n", + " self.base = self\n", + " \n", + " if \"externals\" not in self.global_attrs:\n", + " self.global_attrs[\"externals\"] = {}\n", + " if \"channels\" not in self.global_attrs:\n", + " self.global_attrs[\"channels\"] = {}\n", + " if \"synapses\" not in self.global_attrs:\n", + " self.global_attrs[\"synapses\"] = {}\n", + "\n", + " @property\n", + " def _nodes_in_view(self):\n", + " return self.node_attrs[(\"index\",)]\n", + " \n", + " @property\n", + " def _edges_in_view(self):\n", + " return self.edge_attrs[\"index\"]\n", + "\n", + " @property\n", + " def _num_nodes(self):\n", + " return len(self._nodes_in_view)\n", + " \n", + " @property\n", + " def _num_edges(self):\n", + " return len(self._edges_in_view)\n", + "\n", + " def __repr__(self):\n", + " node_keys = list(self.node_attrs.keys())\n", + " edge_keys = list(self.edge_attrs.keys())\n", + " global_keys = list(self.global_attrs.keys())\n", + " return f\"TestModule(node_attrs={self._num_nodes}*{node_keys}, edge_attrs={self._num_edges}*{edge_keys}, global_attrs={global_keys})\"\n", + "\n", + " def _select_nodes(self, keys = None, inds = None):\n", + " return tree_get_at(self.node_attrs, keys=keys, inds=inds)\n", + "\n", + " def _select_edges(self, keys = None, inds = None):\n", + " return tree_get_at(self.edge_attrs, keys=keys, inds=inds)\n", + " \n", + " def _set_node_attrs(self, keys_values, inds = None):\n", + " inds = self._nodes_in_view if inds is None else inds\n", + " node_attrs = tree_set_at(self.base.node_attrs, keys_values=keys_values, inds=inds)\n", + " updated_base = replace(self.base, node_attrs=node_attrs)\n", + " updated_node_attrs = updated_base._select_nodes(inds=inds)\n", + " updated_view = replace(self, base=updated_base, node_attrs=updated_node_attrs)\n", + " return updated_view\n", + " \n", + " def _set_edge_attrs(self, keys_values, inds = None):\n", + " inds = self._edges_in_view if inds is None else inds\n", + " edge_attrs = tree_set_at(self.base.edge_attrs, keys_values=keys_values, inds=inds)\n", + " updated_base = replace(self.base, edge_attrs=edge_attrs)\n", + " updated_edge_attrs = updated_base._select_edges(inds=inds)\n", + " updated_view = replace(self, base=updated_base, edge_attrs=updated_edge_attrs)\n", + " return updated_view\n", + "\n", + " def select(self, nodes=None, edges=None):\n", + " node_inds = self._nodes_in_view if nodes is None else nodes\n", + " edge_inds = self._edges_in_view if edges is None else edges\n", + "\n", + " node_attrs = self._select_nodes(inds=node_inds)\n", + " edge_attrs = self._select_edges(inds=edge_inds)\n", + "\n", + " return replace(self, node_attrs=node_attrs, edge_attrs=edge_attrs, base=self.base)\n", + " \n", + " def _init_node_attrs(self, d: dict[str, Any], inds = None, init_value = jnp.nan):\n", + " inds = self._nodes_in_view if inds is None else inds\n", + " data_type = lambda x: x.dtype if isinstance(x, jnp.ndarray) else np.dtype(type(x))\n", + " init_node_attrs = jax.tree.map(lambda x: jnp.stack([init_value*x]*self.base._num_nodes, axis=0, dtype=data_type(x)), d)\n", + " base_node_attrs = self.base.node_attrs.copy()\n", + " base_node_attrs.update(init_node_attrs)\n", + " updated_base = replace(self.base, node_attrs=base_node_attrs)\n", + " init_view = replace(self, base=updated_base)\n", + " return init_view._set_node_attrs(keys_values=d, inds=inds)\n", + " \n", + " def _init_edge_attrs(self, d: dict[str, Any], inds = None, init_value = jnp.nan):\n", + " inds = self._edges_in_view if inds is None else inds\n", + " data_type = lambda x: x.dtype if isinstance(x, jnp.ndarray) else np.dtype(type(x))\n", + " init_edge_attrs = jax.tree.map(lambda x: jnp.stack([init_value*x]*self.base._num_edges, axis=0, dtype=data_type(x)), d)\n", + " base_edge_attrs = self.base.edge_attrs.copy()\n", + " base_edge_attrs.update(init_edge_attrs)\n", + " updated_base = replace(self.base, edge_attrs=base_edge_attrs)\n", + " init_view = replace(self, base=updated_base)\n", + " return init_view._set_edge_attrs(keys_values=d, inds=inds)\n", + " \n", + " @property\n", + " def pandas(self):\n", + " return jax_to_pandas(self.node_attrs, self.edge_attrs, self.global_attrs)\n", + " \n", + " @property\n", + " def nodes(self):\n", + " return self.pandas[0]\n", + " \n", + " @property\n", + " def edges(self):\n", + " return self.pandas[1]\n", + " \n", + " @property\n", + " def globals(self):\n", + " return self.pandas[2]\n", + " \n", + " @property\n", + " def recordings(self):\n", + " df = pd.DataFrame.from_dict(self.globals[\"recordings\"], orient=\"index\").T.set_index(\"index\")\n", + " df.index.name = None\n", + " return df\n", + " \n", + " @property\n", + " def externals(self):\n", + " df = pd.DataFrame.from_dict(self.globals[\"externals\"], orient=\"index\").T.set_index(\"index\")\n", + " df.index.name = None\n", + " return df\n", + " \n", + " def set(self, key, value):\n", + " is_morph_key = (\"morphology\", key) in self.node_attrs\n", + " is_channel_key = (\"channels\", key) in self.node_attrs\n", + "\n", + " if is_morph_key or is_channel_key:\n", + " key = (\"morphology\", key) if is_morph_key else (\"channels\", key)\n", + " updated_view = self._set_node_attrs(keys_values={key: value}, inds=self._nodes_in_view)\n", + " return updated_view\n", + " \n", + " elif key in self.edge_attrs:\n", + " updated_base = self.base._set_edges([key], [value], self._edges_in_view)\n", + " \n", + " updated_edge_attrs = updated_base._select_edges(inds=self._edges_in_view)\n", + " updated_view = replace(self, base=updated_base, edge_attrs=updated_edge_attrs)\n", + " return updated_view\n", + " \n", + " else:\n", + " raise ValueError(f\"Key {key} not found or not mutable via `.set()`\")\n", + " \n", + " def insert(self, channel):\n", + " base_global_attrs = self.base.global_attrs.copy()\n", + " channel_param_states = {**channel.states, **channel.params, channel.name: True}\n", + " channel_setter = {(\"channels\", k): v for k,v in channel_param_states.items()}\n", + " if not channel.name in base_global_attrs[\"channels\"]:\n", + " base_global_attrs[\"channels\"][channel.name] = channel\n", + " updated_base = replace(self.base, global_attrs=base_global_attrs)\n", + " updated_view = replace(self, base=updated_base)\n", + " updated_view = updated_view._init_node_attrs(channel_setter)\n", + " return updated_view\n", + " else:\n", + " return self._set_node_attrs(channel_setter)\n", + " \n", + " def record(self, key):\n", + " if (\"recordings\", key) not in self.base.node_attrs:\n", + " return self._init_node_attrs({(\"recordings\", key): True}, init_value=False)\n", + " else:\n", + " return self._set_node_attrs({(\"recordings\", key): True})\n", + "\n", + " def add_to_group(self, group):\n", + " if (\"groups\", group) not in self.base.node_attrs:\n", + " return self._init_node_attrs({(\"groups\", group): True}, init_value=False)\n", + " else:\n", + " return self._set_node_attrs({(\"groups\", group): True})\n", + "\n", + " def stimulate(self, key, values):\n", + " \n", + " if (\"externals\", key) not in self.base.node_attrs:\n", + " updated_view = self._init_node_attrs({(\"externals\", key): True}, init_value=False)\n", + " else:\n", + " updated_view = self._set_node_attrs({(\"externals\", key): True})\n", + " \n", + " for idx in self._nodes_in_view:\n", + " updated_view.base.global_attrs[\"externals\"][(key, int(idx))] = values\n", + " updated_view.global_attrs[\"externals\"][(key, int(idx))] = values\n", + "\n", + " return updated_view\n", + "\n", + "\n", + "def connect(pre_view, post_view, synapse):\n", + " base_global_attrs = pre_view.base.global_attrs.copy()\n", + " synapse_param_states = {**synapse.states, **synapse.params, synapse.name: True}\n", + " synapse_setter = {k: v for k,v in synapse_param_states.items()}\n", + " \n", + " if not synapse.name in base_global_attrs[\"synapses\"]:\n", + " base_global_attrs[\"synapses\"][synapse.name] = synapse\n", + " updated_base = replace(pre_view.base, global_attrs=base_global_attrs)\n", + " updated_view = replace(pre_view, base=updated_base)\n", + " updated_view = updated_view._init_edge_attrs(synapse_setter)\n", + " return updated_view\n", + " else:\n", + " return pre_view._set_node_attrs(synapse_setter)" + ] + }, + { + "cell_type": "code", + "execution_count": 717, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'index': Array([ 0, 1, 2, ..., 2627, 2628, -1], dtype=int64),\n", + " 'index_post': Array([ 2, 852, 1412, ..., 2629, 2630, -1], dtype=int64),\n", + " 'index_pre': Array([ 1, 1, 1, ..., 2628, 2629, -1], dtype=int64),\n", + " 'synapse': Array([False, False, False, ..., False, False, False], dtype=bool)}" + ] + }, + "execution_count": 717, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def fill_with(x):\n", + " dtype_str = str(x.dtype)\n", + " if \"float\" in dtype_str:\n", + " return jnp.nan\n", + " elif \"bool\" in dtype_str:\n", + " return False\n", + " elif \"int\" in dtype_str:\n", + " return -1\n", + " else:\n", + " raise TypeError(f\"Unsupported dtype for padding: {x.dtype}\")\n", + "\n", + "new_row = {k: jnp.full((1, *v.shape[1:]), fill_with(v)).astype(v.dtype) for k, v in cell.edge_attrs.items()}\n", + "tree_concat(cell.edge_attrs, new_row)" + ] + }, + { + "cell_type": "code", + "execution_count": 620, + "metadata": {}, + "outputs": [], + "source": [ + "class TestChannel:\n", + " def __init__(self, name=None):\n", + " self.name = name if name is not None else \"TestChannel\"\n", + " self.states = {\"m\": 0.0, \"h\": 0.0}\n", + " self.params = {\"E\": jnp.array([0.0]), \"g\": jnp.array([[0.0, 0.0], [0.0, 0.0]])}\n", + " # self.params = {\"E\": 0.0, \"g\": 0.0}\n", + "\n", + "class TestSynapse:\n", + " def __init__(self, name=None):\n", + " self.name = name if name is not None else \"TestSynapse\"\n", + " self.states = {\"g\": 0.0}\n", + " self.params = {\"E\": jnp.array([0.0]), \"g\": jnp.array([0.0])}\n", + " # self.params = {\"E\": 0.0, \"g\": 0.0}" + ] + }, + { + "cell_type": "code", + "execution_count": 668, + "metadata": {}, + "outputs": [], + "source": [ + "G = swc_to_nx(\"../jaxley/tests/swc_files/morph_ca1_n120.swc\")\n", + "jax_node_attrs, jax_edge_attrs, jax_global_attrs = nx_to_jax(G)\n", + "jax_node_attrs[\"index\"] = jax_node_attrs[\"index\"] -1 # TMP FIX FOR INDEXING, otherwise index drift for select\n", + "\n", + "jax_node_attrs = {(\"morphology\", k) if k != \"index\" else (k,): v for k, v in jax_node_attrs.items()}\n", + "jax_edge_attrs[\"index\"] = jnp.arange(len(jax_edge_attrs[\"index_pre\"]))\n", + "jax_edge_attrs[\"synapse\"] = jnp.full_like(jax_edge_attrs[\"index\"], False, dtype=bool)\n", + "\n", + "cell = TestModule(jax_node_attrs, jax_edge_attrs, jax_global_attrs)\n", + "view = cell.select(nodes=jnp.array([1, 2, 4, 6, 7]))\n", + "\n", + "view = view.set(\"r\", 20)\n", + "view = view.insert(TestChannel())\n", + "view = view.record(\"i\")\n", + "view = view.add_to_group(\"test\")\n", + "view = view.stimulate(\"i\", jnp.array([0.0, 0.0]))" + ] + }, + { + "cell_type": "code", + "execution_count": 623, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
channelsexternalsgroupsmorphologyrecordings
ETestChannelghmitestidrxyzi
0[nan]True[[nan, nan], [nan, nan]]NaNNaNFalseFalse18.1190.000.000.00False
1[0.0]True[[0.0, 0.0], [0.0, 0.0]]0.00.0TrueTrue120.0001.85-4.030.00True
2[0.0]True[[0.0, 0.0], [0.0, 0.0]]0.00.0TrueTrue120.0001.98-6.000.00True
3[nan]True[[nan, nan], [nan, nan]]NaNNaNFalseFalse15.8101.99-7.260.00False
4[0.0]True[[0.0, 0.0], [0.0, 0.0]]0.00.0TrueTrue120.0002.17-8.490.00True
..........................................
2625[nan]True[[nan, nan], [nan, nan]]NaNNaNFalseFalse30.550137.4897.8033.36False
2626[nan]True[[nan, nan], [nan, nan]]NaNNaNFalseFalse30.550139.56102.0433.36False
2627[nan]True[[nan, nan], [nan, nan]]NaNNaNFalseFalse30.550139.26106.1330.31False
2628[nan]True[[nan, nan], [nan, nan]]NaNNaNFalseFalse30.550138.86112.1944.47False
2629[nan]True[[nan, nan], [nan, nan]]NaNNaNFalseFalse30.550138.77112.3444.47False
\n", + "

2630 rows × 13 columns

\n", + "
" + ], + "text/plain": [ + " channels externals \\\n", + " E TestChannel g h m i \n", + "0 [nan] True [[nan, nan], [nan, nan]] NaN NaN False \n", + "1 [0.0] True [[0.0, 0.0], [0.0, 0.0]] 0.0 0.0 True \n", + "2 [0.0] True [[0.0, 0.0], [0.0, 0.0]] 0.0 0.0 True \n", + "3 [nan] True [[nan, nan], [nan, nan]] NaN NaN False \n", + "4 [0.0] True [[0.0, 0.0], [0.0, 0.0]] 0.0 0.0 True \n", + "... ... ... ... ... ... ... \n", + "2625 [nan] True [[nan, nan], [nan, nan]] NaN NaN False \n", + "2626 [nan] True [[nan, nan], [nan, nan]] NaN NaN False \n", + "2627 [nan] True [[nan, nan], [nan, nan]] NaN NaN False \n", + "2628 [nan] True [[nan, nan], [nan, nan]] NaN NaN False \n", + "2629 [nan] True [[nan, nan], [nan, nan]] NaN NaN False \n", + "\n", + " groups morphology recordings \n", + " test id r x y z i \n", + "0 False 1 8.119 0.00 0.00 0.00 False \n", + "1 True 1 20.000 1.85 -4.03 0.00 True \n", + "2 True 1 20.000 1.98 -6.00 0.00 True \n", + "3 False 1 5.810 1.99 -7.26 0.00 False \n", + "4 True 1 20.000 2.17 -8.49 0.00 True \n", + "... ... ... ... ... ... ... ... \n", + "2625 False 3 0.550 137.48 97.80 33.36 False \n", + "2626 False 3 0.550 139.56 102.04 33.36 False \n", + "2627 False 3 0.550 139.26 106.13 30.31 False \n", + "2628 False 3 0.550 138.86 112.19 44.47 False \n", + "2629 False 3 0.550 138.77 112.34 44.47 False \n", + "\n", + "[2630 rows x 13 columns]" + ] + }, + "execution_count": 623, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "view.base.nodes" + ] + }, + { + "cell_type": "code", + "execution_count": 722, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def test():\n", + " view = cell.select(nodes=jnp.array([1, 2, 4, 6, 7]))\n", + " view = view.set(\"r\", 20)\n", + " view = view.insert(TestChannel())\n", + " view = view.record(\"i\")\n", + " view = view.add_to_group(\"test\")\n", + " # view = view.stimulate(\"i\", jnp.array([0.0, 0.0]))\n", + "\n", + "# connect(cell.select(nodes=jnp.array([1, 2])), cell.select(nodes=jnp.array([3, 4])), TestSynapse())\n", + "\n", + "# cell.select(nodes=jnp.array([1, 2, 3, 4, 5])).select(nodes=jnp.array([1, 3])).insert(TestChannel()) # this is slow!\n", + "# cell.select(nodes=jnp.array([2, 3, 4, 5])).record(\"v\")\n", + "# cell.select(nodes=jnp.array([2, 3, 4, 5])).stimulate(\"i\", jnp.array([0.0, 0.0, 0.0, 0.0]))\n", + "# cell.select(nodes=jnp.array([2, 3, 4, 5])).add_to_group(\"test\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}