From 255a4c84ef3eeeb91816422338664833b7d2f919 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 7 Nov 2024 17:04:40 +0100 Subject: [PATCH 1/9] add: wip version of tutorial on views --- docs/tutorials/00_jaxley_api.ipynb | 1260 ++++++++++++++++++++++++++++ 1 file changed, 1260 insertions(+) create mode 100644 docs/tutorials/00_jaxley_api.ipynb diff --git a/docs/tutorials/00_jaxley_api.ipynb b/docs/tutorials/00_jaxley_api.ipynb new file mode 100644 index 000000000..88e615793 --- /dev/null +++ b/docs/tutorials/00_jaxley_api.ipynb @@ -0,0 +1,1260 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "597dfe2a-d5fe-4e3d-8fb5-bb415126b81a", + "metadata": {}, + "source": [ + "# Basics of Jaxley" + ] + }, + { + "cell_type": "markdown", + "id": "c9db67ff-6334-4435-9092-e7c71ec71a93", + "metadata": {}, + "source": [ + "In this tutorial, we will introduce you to the basic concepts of Jaxley.\n", + "You will learn about:\n", + "\n", + "- Modules\n", + " - nodes\n", + " - edges\n", + "- Channels\n", + "- Synapses\n", + "- Views\n", + " - Groups\n", + "\n", + "Here is a code snippet which you will learn to understand in this tutorial:\n", + "```python\n", + "import jaxley as jx\n", + "from jaxley.channels import Na, K, Leak\n", + "from jaxley.synapses import IonotropicSynapse\n", + "from jaxley.connect import connect\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "\n", + "# Assembling different Modules into a Network\n", + "comp = jx.Compartment()\n", + "branch = jx.Branch(comp, nseg=1)\n", + "cell = jx.Cell(branch, parents=[-1, 0, 0])\n", + "net = jx.Network([cell]*3)\n", + "\n", + "# Navigating and inspecting the Modules using Views\n", + "cell0 = net.cell(0)\n", + "cell0.nodes\n", + "\n", + "# How to group together parts of Modules\n", + "net.cell(1).add_to_group(\"cell1\")\n", + "\n", + "# connecting two cells using a Synapse\n", + "pre_comp = cell0.branch(1).comp(0)\n", + "post_comp = net.cell1.branch(0).comp(0)\n", + "\n", + "connect(pre_comp, post_comp)\n", + "\n", + "# inserting channels in the membrane\n", + "with net.cell(0) as cell0:\n", + " cell0.insert(Na())\n", + " cell0.insert(K())\n", + "\n", + "with net.cell(1) as cell1:\n", + " cell0.insert(Leak())\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "7177950f-d702-4d8d-b69e-bfb06677037f", + "metadata": {}, + "source": [ + "First, we import the relevant libraries:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "deb594f4", + "metadata": {}, + "outputs": [], + "source": [ + "from jax import config\n", + "config.update(\"jax_enable_x64\", True)\n", + "config.update(\"jax_platform_name\", \"cpu\")\n", + "\n", + "import jaxley as jx\n", + "from jaxley.channels import Na, K, Leak\n", + "from jaxley.synapses import IonotropicSynapse\n", + "from jaxley.connect import connect\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "id": "415b6741", + "metadata": {}, + "source": [ + "# Modules\n", + "\n", + "In Jaxley, we heavily rely on the concept of Modules to build biophyiscal models of neural systems at various scales.\n", + "Jaxley implements 4 Module types:\n", + "- `Compartment`\n", + "- `Branch`\n", + "- `Cell`\n", + "- `Network`\n", + "\n", + "Modules can be connected together to build increasingly detailed and complex models. `Compartment` -> `Branch` -> `Cell` -> `Network`." + ] + }, + { + "cell_type": "markdown", + "id": "7480ea5e", + "metadata": {}, + "source": [ + "`Compartment`s are the atoms of biophysical models in Jaxley. All mechanisms and synaptic connections live on the level of `Compartment`s and can already be simulated using `jx.integrate` on their own. Everything you do in Jaxley starts with a `Compartment`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "535057cf", + "metadata": {}, + "outputs": [], + "source": [ + "comp = jx.Compartment() # single compartment model" + ] + }, + { + "cell_type": "markdown", + "id": "469e25a3", + "metadata": {}, + "source": [ + "Mutliple `Compartments` can be connected together to form longer, linear segments / cables, we call `Branch`es and are essentially equivalent to sections in `NEURON`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "0476a173", + "metadata": {}, + "outputs": [], + "source": [ + "nseg = 4\n", + "branch = jx.Branch([comp]*nseg)" + ] + }, + { + "cell_type": "markdown", + "id": "63c35e7b", + "metadata": {}, + "source": [ + "In order to construct cell morphologies in Jaxley, multiple `Branches` can to be connected together using the `Cell` primitive." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "6f4c5202", + "metadata": {}, + "outputs": [], + "source": [ + "parents = [-1,0,0] # soma = -1, since it has no parents and both dendrites connect to the soma (0). \n", + "cell = jx.Cell([branch]*len(parents), parents)" + ] + }, + { + "cell_type": "markdown", + "id": "3830cf70", + "metadata": {}, + "source": [ + "Finally, several `Cell`s can be grouped together to form a `Network`, which can than be connected together using `Synpase`s." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "3399a716", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(2, 6, 24)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ncells = 2\n", + "net = jx.Network([cell]*ncells)\n", + "\n", + "net.shape # shows you the num_cells, num_branches, num_comps" + ] + }, + { + "cell_type": "markdown", + "id": "30eb8fd5", + "metadata": {}, + "source": [ + "`Module`s carry around the information about their current state and parameters in two Dataframes called `nodes` and `edges`.\n", + "`nodes` contains all the information that we associate with compartments in the model (each row corresponds to a compartment) and `edges` all the information relevant to synapses.\n", + "\n", + "This means that you can easily keep track of the current state of your `Module` and how it changes at all times." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "701364b0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
local_cell_indexlocal_branch_indexlocal_comp_indexlengthradiusaxial_resistivitycapacitancevglobal_cell_indexglobal_branch_indexglobal_comp_indexcontrolled_by_param
000010.01.05000.01.0-70.00000
100110.01.05000.01.0-70.00010
200210.01.05000.01.0-70.00020
300310.01.05000.01.0-70.00030
401010.01.05000.01.0-70.00140
501110.01.05000.01.0-70.00150
601210.01.05000.01.0-70.00160
701310.01.05000.01.0-70.00170
802010.01.05000.01.0-70.00280
902110.01.05000.01.0-70.00290
1002210.01.05000.01.0-70.002100
1102310.01.05000.01.0-70.002110
1210010.01.05000.01.0-70.013120
1310110.01.05000.01.0-70.013130
1410210.01.05000.01.0-70.013140
1510310.01.05000.01.0-70.013150
1611010.01.05000.01.0-70.014160
1711110.01.05000.01.0-70.014170
1811210.01.05000.01.0-70.014180
1911310.01.05000.01.0-70.014190
2012010.01.05000.01.0-70.015200
2112110.01.05000.01.0-70.015210
2212210.01.05000.01.0-70.015220
2312310.01.05000.01.0-70.015230
\n", + "
" + ], + "text/plain": [ + " local_cell_index local_branch_index local_comp_index length radius \\\n", + "0 0 0 0 10.0 1.0 \n", + "1 0 0 1 10.0 1.0 \n", + "2 0 0 2 10.0 1.0 \n", + "3 0 0 3 10.0 1.0 \n", + "4 0 1 0 10.0 1.0 \n", + "5 0 1 1 10.0 1.0 \n", + "6 0 1 2 10.0 1.0 \n", + "7 0 1 3 10.0 1.0 \n", + "8 0 2 0 10.0 1.0 \n", + "9 0 2 1 10.0 1.0 \n", + "10 0 2 2 10.0 1.0 \n", + "11 0 2 3 10.0 1.0 \n", + "12 1 0 0 10.0 1.0 \n", + "13 1 0 1 10.0 1.0 \n", + "14 1 0 2 10.0 1.0 \n", + "15 1 0 3 10.0 1.0 \n", + "16 1 1 0 10.0 1.0 \n", + "17 1 1 1 10.0 1.0 \n", + "18 1 1 2 10.0 1.0 \n", + "19 1 1 3 10.0 1.0 \n", + "20 1 2 0 10.0 1.0 \n", + "21 1 2 1 10.0 1.0 \n", + "22 1 2 2 10.0 1.0 \n", + "23 1 2 3 10.0 1.0 \n", + "\n", + " axial_resistivity capacitance v global_cell_index \\\n", + "0 5000.0 1.0 -70.0 0 \n", + "1 5000.0 1.0 -70.0 0 \n", + "2 5000.0 1.0 -70.0 0 \n", + "3 5000.0 1.0 -70.0 0 \n", + "4 5000.0 1.0 -70.0 0 \n", + "5 5000.0 1.0 -70.0 0 \n", + "6 5000.0 1.0 -70.0 0 \n", + "7 5000.0 1.0 -70.0 0 \n", + "8 5000.0 1.0 -70.0 0 \n", + "9 5000.0 1.0 -70.0 0 \n", + "10 5000.0 1.0 -70.0 0 \n", + "11 5000.0 1.0 -70.0 0 \n", + "12 5000.0 1.0 -70.0 1 \n", + "13 5000.0 1.0 -70.0 1 \n", + "14 5000.0 1.0 -70.0 1 \n", + "15 5000.0 1.0 -70.0 1 \n", + "16 5000.0 1.0 -70.0 1 \n", + "17 5000.0 1.0 -70.0 1 \n", + "18 5000.0 1.0 -70.0 1 \n", + "19 5000.0 1.0 -70.0 1 \n", + "20 5000.0 1.0 -70.0 1 \n", + "21 5000.0 1.0 -70.0 1 \n", + "22 5000.0 1.0 -70.0 1 \n", + "23 5000.0 1.0 -70.0 1 \n", + "\n", + " global_branch_index global_comp_index controlled_by_param \n", + "0 0 0 0 \n", + "1 0 1 0 \n", + "2 0 2 0 \n", + "3 0 3 0 \n", + "4 1 4 0 \n", + "5 1 5 0 \n", + "6 1 6 0 \n", + "7 1 7 0 \n", + "8 2 8 0 \n", + "9 2 9 0 \n", + "10 2 10 0 \n", + "11 2 11 0 \n", + "12 3 12 0 \n", + "13 3 13 0 \n", + "14 3 14 0 \n", + "15 3 15 0 \n", + "16 4 16 0 \n", + "17 4 17 0 \n", + "18 4 18 0 \n", + "19 4 19 0 \n", + "20 5 20 0 \n", + "21 5 21 0 \n", + "22 5 22 0 \n", + "23 5 23 0 " + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "net.nodes" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "eeaaefd1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
global_edge_indexglobal_pre_comp_indexglobal_post_comp_indexpre_locspost_locstypetype_ind
\n", + "
" + ], + "text/plain": [ + "Empty DataFrame\n", + "Columns: [global_edge_index, global_pre_comp_index, global_post_comp_index, pre_locs, post_locs, type, type_ind]\n", + "Index: []" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "net.edges.head() # this is currently empty since we have not made any connections yet" + ] + }, + { + "cell_type": "markdown", + "id": "2904c0c5", + "metadata": {}, + "source": [ + "# Views" + ] + }, + { + "cell_type": "markdown", + "id": "fd0d4fa3", + "metadata": {}, + "source": [ + "Since these models can become arbitrarily complex, Jaxley utilizes so called `View`s to make working with `Modules` easy and intuitive. \n", + "\n", + "The simplest way to navigate Modules is by navigating them via the hierachy that we introduces above. Let's see how this works for a `Nework`." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "d6abcdec", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "View with 0 different channels. Use `.nodes` for details." + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "net.cell(0) # View of the 0th cell of the network\n", + "net.cell(0).branch(0) # View of the 1st branch of the 0th cell of the network\n", + "net.cell(0).branch(1).comp(0) # View of the 0th comp of the 1st branch of the 0th cell of the network\n", + "\n", + "# several types of indices are supported (lists, ranges, ...)\n", + "net.cell([0,1]).branch(\"all\").comp(0) # View of all 0th comps of all branches of cell 0 and 1\n", + "\n", + "branch.loc(0.1) # equivalent to `NEURON`s `loc`. Assumes branches are continous from 0-1.\n", + "\n", + "net[0,0,0] # Modules/Views can also be lazily indexed\n", + "\n", + "cell0 = net.cell(0) # views can be assigned to variables and only track the parts of the Module they belong to\n", + "cell0.branch(1).comp(0) # Views can be continuely indexed" + ] + }, + { + "cell_type": "markdown", + "id": "7283d2d4", + "metadata": {}, + "source": [ + "_In case you need even more flexibility in how you select parts of a Module, Jaxley provides a `select` method, to give full control over the exact parts of the `nodes` and `edges` that are part of a `View`. On examples of how this can be used, see [](MISSING)._" + ] + }, + { + "cell_type": "markdown", + "id": "4f8922c4", + "metadata": {}, + "source": [ + "Views behave very similarly to `Module`s, i.e. the `cell0` (the 0th cell of the network) from the example above handles like the `cell` we instantiated earlier, which then became cell 0 of the network. As such `cell0` also has a `nodes` attribute, which keeps track of it's part of the network." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "d5502655", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
local_cell_indexlocal_branch_indexlocal_comp_indexlengthradiusaxial_resistivitycapacitancevglobal_cell_indexglobal_branch_indexglobal_comp_indexcontrolled_by_param
000010.01.05000.01.0-70.00000
100110.01.05000.01.0-70.00010
200210.01.05000.01.0-70.00020
300310.01.05000.01.0-70.00030
401010.01.05000.01.0-70.00140
501110.01.05000.01.0-70.00150
601210.01.05000.01.0-70.00160
701310.01.05000.01.0-70.00170
802010.01.05000.01.0-70.00280
902110.01.05000.01.0-70.00290
1002210.01.05000.01.0-70.002100
1102310.01.05000.01.0-70.002110
\n", + "
" + ], + "text/plain": [ + " local_cell_index local_branch_index local_comp_index length radius \\\n", + "0 0 0 0 10.0 1.0 \n", + "1 0 0 1 10.0 1.0 \n", + "2 0 0 2 10.0 1.0 \n", + "3 0 0 3 10.0 1.0 \n", + "4 0 1 0 10.0 1.0 \n", + "5 0 1 1 10.0 1.0 \n", + "6 0 1 2 10.0 1.0 \n", + "7 0 1 3 10.0 1.0 \n", + "8 0 2 0 10.0 1.0 \n", + "9 0 2 1 10.0 1.0 \n", + "10 0 2 2 10.0 1.0 \n", + "11 0 2 3 10.0 1.0 \n", + "\n", + " axial_resistivity capacitance v global_cell_index \\\n", + "0 5000.0 1.0 -70.0 0 \n", + "1 5000.0 1.0 -70.0 0 \n", + "2 5000.0 1.0 -70.0 0 \n", + "3 5000.0 1.0 -70.0 0 \n", + "4 5000.0 1.0 -70.0 0 \n", + "5 5000.0 1.0 -70.0 0 \n", + "6 5000.0 1.0 -70.0 0 \n", + "7 5000.0 1.0 -70.0 0 \n", + "8 5000.0 1.0 -70.0 0 \n", + "9 5000.0 1.0 -70.0 0 \n", + "10 5000.0 1.0 -70.0 0 \n", + "11 5000.0 1.0 -70.0 0 \n", + "\n", + " global_branch_index global_comp_index controlled_by_param \n", + "0 0 0 0 \n", + "1 0 1 0 \n", + "2 0 2 0 \n", + "3 0 3 0 \n", + "4 1 4 0 \n", + "5 1 5 0 \n", + "6 1 6 0 \n", + "7 1 7 0 \n", + "8 2 8 0 \n", + "9 2 9 0 \n", + "10 2 10 0 \n", + "11 2 11 0 " + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cell0.nodes" + ] + }, + { + "cell_type": "markdown", + "id": "76e72d1c", + "metadata": {}, + "source": [ + "Assigning `View`s to a variable makes it easuer to reuse parts of a `Module` later or to highlight them. However, this can become messy and we might need access to such a `View` more readily. For this purpose Jaxley implements so called groups that can be used to assign any `View` of a `Module` to an attribute, i.e. the soma." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "41d38b22", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Groups {'somas': array([ 0, 12])}\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
local_cell_indexlocal_branch_indexlocal_comp_indexlengthradiusaxial_resistivitycapacitancevglobal_cell_indexglobal_branch_indexglobal_comp_indexcontrolled_by_param
000010.01.05000.01.0-70.00000
1210010.01.05000.01.0-70.013120
\n", + "
" + ], + "text/plain": [ + " local_cell_index local_branch_index local_comp_index length radius \\\n", + "0 0 0 0 10.0 1.0 \n", + "12 1 0 0 10.0 1.0 \n", + "\n", + " axial_resistivity capacitance v global_cell_index \\\n", + "0 5000.0 1.0 -70.0 0 \n", + "12 5000.0 1.0 -70.0 1 \n", + "\n", + " global_branch_index global_comp_index controlled_by_param \n", + "0 0 0 0 \n", + "12 3 12 0 " + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "net.cell(\"all\").branch(0).comp(0).add_to_group(\"somas\")\n", + "print(\"Groups\", net.groups) # list the indices of the nodes dataframe that are part of the group\n", + "\n", + "somas = net.somas # returns a View with only a subset of nodes\n", + "somas.nodes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "edac8921", + "metadata": {}, + "outputs": [], + "source": [ + "# connecting two cells using a Synapse\n", + "pre_comp = cell0.branch(1).comp(0)\n", + "post_comp = net.cell1.branch(0).comp(0)\n", + "\n", + "connect(pre_comp, post_comp)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jaxley", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 979a58941b8ec27de913021e3cc25e48c22939aa Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 8 Nov 2024 11:41:03 +0100 Subject: [PATCH 2/9] fix: added fixes for clamping and stimulating of synapses --- jaxley/integrate.py | 8 ++++---- jaxley/modules/base.py | 38 ++++++++++++++++++++++++++------------ 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/jaxley/integrate.py b/jaxley/integrate.py index 45645d05d..c068ec15b 100644 --- a/jaxley/integrate.py +++ b/jaxley/integrate.py @@ -118,11 +118,11 @@ def add_stimuli( if data_stimuli is not None: externals["i"] = jnp.concatenate([externals["i"], data_stimuli[1]]) external_inds["i"] = jnp.concatenate( - [external_inds["i"], data_stimuli[2].global_comp_index.to_numpy()] + [external_inds["i"], data_stimuli[2].index.to_numpy()] ) else: externals["i"] = data_stimuli[1] - external_inds["i"] = data_stimuli[2].global_comp_index.to_numpy() + external_inds["i"] = data_stimuli[2].index.to_numpy() return externals, external_inds @@ -148,11 +148,11 @@ def add_clamps( if state_name in externals.keys(): externals[state_name] = jnp.concatenate([externals[state_name], clamps]) external_inds[state_name] = jnp.concatenate( - [external_inds[state_name], inds.global_comp_index.to_numpy()] + [external_inds[state_name], inds.index.to_numpy()] ) else: externals[state_name] = clamps - external_inds[state_name] = inds.global_comp_index.to_numpy() + external_inds[state_name] = inds.index.to_numpy() return externals, external_inds diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 12bb8f0dc..f852a1251 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1514,7 +1514,8 @@ def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True) This function sets external states for the compartments. """ - if state_name not in self.nodes.columns: + + if state_name not in list(self.nodes.columns) + list(self.edges.columns): raise KeyError(f"{state_name} is not a recognized state in this module.") self._external_input(state_name, state_array, verbose=verbose) @@ -1526,13 +1527,11 @@ def _external_input( ): values = values if values.ndim == 2 else jnp.expand_dims(values, axis=0) batch_size = values.shape[0] - num_inserted = len(self._nodes_in_view) - is_multiple = num_inserted == batch_size - values = ( - values - if is_multiple - else jnp.repeat(values, len(self._nodes_in_view), axis=0) + num_inserted = ( + len(self._nodes_in_view) if key in self.nodes else len(self._edges_in_view) ) + is_multiple = num_inserted == batch_size + values = values if is_multiple else jnp.repeat(values, num_inserted, axis=0) assert batch_size in [ 1, num_inserted, @@ -1546,8 +1545,14 @@ def _external_input( [self.base.external_inds[key], self._nodes_in_view] ) else: - self.base.externals[key] = values - self.base.external_inds[key] = self._nodes_in_view + if key in self.base.nodes.columns: + self.base.externals[key] = values + self.base.external_inds[key] = self._nodes_in_view + elif key in self.base.edges.columns: + self.base.externals[key] = values + self.base.external_inds[key] = self._edges_in_view + else: + raise KeyError(f"Key '{key}' not found in nodes or edges") if verbose: print( @@ -1588,8 +1593,11 @@ def data_clamp( verbose: Whether or not to print the number of inserted clamps. `False` by default because this method is meant to be jitted. """ + if state_name not in list(self.nodes.columns) + list(self.edges.columns): + raise KeyError(f"{state_name} is not a recognized state in this module.") + data = self.nodes if state_name in self.nodes.columns else self.edges return self._data_external_input( - state_name, state_array, data_clamps, self.nodes, verbose=verbose + state_name, state_array, data_clamps, data, verbose=verbose ) def _data_external_input( @@ -1606,10 +1614,16 @@ def _data_external_input( else jnp.expand_dims(state_array, axis=0) ) batch_size = state_array.shape[0] - num_inserted = len(self._nodes_in_view) + num_inserted = ( + len(self._nodes_in_view) + if state_name in self.nodes + else len(self._edges_in_view) + ) is_multiple = num_inserted == batch_size state_array = ( - state_array if is_multiple else jnp.repeat(state_array, len(view), axis=0) + state_array + if is_multiple + else jnp.repeat(state_array, num_inserted, axis=0) ) assert batch_size in [ 1, From 843c38076cb79f3236cf1b6c695d401806604e7a Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 8 Nov 2024 12:25:06 +0100 Subject: [PATCH 3/9] fix: allow all states --- jaxley/modules/base.py | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index f852a1251..2afaa3c1b 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1179,6 +1179,15 @@ def add_to_group(self, group_name: str): np.concatenate([self.base.groups[group_name], self._nodes_in_view]) ) + def _get_state_names(self) -> Tuple[List, List]: + """Collect all recordable / clampable states in the membrane and synapses. + + Returns states seperated by comps and edges.""" + channel_states = [name for c in self.channels for name in c.channel_states] + synapse_states = [name for s in self.synapses for name in s.synapse_states] + membrane_states = ["v"] + self.membrane_current_names + return channel_states + membrane_states, synapse_states + def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: """Get all trainable parameters. @@ -1447,10 +1456,10 @@ def _init_morph_for_debugging(self): self.base.debug_states["par_inds"] = self.base.par_inds def record(self, state: str = "v", verbose=True): - in_view = None - in_view = self._edges_in_view if state in self.edges.columns else in_view - in_view = self._nodes_in_view if state in self.nodes.columns else in_view - assert in_view is not None, "State not found in nodes or edges." + comp_states, edge_states = self._get_state_names() + if state not in comp_states + edge_states: + raise KeyError(f"{state} is not a recognized state in this module.") + in_view = self._nodes_in_view if state in comp_states else self._edges_in_view new_recs = pd.DataFrame(in_view, columns=["rec_index"]) new_recs["state"] = state @@ -1514,9 +1523,6 @@ def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True) This function sets external states for the compartments. """ - - if state_name not in list(self.nodes.columns) + list(self.edges.columns): - raise KeyError(f"{state_name} is not a recognized state in this module.") self._external_input(state_name, state_array, verbose=verbose) def _external_input( @@ -1525,10 +1531,13 @@ def _external_input( values: Optional[jnp.ndarray], verbose: bool = True, ): + comp_states, edge_states = self._get_state_names() + if key not in comp_states + edge_states: + raise KeyError(f"{key} is not a recognized state in this module.") values = values if values.ndim == 2 else jnp.expand_dims(values, axis=0) batch_size = values.shape[0] num_inserted = ( - len(self._nodes_in_view) if key in self.nodes else len(self._edges_in_view) + len(self._nodes_in_view) if key in comp_states else len(self._edges_in_view) ) is_multiple = num_inserted == batch_size values = values if is_multiple else jnp.repeat(values, num_inserted, axis=0) @@ -1545,15 +1554,12 @@ def _external_input( [self.base.external_inds[key], self._nodes_in_view] ) else: - if key in self.base.nodes.columns: + if key in comp_states: self.base.externals[key] = values self.base.external_inds[key] = self._nodes_in_view - elif key in self.base.edges.columns: + else: self.base.externals[key] = values self.base.external_inds[key] = self._edges_in_view - else: - raise KeyError(f"Key '{key}' not found in nodes or edges") - if verbose: print( f"Added {num_inserted} external_states. See `.externals` for details." @@ -1593,9 +1599,10 @@ def data_clamp( verbose: Whether or not to print the number of inserted clamps. `False` by default because this method is meant to be jitted. """ - if state_name not in list(self.nodes.columns) + list(self.edges.columns): + comp_states, edge_states = self._get_state_names() + if state_name not in comp_states + edge_states: raise KeyError(f"{state_name} is not a recognized state in this module.") - data = self.nodes if state_name in self.nodes.columns else self.edges + data = self.nodes if state_name in comp_states else self.edges return self._data_external_input( state_name, state_array, data_clamps, data, verbose=verbose ) @@ -1608,6 +1615,7 @@ def _data_external_input( view: pd.DataFrame, verbose: bool = False, ): + comp_states, edge_states = self._get_state_names() state_array = ( state_array if state_array.ndim == 2 @@ -1616,7 +1624,7 @@ def _data_external_input( batch_size = state_array.shape[0] num_inserted = ( len(self._nodes_in_view) - if state_name in self.nodes + if state_name in comp_states else len(self._edges_in_view) ) is_multiple = num_inserted == batch_size From c27cf96472c730f042cf757006b332717f50c23b Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 8 Nov 2024 12:26:21 +0100 Subject: [PATCH 4/9] fix: add current to states --- jaxley/modules/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 2afaa3c1b..cc7e393fc 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1185,7 +1185,7 @@ def _get_state_names(self) -> Tuple[List, List]: Returns states seperated by comps and edges.""" channel_states = [name for c in self.channels for name in c.channel_states] synapse_states = [name for s in self.synapses for name in s.synapse_states] - membrane_states = ["v"] + self.membrane_current_names + membrane_states = ["v", "i"] + self.membrane_current_names return channel_states + membrane_states, synapse_states def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: From b3e89598a33469e52f1a130ddceeee9e0e606155 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 8 Nov 2024 12:47:44 +0100 Subject: [PATCH 5/9] fix: add tests for current and synapse clamping --- tests/test_clamp.py | 51 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/test_clamp.py b/tests/test_clamp.py index c4ed7265e..3b3be8430 100644 --- a/tests/test_clamp.py +++ b/tests/test_clamp.py @@ -3,6 +3,9 @@ import jax +from jaxley.connect import connect +from jaxley.synapses.ionotropic import IonotropicSynapse + jax.config.update("jax_enable_x64", True) jax.config.update("jax_platform_name", "cpu") from typing import Optional @@ -25,6 +28,54 @@ def test_clamp_pointneuron(): assert np.all(v[:, 1:] == -50.0) +def test_clamp_currents(): + comp = jx.Compartment() + comp.insert(HH()) + comp.record() + + # test clamp + comp.clamp("i_HH", 1.0 * jnp.ones((1000,))) + i1 = jx.integrate(comp, t_max=1.0) + assert np.all(i1[:, 1:] == 1.0) + + # test data clamp + data_clamps = None + ipts = 1.0 * jnp.ones((1000,)) + data_clamps = comp.data_clamp("i_HH", ipts, data_clamps=data_clamps) + + i2 = jx.integrate(comp, data_clamps=data_clamps, t_max=1.0) + assert np.all(i2[:, 1:] == 1.0) + + assert np.all(np.isclose(i1, i2)) + + +def test_clamp_synapse(): + comp = jx.Compartment() + branch = jx.Branch(comp, 1) + cell1 = jx.Cell(branch, [-1]) + cell2 = jx.Cell(branch, [-1]) + net = jx.Network([cell1, cell2]) + connect(net[0, 0, 0], net[1, 0, 0], IonotropicSynapse()) + net.record("IonotropicSynapse_s") + + # test clamp + net.clamp("IonotropicSynapse_s", 1.0 * jnp.ones((1000,))) + s1 = jx.integrate(net, t_max=1.0) + assert np.all(s1[:, 1:] == 1.0) + + net.delete_clamps() + + # test data clamp + data_clamps = None + ipts = 1.0 * jnp.ones((1000,)) + data_clamps = net.data_clamp("IonotropicSynapse_s", ipts, data_clamps=data_clamps) + + s2 = jx.integrate(net, data_clamps=data_clamps, t_max=1.0) + assert np.all(s2[:, 1:] == 1.0) + + assert np.all(np.isclose(s1, s2)) + + def test_clamp_multicompartment(): comp = jx.Compartment() branch = jx.Branch(comp, 4) From b28b5f0bf6598c4ade62e176fa6baad602c4b2ef Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 8 Nov 2024 12:54:39 +0100 Subject: [PATCH 6/9] fix: fix current clamp test --- tests/test_clamp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_clamp.py b/tests/test_clamp.py index 3b3be8430..8253cd5bb 100644 --- a/tests/test_clamp.py +++ b/tests/test_clamp.py @@ -31,7 +31,7 @@ def test_clamp_pointneuron(): def test_clamp_currents(): comp = jx.Compartment() comp.insert(HH()) - comp.record() + comp.record("i_HH") # test clamp comp.clamp("i_HH", 1.0 * jnp.ones((1000,))) From 975aed14d3f7ce70f6f1691530ff108cca604a54 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 8 Nov 2024 12:55:40 +0100 Subject: [PATCH 7/9] rm: rm accidently commited stash from different branch --- docs/tutorials/00_jaxley_api.ipynb | 1260 ---------------------------- 1 file changed, 1260 deletions(-) delete mode 100644 docs/tutorials/00_jaxley_api.ipynb diff --git a/docs/tutorials/00_jaxley_api.ipynb b/docs/tutorials/00_jaxley_api.ipynb deleted file mode 100644 index 88e615793..000000000 --- a/docs/tutorials/00_jaxley_api.ipynb +++ /dev/null @@ -1,1260 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "597dfe2a-d5fe-4e3d-8fb5-bb415126b81a", - "metadata": {}, - "source": [ - "# Basics of Jaxley" - ] - }, - { - "cell_type": "markdown", - "id": "c9db67ff-6334-4435-9092-e7c71ec71a93", - "metadata": {}, - "source": [ - "In this tutorial, we will introduce you to the basic concepts of Jaxley.\n", - "You will learn about:\n", - "\n", - "- Modules\n", - " - nodes\n", - " - edges\n", - "- Channels\n", - "- Synapses\n", - "- Views\n", - " - Groups\n", - "\n", - "Here is a code snippet which you will learn to understand in this tutorial:\n", - "```python\n", - "import jaxley as jx\n", - "from jaxley.channels import Na, K, Leak\n", - "from jaxley.synapses import IonotropicSynapse\n", - "from jaxley.connect import connect\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "\n", - "\n", - "# Assembling different Modules into a Network\n", - "comp = jx.Compartment()\n", - "branch = jx.Branch(comp, nseg=1)\n", - "cell = jx.Cell(branch, parents=[-1, 0, 0])\n", - "net = jx.Network([cell]*3)\n", - "\n", - "# Navigating and inspecting the Modules using Views\n", - "cell0 = net.cell(0)\n", - "cell0.nodes\n", - "\n", - "# How to group together parts of Modules\n", - "net.cell(1).add_to_group(\"cell1\")\n", - "\n", - "# connecting two cells using a Synapse\n", - "pre_comp = cell0.branch(1).comp(0)\n", - "post_comp = net.cell1.branch(0).comp(0)\n", - "\n", - "connect(pre_comp, post_comp)\n", - "\n", - "# inserting channels in the membrane\n", - "with net.cell(0) as cell0:\n", - " cell0.insert(Na())\n", - " cell0.insert(K())\n", - "\n", - "with net.cell(1) as cell1:\n", - " cell0.insert(Leak())\n", - "```" - ] - }, - { - "cell_type": "markdown", - "id": "7177950f-d702-4d8d-b69e-bfb06677037f", - "metadata": {}, - "source": [ - "First, we import the relevant libraries:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "deb594f4", - "metadata": {}, - "outputs": [], - "source": [ - "from jax import config\n", - "config.update(\"jax_enable_x64\", True)\n", - "config.update(\"jax_platform_name\", \"cpu\")\n", - "\n", - "import jaxley as jx\n", - "from jaxley.channels import Na, K, Leak\n", - "from jaxley.synapses import IonotropicSynapse\n", - "from jaxley.connect import connect\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np" - ] - }, - { - "cell_type": "markdown", - "id": "415b6741", - "metadata": {}, - "source": [ - "# Modules\n", - "\n", - "In Jaxley, we heavily rely on the concept of Modules to build biophyiscal models of neural systems at various scales.\n", - "Jaxley implements 4 Module types:\n", - "- `Compartment`\n", - "- `Branch`\n", - "- `Cell`\n", - "- `Network`\n", - "\n", - "Modules can be connected together to build increasingly detailed and complex models. `Compartment` -> `Branch` -> `Cell` -> `Network`." - ] - }, - { - "cell_type": "markdown", - "id": "7480ea5e", - "metadata": {}, - "source": [ - "`Compartment`s are the atoms of biophysical models in Jaxley. All mechanisms and synaptic connections live on the level of `Compartment`s and can already be simulated using `jx.integrate` on their own. Everything you do in Jaxley starts with a `Compartment`." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "535057cf", - "metadata": {}, - "outputs": [], - "source": [ - "comp = jx.Compartment() # single compartment model" - ] - }, - { - "cell_type": "markdown", - "id": "469e25a3", - "metadata": {}, - "source": [ - "Mutliple `Compartments` can be connected together to form longer, linear segments / cables, we call `Branch`es and are essentially equivalent to sections in `NEURON`." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "0476a173", - "metadata": {}, - "outputs": [], - "source": [ - "nseg = 4\n", - "branch = jx.Branch([comp]*nseg)" - ] - }, - { - "cell_type": "markdown", - "id": "63c35e7b", - "metadata": {}, - "source": [ - "In order to construct cell morphologies in Jaxley, multiple `Branches` can to be connected together using the `Cell` primitive." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "6f4c5202", - "metadata": {}, - "outputs": [], - "source": [ - "parents = [-1,0,0] # soma = -1, since it has no parents and both dendrites connect to the soma (0). \n", - "cell = jx.Cell([branch]*len(parents), parents)" - ] - }, - { - "cell_type": "markdown", - "id": "3830cf70", - "metadata": {}, - "source": [ - "Finally, several `Cell`s can be grouped together to form a `Network`, which can than be connected together using `Synpase`s." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "3399a716", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(2, 6, 24)" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ncells = 2\n", - "net = jx.Network([cell]*ncells)\n", - "\n", - "net.shape # shows you the num_cells, num_branches, num_comps" - ] - }, - { - "cell_type": "markdown", - "id": "30eb8fd5", - "metadata": {}, - "source": [ - "`Module`s carry around the information about their current state and parameters in two Dataframes called `nodes` and `edges`.\n", - "`nodes` contains all the information that we associate with compartments in the model (each row corresponds to a compartment) and `edges` all the information relevant to synapses.\n", - "\n", - "This means that you can easily keep track of the current state of your `Module` and how it changes at all times." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "701364b0", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
local_cell_indexlocal_branch_indexlocal_comp_indexlengthradiusaxial_resistivitycapacitancevglobal_cell_indexglobal_branch_indexglobal_comp_indexcontrolled_by_param
000010.01.05000.01.0-70.00000
100110.01.05000.01.0-70.00010
200210.01.05000.01.0-70.00020
300310.01.05000.01.0-70.00030
401010.01.05000.01.0-70.00140
501110.01.05000.01.0-70.00150
601210.01.05000.01.0-70.00160
701310.01.05000.01.0-70.00170
802010.01.05000.01.0-70.00280
902110.01.05000.01.0-70.00290
1002210.01.05000.01.0-70.002100
1102310.01.05000.01.0-70.002110
1210010.01.05000.01.0-70.013120
1310110.01.05000.01.0-70.013130
1410210.01.05000.01.0-70.013140
1510310.01.05000.01.0-70.013150
1611010.01.05000.01.0-70.014160
1711110.01.05000.01.0-70.014170
1811210.01.05000.01.0-70.014180
1911310.01.05000.01.0-70.014190
2012010.01.05000.01.0-70.015200
2112110.01.05000.01.0-70.015210
2212210.01.05000.01.0-70.015220
2312310.01.05000.01.0-70.015230
\n", - "
" - ], - "text/plain": [ - " local_cell_index local_branch_index local_comp_index length radius \\\n", - "0 0 0 0 10.0 1.0 \n", - "1 0 0 1 10.0 1.0 \n", - "2 0 0 2 10.0 1.0 \n", - "3 0 0 3 10.0 1.0 \n", - "4 0 1 0 10.0 1.0 \n", - "5 0 1 1 10.0 1.0 \n", - "6 0 1 2 10.0 1.0 \n", - "7 0 1 3 10.0 1.0 \n", - "8 0 2 0 10.0 1.0 \n", - "9 0 2 1 10.0 1.0 \n", - "10 0 2 2 10.0 1.0 \n", - "11 0 2 3 10.0 1.0 \n", - "12 1 0 0 10.0 1.0 \n", - "13 1 0 1 10.0 1.0 \n", - "14 1 0 2 10.0 1.0 \n", - "15 1 0 3 10.0 1.0 \n", - "16 1 1 0 10.0 1.0 \n", - "17 1 1 1 10.0 1.0 \n", - "18 1 1 2 10.0 1.0 \n", - "19 1 1 3 10.0 1.0 \n", - "20 1 2 0 10.0 1.0 \n", - "21 1 2 1 10.0 1.0 \n", - "22 1 2 2 10.0 1.0 \n", - "23 1 2 3 10.0 1.0 \n", - "\n", - " axial_resistivity capacitance v global_cell_index \\\n", - "0 5000.0 1.0 -70.0 0 \n", - "1 5000.0 1.0 -70.0 0 \n", - "2 5000.0 1.0 -70.0 0 \n", - "3 5000.0 1.0 -70.0 0 \n", - "4 5000.0 1.0 -70.0 0 \n", - "5 5000.0 1.0 -70.0 0 \n", - "6 5000.0 1.0 -70.0 0 \n", - "7 5000.0 1.0 -70.0 0 \n", - "8 5000.0 1.0 -70.0 0 \n", - "9 5000.0 1.0 -70.0 0 \n", - "10 5000.0 1.0 -70.0 0 \n", - "11 5000.0 1.0 -70.0 0 \n", - "12 5000.0 1.0 -70.0 1 \n", - "13 5000.0 1.0 -70.0 1 \n", - "14 5000.0 1.0 -70.0 1 \n", - "15 5000.0 1.0 -70.0 1 \n", - "16 5000.0 1.0 -70.0 1 \n", - "17 5000.0 1.0 -70.0 1 \n", - "18 5000.0 1.0 -70.0 1 \n", - "19 5000.0 1.0 -70.0 1 \n", - "20 5000.0 1.0 -70.0 1 \n", - "21 5000.0 1.0 -70.0 1 \n", - "22 5000.0 1.0 -70.0 1 \n", - "23 5000.0 1.0 -70.0 1 \n", - "\n", - " global_branch_index global_comp_index controlled_by_param \n", - "0 0 0 0 \n", - "1 0 1 0 \n", - "2 0 2 0 \n", - "3 0 3 0 \n", - "4 1 4 0 \n", - "5 1 5 0 \n", - "6 1 6 0 \n", - "7 1 7 0 \n", - "8 2 8 0 \n", - "9 2 9 0 \n", - "10 2 10 0 \n", - "11 2 11 0 \n", - "12 3 12 0 \n", - "13 3 13 0 \n", - "14 3 14 0 \n", - "15 3 15 0 \n", - "16 4 16 0 \n", - "17 4 17 0 \n", - "18 4 18 0 \n", - "19 4 19 0 \n", - "20 5 20 0 \n", - "21 5 21 0 \n", - "22 5 22 0 \n", - "23 5 23 0 " - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "net.nodes" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "eeaaefd1", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
global_edge_indexglobal_pre_comp_indexglobal_post_comp_indexpre_locspost_locstypetype_ind
\n", - "
" - ], - "text/plain": [ - "Empty DataFrame\n", - "Columns: [global_edge_index, global_pre_comp_index, global_post_comp_index, pre_locs, post_locs, type, type_ind]\n", - "Index: []" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "net.edges.head() # this is currently empty since we have not made any connections yet" - ] - }, - { - "cell_type": "markdown", - "id": "2904c0c5", - "metadata": {}, - "source": [ - "# Views" - ] - }, - { - "cell_type": "markdown", - "id": "fd0d4fa3", - "metadata": {}, - "source": [ - "Since these models can become arbitrarily complex, Jaxley utilizes so called `View`s to make working with `Modules` easy and intuitive. \n", - "\n", - "The simplest way to navigate Modules is by navigating them via the hierachy that we introduces above. Let's see how this works for a `Nework`." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "d6abcdec", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "View with 0 different channels. Use `.nodes` for details." - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "net.cell(0) # View of the 0th cell of the network\n", - "net.cell(0).branch(0) # View of the 1st branch of the 0th cell of the network\n", - "net.cell(0).branch(1).comp(0) # View of the 0th comp of the 1st branch of the 0th cell of the network\n", - "\n", - "# several types of indices are supported (lists, ranges, ...)\n", - "net.cell([0,1]).branch(\"all\").comp(0) # View of all 0th comps of all branches of cell 0 and 1\n", - "\n", - "branch.loc(0.1) # equivalent to `NEURON`s `loc`. Assumes branches are continous from 0-1.\n", - "\n", - "net[0,0,0] # Modules/Views can also be lazily indexed\n", - "\n", - "cell0 = net.cell(0) # views can be assigned to variables and only track the parts of the Module they belong to\n", - "cell0.branch(1).comp(0) # Views can be continuely indexed" - ] - }, - { - "cell_type": "markdown", - "id": "7283d2d4", - "metadata": {}, - "source": [ - "_In case you need even more flexibility in how you select parts of a Module, Jaxley provides a `select` method, to give full control over the exact parts of the `nodes` and `edges` that are part of a `View`. On examples of how this can be used, see [](MISSING)._" - ] - }, - { - "cell_type": "markdown", - "id": "4f8922c4", - "metadata": {}, - "source": [ - "Views behave very similarly to `Module`s, i.e. the `cell0` (the 0th cell of the network) from the example above handles like the `cell` we instantiated earlier, which then became cell 0 of the network. As such `cell0` also has a `nodes` attribute, which keeps track of it's part of the network." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "d5502655", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
local_cell_indexlocal_branch_indexlocal_comp_indexlengthradiusaxial_resistivitycapacitancevglobal_cell_indexglobal_branch_indexglobal_comp_indexcontrolled_by_param
000010.01.05000.01.0-70.00000
100110.01.05000.01.0-70.00010
200210.01.05000.01.0-70.00020
300310.01.05000.01.0-70.00030
401010.01.05000.01.0-70.00140
501110.01.05000.01.0-70.00150
601210.01.05000.01.0-70.00160
701310.01.05000.01.0-70.00170
802010.01.05000.01.0-70.00280
902110.01.05000.01.0-70.00290
1002210.01.05000.01.0-70.002100
1102310.01.05000.01.0-70.002110
\n", - "
" - ], - "text/plain": [ - " local_cell_index local_branch_index local_comp_index length radius \\\n", - "0 0 0 0 10.0 1.0 \n", - "1 0 0 1 10.0 1.0 \n", - "2 0 0 2 10.0 1.0 \n", - "3 0 0 3 10.0 1.0 \n", - "4 0 1 0 10.0 1.0 \n", - "5 0 1 1 10.0 1.0 \n", - "6 0 1 2 10.0 1.0 \n", - "7 0 1 3 10.0 1.0 \n", - "8 0 2 0 10.0 1.0 \n", - "9 0 2 1 10.0 1.0 \n", - "10 0 2 2 10.0 1.0 \n", - "11 0 2 3 10.0 1.0 \n", - "\n", - " axial_resistivity capacitance v global_cell_index \\\n", - "0 5000.0 1.0 -70.0 0 \n", - "1 5000.0 1.0 -70.0 0 \n", - "2 5000.0 1.0 -70.0 0 \n", - "3 5000.0 1.0 -70.0 0 \n", - "4 5000.0 1.0 -70.0 0 \n", - "5 5000.0 1.0 -70.0 0 \n", - "6 5000.0 1.0 -70.0 0 \n", - "7 5000.0 1.0 -70.0 0 \n", - "8 5000.0 1.0 -70.0 0 \n", - "9 5000.0 1.0 -70.0 0 \n", - "10 5000.0 1.0 -70.0 0 \n", - "11 5000.0 1.0 -70.0 0 \n", - "\n", - " global_branch_index global_comp_index controlled_by_param \n", - "0 0 0 0 \n", - "1 0 1 0 \n", - "2 0 2 0 \n", - "3 0 3 0 \n", - "4 1 4 0 \n", - "5 1 5 0 \n", - "6 1 6 0 \n", - "7 1 7 0 \n", - "8 2 8 0 \n", - "9 2 9 0 \n", - "10 2 10 0 \n", - "11 2 11 0 " - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cell0.nodes" - ] - }, - { - "cell_type": "markdown", - "id": "76e72d1c", - "metadata": {}, - "source": [ - "Assigning `View`s to a variable makes it easuer to reuse parts of a `Module` later or to highlight them. However, this can become messy and we might need access to such a `View` more readily. For this purpose Jaxley implements so called groups that can be used to assign any `View` of a `Module` to an attribute, i.e. the soma." - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "41d38b22", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Groups {'somas': array([ 0, 12])}\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
local_cell_indexlocal_branch_indexlocal_comp_indexlengthradiusaxial_resistivitycapacitancevglobal_cell_indexglobal_branch_indexglobal_comp_indexcontrolled_by_param
000010.01.05000.01.0-70.00000
1210010.01.05000.01.0-70.013120
\n", - "
" - ], - "text/plain": [ - " local_cell_index local_branch_index local_comp_index length radius \\\n", - "0 0 0 0 10.0 1.0 \n", - "12 1 0 0 10.0 1.0 \n", - "\n", - " axial_resistivity capacitance v global_cell_index \\\n", - "0 5000.0 1.0 -70.0 0 \n", - "12 5000.0 1.0 -70.0 1 \n", - "\n", - " global_branch_index global_comp_index controlled_by_param \n", - "0 0 0 0 \n", - "12 3 12 0 " - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "net.cell(\"all\").branch(0).comp(0).add_to_group(\"somas\")\n", - "print(\"Groups\", net.groups) # list the indices of the nodes dataframe that are part of the group\n", - "\n", - "somas = net.somas # returns a View with only a subset of nodes\n", - "somas.nodes" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "edac8921", - "metadata": {}, - "outputs": [], - "source": [ - "# connecting two cells using a Synapse\n", - "pre_comp = cell0.branch(1).comp(0)\n", - "post_comp = net.cell1.branch(0).comp(0)\n", - "\n", - "connect(pre_comp, post_comp)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "jaxley", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.1" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 439d59aa491bdad3791bfcc57af0b5a17cbb75c4 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 8 Nov 2024 13:14:34 +0100 Subject: [PATCH 8/9] enh: allow to remove multiple clamp --- jaxley/modules/base.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index cc7e393fc..0cdf5d5c0 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1660,23 +1660,27 @@ def delete_stimuli(self): """Removes all stimuli from the module.""" self.delete_clamps("i") - def delete_clamps(self, state_name: str): + def delete_clamps(self, state_name: Optional[str] = None): """Removes all clamps of the given state from the module.""" - if state_name in self.externals: - keep_inds = ~np.isin( - self.base.external_inds[state_name], self._nodes_in_view - ) - base_exts = self.base.externals - base_exts_inds = self.base.external_inds - if np.all(~keep_inds): - base_exts.pop(state_name, None) - base_exts_inds.pop(state_name, None) + all_externals = list(self.externals.keys()) + all_externals.remove("i") + state_names = all_externals if state_name is None else [state_name] + for state_name in state_names: + if state_name in self.externals: + keep_inds = ~np.isin( + self.base.external_inds[state_name], self._nodes_in_view + ) + base_exts = self.base.externals + base_exts_inds = self.base.external_inds + if np.all(~keep_inds): + base_exts.pop(state_name, None) + base_exts_inds.pop(state_name, None) + else: + base_exts[state_name] = base_exts[state_name][keep_inds] + base_exts_inds[state_name] = base_exts_inds[state_name][keep_inds] + self._update_view() else: - base_exts[state_name] = base_exts[state_name][keep_inds] - base_exts_inds[state_name] = base_exts_inds[state_name][keep_inds] - self._update_view() - else: - pass # does not have to be deleted if not in externals + pass # does not have to be deleted if not in externals def insert(self, channel: Channel): """Insert a channel into the module. From 1339917efac5ef210b0381dbfd288c1662ca2432 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 8 Nov 2024 13:27:52 +0100 Subject: [PATCH 9/9] fix: make tests pass --- jaxley/modules/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 0cdf5d5c0..92a73aed7 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1663,7 +1663,8 @@ def delete_stimuli(self): def delete_clamps(self, state_name: Optional[str] = None): """Removes all clamps of the given state from the module.""" all_externals = list(self.externals.keys()) - all_externals.remove("i") + if "i" in all_externals: + all_externals.remove("i") state_names = all_externals if state_name is None else [state_name] for state_name in state_names: if state_name in self.externals: