Simpler edge indexing#487
Conversation
|
michaeldeistler
left a comment
There was a problem hiding this comment.
Nice!
I am a bit worried about memory consumption though. If every jaxedge has every state, then there will a lot of NaNs each of which contains memory. I think that, in the worst case, memory increases by N**2, where is the number of types of synapses.
Can we have a zoom about this? E.g. next week?
|
True! I did not think about this. I have two thoughts though. Why not do this for nodes on a per mechanism basis as well. And is there not a more straight forward way to avoid nans while still using global indexes? |
|
I agree that we should also do it like this for nodes. Let's zoom next week! |
f3ec50a to
aa4ae5f
Compare
| def dtype_aware_concat(dfs): | ||
| concat_df = pd.concat(dfs, ignore_index=True) | ||
| # replace nans with Nones | ||
| # this correctly casts float(None) -> NaN, bool(None) -> NaN, etc. | ||
| concat_df[concat_df.isna()] = None | ||
| for col in concat_df.columns[concat_df.dtypes == "object"]: | ||
| for df in dfs: | ||
| if col in df.columns: | ||
| concat_df[col] = concat_df[col].astype(df[col].dtype) | ||
| break # first match is sufficient | ||
| return concat_df |
There was a problem hiding this comment.
fixes a bug, where dtype of bool columns was changed upon concatenation in:
comp1 = jx.Compartment()
comp2 = jx.Compartment()
comp2.insert(HH())
branch = jx.Branch([comp1, comp2)] #-> branch.nodes["HH"].dtype was `object` in this case| self.synapse_param_names = [] | ||
| self.synapse_state_names = [] | ||
| self.synapse_names = [] |
|
|
||
| def __dir__(self): | ||
| base_dir = object.__dir__(self) | ||
| return sorted(base_dir + self.synapse_names + list(self.group_nodes.keys())) |
7b4e771 to
b58c3dd
Compare
|
Main changes:
In general I think the code is a bit cleaner and more consolidated now (-150 line diff). Will also run regression tests. Maybe its even faster as well. I am not super happy about the Lemme know what you think or if you have any questions. The only thing I touched in terms of channels / synapse rewrite is the params/states rename. Does this warrant a seperate branch? |
michaeldeistler
left a comment
There was a problem hiding this comment.
Thanks a lot @jnsbck!
I am already off for the holidays so will not be able to do a detailed review, but:
- I think all your high-level explanations sound good!
- Yes, let's make a new
v1.0branch and merge it there because it breaks all channel models.
And yes, please run regression tests and ensure that all is good! Good to go then! Thanks!
a0183be to
e17f925
Compare
|
/test_regression |
Regression Test Results✅ Process completed |
e17f925 to
62351b8
Compare
…er unecessary things
|
/test_regression |
Regression Test Results❌ Process completed |
|
I have done some refactoring and am now much happier with how the code looks. Locally tests pass, but since this is not a pull request to main (but v1.0), tests are not being run here. I also started another regression tests (see above), which I hope will be as quick as prev, i.e. faster than whats in Would still be great if you could have a more thorough look before I merge it, when your back in the office. |
|
...more refactoring and consolidation imo. Somewhat forward thinking, now /test_regression |
Regression Test Results❌ Process completed |
| # Update states of the channels. | ||
| channel_nodes = self.base.nodes | ||
| states = self.base._get_states_from_nodes_and_edges() | ||
| self.to_jax() # Create `.jax` from `.nodes` and `.edges`. |
There was a problem hiding this comment.
why does this run here and not in jx.integrate (where it was before)?
import jaxley as jx
from jaxley.channels.pospischil import Leak, Na, K
comp = jx.Compartment()
branch = jx.Branch(comp, ncomp=3)
branch.comp([0, 1]).insert(Na())
branch.comp([1, 2]).insert(K())
branch.to_jax()
print(branch.jax["nodes"]["params"]["vt"]) |
Currently synapse and channel parameters / states are handled differently. While channel params are referred to with global indices, synapse params are referenced on a per synapse basis. This leads to very different implementations of synapse and channel updates. This PR makes the synapse indexing global, which simplifies several aspects of the code and allows for more function reuse.
This could potentially be simplified even further, such that channels and synapses can be handled through mostly the same functions.