diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 901106f9..d9cdebe8 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -143,11 +143,15 @@ def make(self, key): ) # get the row indices of the port # Create SI recording extractor object - si_extractor: si.extractors.neoextractors = ( - si.extractors.extractorlist.recording_extractor_full_dict[ - acq_software.replace(" ", "").lower() - ] - ) # data extractor object + from spikeinterface.extractors.extractor_classes import ( + recording_extractor_full_dict, + ) + + acq_key = acq_software.replace(" ", "").lower() + try: + si_extractor = recording_extractor_full_dict[acq_key] + except KeyError: + raise ValueError(f"Unsupported acquisition software: {acq_software}") files, file_times = ( ephys.EphysRawFile @@ -155,53 +159,59 @@ def make(self, key): & f"file_time BETWEEN '{key['start_time']}' AND '{key['end_time']}'" ).fetch("file_path", "file_time", order_by="file_time") - si_recording = None + # Detect stream name from the first file + first_file_path = find_full_path(ephys.get_ephys_root_data_dir(), files[0]) + available_streams = si_extractor.get_streams(first_file_path)[0] + amplifier_streams = [s for s in available_streams if "amplifier" in s.lower()] + + if not amplifier_streams: + raise ValueError( + f"No amplifier stream found in {first_file_path}. " + f"Available streams: {available_streams}" + ) + + stream_name = amplifier_streams[0] + # Read data. Concatenate if multiple files are found. + si_recording = None for file_path in ( find_full_path(ephys.get_ephys_root_data_dir(), f) for f in files ): if not si_recording: - stream_name = [ - s - for s in si_extractor.get_streams(file_path)[0] - if "amplifier" in s - ][0] - si_recording: si.BaseRecording = si_extractor( - file_path, stream_name=stream_name - ) + si_recording = si_extractor(file_path, stream_name=stream_name) else: - si_recording: si.BaseRecording = si.concatenate_recordings( + si_recording = si.concatenate_recordings( [ si_recording, si_extractor(file_path, stream_name=stream_name), ] ) + # Restrict to channels from the target port si_recording = si_recording.channel_slice( si_recording.channel_ids[port_indices] - ) # select only the port data + ) - # Create SI probe object + # Apply probe geometry si_probe = readers.probe_geometry.to_probeinterface(electrodes_df) si_probe.set_device_channel_indices(electrodes_df["channel_idx"].values) si_recording.set_probe(probe=si_probe, in_place=True) - # Account for additional electrodes being removed + # Remove unused electrodes if unused_electrodes: chn_ids_to_remove = [ f"{probe_info['port_id']}-{electrodes_df.channel_idx.iloc[elec]:03d}" for elec in unused_electrodes ] - else: - chn_ids_to_remove = [] - - si_recording = si_recording.remove_channels( - remove_channel_ids=chn_ids_to_remove - ) + si_recording = si_recording.remove_channels( + remove_channel_ids=chn_ids_to_remove + ) - # Run preprocessing and save results to output folder + # Preprocess si_preproc_func = getattr(si_preprocessing, params["SI_PREPROCESSING_METHOD"]) si_recording = si_preproc_func(si_recording) + + # Save to pickle si_recording.dump_to_pickle(file_path=recording_file, relative_to=output_dir) self.insert1(