Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 35 additions & 25 deletions element_array_ephys/spike_sorting/si_spike_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,65 +143,75 @@ def make(self, key):
) # get the row indices of the port

# Create SI recording extractor object
si_extractor: si.extractors.neoextractors = (
si.extractors.extractorlist.recording_extractor_full_dict[
acq_software.replace(" ", "").lower()
]
) # data extractor object
from spikeinterface.extractors.extractor_classes import (
recording_extractor_full_dict,
)

acq_key = acq_software.replace(" ", "").lower()
try:
si_extractor = recording_extractor_full_dict[acq_key]
except KeyError:
raise ValueError(f"Unsupported acquisition software: {acq_software}")

files, file_times = (
ephys.EphysRawFile
& key
& f"file_time BETWEEN '{key['start_time']}' AND '{key['end_time']}'"
).fetch("file_path", "file_time", order_by="file_time")

si_recording = None
# Detect stream name from the first file
first_file_path = find_full_path(ephys.get_ephys_root_data_dir(), files[0])
available_streams = si_extractor.get_streams(first_file_path)[0]
amplifier_streams = [s for s in available_streams if "amplifier" in s.lower()]

if not amplifier_streams:
raise ValueError(
f"No amplifier stream found in {first_file_path}. "
f"Available streams: {available_streams}"
)

stream_name = amplifier_streams[0]

# Read data. Concatenate if multiple files are found.
si_recording = None
for file_path in (
find_full_path(ephys.get_ephys_root_data_dir(), f) for f in files
):
if not si_recording:
stream_name = [
s
for s in si_extractor.get_streams(file_path)[0]
if "amplifier" in s
][0]
si_recording: si.BaseRecording = si_extractor(
file_path, stream_name=stream_name
)
si_recording = si_extractor(file_path, stream_name=stream_name)
else:
si_recording: si.BaseRecording = si.concatenate_recordings(
si_recording = si.concatenate_recordings(
[
si_recording,
si_extractor(file_path, stream_name=stream_name),
]
)

# Restrict to channels from the target port
si_recording = si_recording.channel_slice(
si_recording.channel_ids[port_indices]
) # select only the port data
)

# Create SI probe object
# Apply probe geometry
si_probe = readers.probe_geometry.to_probeinterface(electrodes_df)
si_probe.set_device_channel_indices(electrodes_df["channel_idx"].values)
si_recording.set_probe(probe=si_probe, in_place=True)

# Account for additional electrodes being removed
# Remove unused electrodes
if unused_electrodes:
chn_ids_to_remove = [
f"{probe_info['port_id']}-{electrodes_df.channel_idx.iloc[elec]:03d}"
for elec in unused_electrodes
]
else:
chn_ids_to_remove = []

si_recording = si_recording.remove_channels(
remove_channel_ids=chn_ids_to_remove
)
si_recording = si_recording.remove_channels(
remove_channel_ids=chn_ids_to_remove
)

# Run preprocessing and save results to output folder
# Preprocess
si_preproc_func = getattr(si_preprocessing, params["SI_PREPROCESSING_METHOD"])
si_recording = si_preproc_func(si_recording)

# Save to pickle
si_recording.dump_to_pickle(file_path=recording_file, relative_to=output_dir)

self.insert1(
Expand Down