From 6efb231f8e1c2574a513c8d71d8c522d5f6c8dcc Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Fri, 25 Jul 2025 11:53:01 +0100 Subject: [PATCH 1/2] fix: `_build_si_recording_object` for new version of SI where extractorlist no longer exists --- .../spike_sorting/si_spike_sorting.py | 68 ++++++++++--------- 1 file changed, 37 insertions(+), 31 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 901106f9..aa5d7d37 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -143,11 +143,15 @@ def make(self, key): ) # get the row indices of the port # Create SI recording extractor object - si_extractor: si.extractors.neoextractors = ( - si.extractors.extractorlist.recording_extractor_full_dict[ - acq_software.replace(" ", "").lower() - ] - ) # data extractor object + from spikeinterface.extractors.extractor_classes import ( + recording_extractor_full_dict, + ) + + acq_key = acq_software.replace(" ", "").lower() + try: + si_extractor = recording_extractor_full_dict[acq_key] + except KeyError: + raise ValueError(f"Unsupported acquisition software: {acq_software}") files, file_times = ( ephys.EphysRawFile @@ -155,53 +159,55 @@ def make(self, key): & f"file_time BETWEEN '{key['start_time']}' AND '{key['end_time']}'" ).fetch("file_path", "file_time", order_by="file_time") - si_recording = None + # Detect stream name from the first file + first_file_path = find_full_path(ephys.get_ephys_root_data_dir(), files[0]) + available_streams = si_extractor.get_streams(first_file_path)[0] + amplifier_streams = [s for s in available_streams if "amplifier" in s.lower()] + + if not amplifier_streams: + raise ValueError( + f"No amplifier stream found in {first_file_path}. " + f"Available streams: {available_streams}" + ) + + stream_name = amplifier_streams[0] + # Read data. Concatenate if multiple files are found. + si_recording = None for file_path in ( find_full_path(ephys.get_ephys_root_data_dir(), f) for f in files ): if not si_recording: - stream_name = [ - s - for s in si_extractor.get_streams(file_path)[0] - if "amplifier" in s - ][0] - si_recording: si.BaseRecording = si_extractor( - file_path, stream_name=stream_name - ) + si_recording = si_extractor(file_path, stream_name=stream_name) else: - si_recording: si.BaseRecording = si.concatenate_recordings( - [ - si_recording, - si_extractor(file_path, stream_name=stream_name), - ] - ) - + si_recording = si.concatenate_recordings([ + si_recording, + si_extractor(file_path, stream_name=stream_name), + ]) + + # Restrict to channels from the target port si_recording = si_recording.channel_slice( si_recording.channel_ids[port_indices] - ) # select only the port data + ) - # Create SI probe object + # Apply probe geometry si_probe = readers.probe_geometry.to_probeinterface(electrodes_df) si_probe.set_device_channel_indices(electrodes_df["channel_idx"].values) si_recording.set_probe(probe=si_probe, in_place=True) - # Account for additional electrodes being removed + # Remove unused electrodes if unused_electrodes: chn_ids_to_remove = [ f"{probe_info['port_id']}-{electrodes_df.channel_idx.iloc[elec]:03d}" for elec in unused_electrodes ] - else: - chn_ids_to_remove = [] + si_recording = si_recording.remove_channels(remove_channel_ids=chn_ids_to_remove) - si_recording = si_recording.remove_channels( - remove_channel_ids=chn_ids_to_remove - ) - - # Run preprocessing and save results to output folder + # Preprocess si_preproc_func = getattr(si_preprocessing, params["SI_PREPROCESSING_METHOD"]) si_recording = si_preproc_func(si_recording) + + # Save to pickle si_recording.dump_to_pickle(file_path=recording_file, relative_to=output_dir) self.insert1( From 9d61b777248c524073bf9195e595033e5f5bd555 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Fri, 25 Jul 2025 12:03:00 +0100 Subject: [PATCH 2/2] black formatting --- .../spike_sorting/si_spike_sorting.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index aa5d7d37..d9cdebe8 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -146,7 +146,7 @@ def make(self, key): from spikeinterface.extractors.extractor_classes import ( recording_extractor_full_dict, ) - + acq_key = acq_software.replace(" ", "").lower() try: si_extractor = recording_extractor_full_dict[acq_key] @@ -163,13 +163,13 @@ def make(self, key): first_file_path = find_full_path(ephys.get_ephys_root_data_dir(), files[0]) available_streams = si_extractor.get_streams(first_file_path)[0] amplifier_streams = [s for s in available_streams if "amplifier" in s.lower()] - + if not amplifier_streams: raise ValueError( f"No amplifier stream found in {first_file_path}. " f"Available streams: {available_streams}" ) - + stream_name = amplifier_streams[0] # Read data. Concatenate if multiple files are found. @@ -180,11 +180,13 @@ def make(self, key): if not si_recording: si_recording = si_extractor(file_path, stream_name=stream_name) else: - si_recording = si.concatenate_recordings([ - si_recording, - si_extractor(file_path, stream_name=stream_name), - ]) - + si_recording = si.concatenate_recordings( + [ + si_recording, + si_extractor(file_path, stream_name=stream_name), + ] + ) + # Restrict to channels from the target port si_recording = si_recording.channel_slice( si_recording.channel_ids[port_indices] @@ -201,7 +203,9 @@ def make(self, key): f"{probe_info['port_id']}-{electrodes_df.channel_idx.iloc[elec]:03d}" for elec in unused_electrodes ] - si_recording = si_recording.remove_channels(remove_channel_ids=chn_ids_to_remove) + si_recording = si_recording.remove_channels( + remove_channel_ids=chn_ids_to_remove + ) # Preprocess si_preproc_func = getattr(si_preprocessing, params["SI_PREPROCESSING_METHOD"])