Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions element_array_ephys/readers/kilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,28 @@ def get_best_channel(self, unit):
]
channel_templates = self.data["templates"][template_idx, :, :]
max_channel_idx = np.abs(channel_templates).max(axis=0).argmax()
max_channel = self.data["channel_map"][max_channel_idx]

# Handle sparse templates (SpikeInterface export_to_phy format)
# templates_ind maps (template_idx, sparse_channel_idx) -> actual channel_idx
# Value of -1 indicates padding (no channel at that position)
if "templates_ind" in self.data:
# Use templates_ind to get actual channel index for this template
actual_channel_idx = self.data["templates_ind"][template_idx, max_channel_idx]
if actual_channel_idx >= 0:
max_channel = self.data["channel_map"][actual_channel_idx]
else:
# Fallback if the max is in a padded position (shouldn't happen normally)
log.warning(
f"Unit {unit}: max amplitude in padded channel position, "
"falling back to first valid channel"
)
valid_channels = self.data["templates_ind"][template_idx]
valid_idx = valid_channels[valid_channels >= 0][0]
max_channel = self.data["channel_map"][valid_idx]
max_channel_idx = valid_idx
else:
# Dense templates (native Kilosort format) - use channel_map directly
max_channel = self.data["channel_map"][max_channel_idx]

return max_channel, max_channel_idx

Expand Down Expand Up @@ -179,9 +200,28 @@ def extract_spike_depths(self):
self._data["spike_depths"] = None

# ---- extract spike sites ----
# For each template, find the channel with maximum amplitude
max_site_ind = np.argmax(np.abs(self.data["templates"]).max(axis=1), axis=1)
spike_site_ind = max_site_ind[self.data["spike_templates"]]
self._data["spike_sites"] = self.data["channel_map"][spike_site_ind]

# Handle sparse templates (SpikeInterface export_to_phy format)
# templates_ind maps (template_idx, sparse_channel_idx) -> actual channel_idx
if "templates_ind" in self.data:
# Map sparse indices to actual channel indices using templates_ind
# templates_ind shape: (n_templates, max_sparse_channels)
templates_ind = self.data["templates_ind"]
# For each template, get the actual channel index at the max position
actual_channel_indices = templates_ind[
np.arange(len(max_site_ind)), max_site_ind
]
# Map spike templates to their actual channel indices
spike_actual_channel_ind = actual_channel_indices[
self.data["spike_templates"]
]
self._data["spike_sites"] = self.data["channel_map"][spike_actual_channel_ind]
else:
# Dense templates (native Kilosort format) - use channel_map directly
spike_site_ind = max_site_ind[self.data["spike_templates"]]
self._data["spike_sites"] = self.data["channel_map"][spike_site_ind]


def extract_clustering_info(cluster_output_dir):
Expand Down
Loading