From 32578f6a79cfb8f6fadfe10ee3893677e81ddfb4 Mon Sep 17 00:00:00 2001 From: ttngu207 Date: Tue, 20 Jan 2026 14:46:54 -0600 Subject: [PATCH] Fix sparse template handling in kilosort reader for SI-exported Phy data SpikeInterface's export_to_phy exports sparse templates by default (since v0.101.0), with templates.npy shape (n_templates, n_samples, max_sparse_channels) instead of the full (n_templates, n_samples, n_channels) format from native Kilosort. The companion file templates_ind.npy maps (template_idx, sparse_channel_idx) to actual channel indices, with -1 indicating padding. This fix updates get_best_channel() and extract_spike_depths() to: - Check if templates_ind exists (indicates SI-exported sparse format) - Use templates_ind to map sparse indices to actual channel indices - Fall back to original behavior for native Kilosort (dense) format Without this fix, spike_sites and best_channel values are incorrect when reading SI-exported Phy curations, as argmax returns indices into the sparse representation rather than actual channel indices. Related: dj-sciops/nei_nienborg#111 (Issue #2) Co-Authored-By: Claude Opus 4.5 --- element_array_ephys/readers/kilosort.py | 46 +++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/element_array_ephys/readers/kilosort.py b/element_array_ephys/readers/kilosort.py index 5e51519e..aa6c1ac7 100644 --- a/element_array_ephys/readers/kilosort.py +++ b/element_array_ephys/readers/kilosort.py @@ -151,7 +151,28 @@ def get_best_channel(self, unit): ] channel_templates = self.data["templates"][template_idx, :, :] max_channel_idx = np.abs(channel_templates).max(axis=0).argmax() - max_channel = self.data["channel_map"][max_channel_idx] + + # Handle sparse templates (SpikeInterface export_to_phy format) + # templates_ind maps (template_idx, sparse_channel_idx) -> actual channel_idx + # Value of -1 indicates padding (no channel at that position) + if "templates_ind" in self.data: + # Use templates_ind to get actual channel index for this template + actual_channel_idx = self.data["templates_ind"][template_idx, max_channel_idx] + if actual_channel_idx >= 0: + max_channel = self.data["channel_map"][actual_channel_idx] + else: + # Fallback if the max is in a padded position (shouldn't happen normally) + log.warning( + f"Unit {unit}: max amplitude in padded channel position, " + "falling back to first valid channel" + ) + valid_channels = self.data["templates_ind"][template_idx] + valid_idx = valid_channels[valid_channels >= 0][0] + max_channel = self.data["channel_map"][valid_idx] + max_channel_idx = valid_idx + else: + # Dense templates (native Kilosort format) - use channel_map directly + max_channel = self.data["channel_map"][max_channel_idx] return max_channel, max_channel_idx @@ -179,9 +200,28 @@ def extract_spike_depths(self): self._data["spike_depths"] = None # ---- extract spike sites ---- + # For each template, find the channel with maximum amplitude max_site_ind = np.argmax(np.abs(self.data["templates"]).max(axis=1), axis=1) - spike_site_ind = max_site_ind[self.data["spike_templates"]] - self._data["spike_sites"] = self.data["channel_map"][spike_site_ind] + + # Handle sparse templates (SpikeInterface export_to_phy format) + # templates_ind maps (template_idx, sparse_channel_idx) -> actual channel_idx + if "templates_ind" in self.data: + # Map sparse indices to actual channel indices using templates_ind + # templates_ind shape: (n_templates, max_sparse_channels) + templates_ind = self.data["templates_ind"] + # For each template, get the actual channel index at the max position + actual_channel_indices = templates_ind[ + np.arange(len(max_site_ind)), max_site_ind + ] + # Map spike templates to their actual channel indices + spike_actual_channel_ind = actual_channel_indices[ + self.data["spike_templates"] + ] + self._data["spike_sites"] = self.data["channel_map"][spike_actual_channel_ind] + else: + # Dense templates (native Kilosort format) - use channel_map directly + spike_site_ind = max_site_ind[self.data["spike_templates"]] + self._data["spike_sites"] = self.data["channel_map"][spike_site_ind] def extract_clustering_info(cluster_output_dir):