diff --git a/element_array_ephys/ephys.py b/element_array_ephys/ephys.py index f43aa620..3293fd0e 100644 --- a/element_array_ephys/ephys.py +++ b/element_array_ephys/ephys.py @@ -5,6 +5,7 @@ import re from decimal import Decimal from typing import Any, Dict, List, Tuple + import datajoint as dj import numpy as np import pandas as pd @@ -112,6 +113,18 @@ def get_processed_root_data_dir() -> str: return get_ephys_root_data_dir()[0] +def get_sync_ephys_function(ephys_recording_key: dict): + """Retrieve the synchronization function for the specified ephys recording. + + Args: + ephys_recording_key (dict): A dictionary containing the primary key for the EphysRecording table. + + Returns: + A function that can be used to synchronize timestamps of the ephys recording (and spikes) to the primary clock. + """ + return _linking_module.get_sync_ephys_function(ephys_recording_key) + + # ----------------------------- Table declarations ---------------------- @@ -124,7 +137,7 @@ class AcquisitionSoftware(dj.Lookup): """ definition = """ # Name of software used for recording of neuropixels probes - SpikeGLX or Open Ephys - acq_software: varchar(24) + acq_software: varchar(24) """ contents = zip(["SpikeGLX", "Open Ephys"]) @@ -261,11 +274,11 @@ class EphysRecording(dj.Imported): definition = """ # Ephys recording from a probe insertion for a given session. - -> ProbeInsertion + -> ProbeInsertion --- -> probe.ElectrodeConfig -> AcquisitionSoftware - sampling_rate: float # (Hz) + sampling_rate: float # (Hz) recording_datetime: datetime # datetime of the recording from this probe recording_duration: float # (seconds) duration of the recording from this probe """ @@ -304,7 +317,12 @@ def make_fetch(self, key: dict) -> Tuple[pathlib.Path, str, List[str], List[str] supported_probe_types = list(probe.ProbeType.fetch("probe_type")) supported_acq_software = list(AcquisitionSoftware.fetch("acq_software")) - return session_dir, inserted_probe_serial_number, supported_probe_types, supported_acq_software + return ( + session_dir, + inserted_probe_serial_number, + supported_probe_types, + supported_acq_software, + ) def make_compute( self, @@ -313,7 +331,13 @@ def make_compute( inserted_probe_serial_number: str, supported_probe_types: List[str], supported_acq_software: List[str], - ) -> Tuple[Dict[str, Any], List[Dict[str, Any]], Dict[str, Any], List[Dict[str, Any]], List[Dict[str, Any]]]: + ) -> Tuple[ + Dict[str, Any], + List[Dict[str, Any]], + Dict[str, Any], + List[Dict[str, Any]], + List[Dict[str, Any]], + ]: """Populates table with electrophysiology recording information.""" # Search session dir and determine acquisition software @@ -495,7 +519,13 @@ def make_compute( f"Processing ephys files from acquisition software of type {acq_software} is not yet implemented." ) - return econfig_entry, econfig_electrodes, ephys_recording_entry, ephys_file_entries, ephys_channel_entries + return ( + econfig_entry, + econfig_electrodes, + ephys_recording_entry, + ephys_file_entries, + ephys_channel_entries, + ) def make_insert( self, @@ -550,21 +580,28 @@ class Electrode(dj.Part): definition = """ -> master - -> probe.ElectrodeConfig.Electrode + -> probe.ElectrodeConfig.Electrode --- - lfp: longblob # (uV) recorded lfp at this electrode + lfp: longblob # (uV) recorded lfp at this electrode """ - # Only store LFP for every 9th channel, due to high channel density, - # close-by channels exhibit highly similar LFP - _skip_channel_counts = 9 - def make(self, key): """Populates the LFP tables.""" acq_software = (EphysRecording * ProbeInsertion & key).fetch1("acq_software") - electrode_keys, lfp = [], [] + # Get probe information to recording object + electrodes_df = ( + ( + EphysRecording.Channel + * probe.ElectrodeConfig.Electrode + * probe.ProbeType.Electrode + & key + ) + .fetch(format="frame") + .reset_index() + ) + electrode_keys, lfp = [], [] if acq_software == "SpikeGLX": spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) @@ -611,52 +648,58 @@ def make(self, key): "data" ][recorded_site] electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)]) - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(key) + elif acq_software == "Trellis": + import spikeinterface as si + from spikeinterface import extractors + + si_extractor = ( + si.extractors.neoextractors.blackrock.BlackrockRecordingExtractor + ) - lfp_channel_ind = np.r_[ - len(oe_probe.lfp_meta["channels_indices"]) - - 1 : 0 : -self._skip_channel_counts + nsx2_relpaths = (EphysRecording.EphysFile & key).fetch("file_path") + nsx2_fullpaths = [ + find_full_path(get_ephys_root_data_dir(), f + ".ns2") + for f in nsx2_relpaths ] + si_recs = [] + for f in nsx2_fullpaths: + si_rec = si_extractor(file_path=f, stream_name="nsx2") + # find & remove non-ephys channels + non_ephys_chns = set(si_rec.channel_ids) - set( + (electrodes_df.channel_idx.values + 1).astype(str) + ) + si_recs.append(si_rec.remove_channels(list(non_ephys_chns))) + si_recording = si.concatenate_recordings(si_recs) - lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind] # (sample x channel) - lfp = ( - lfp * np.array(oe_probe.lfp_meta["channels_gains"])[lfp_channel_ind] - ).T # (channel x sample) - lfp_timestamps = oe_probe.lfp_timestamps + ephys_sync_func = get_sync_ephys_function(key) + synced_timestamps = ephys_sync_func(si_recording.get_times()) self.insert1( dict( key, - lfp_sampling_rate=oe_probe.lfp_meta["sample_rate"], - lfp_time_stamps=lfp_timestamps, - lfp_mean=lfp.mean(axis=0), + lfp_sampling_rate=si_recording.sampling_frequency, + lfp_time_stamps=synced_timestamps, + lfp_mean=si_recording.get_traces(return_scaled=True).mean(axis=1), ) ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - electrode_keys.extend( - probe_electrodes[channel_idx] for channel_idx in lfp_channel_ind - ) + # single insert in loop to mitigate potential memory issue + for chn_idx in range(si_recording.get_num_channels()): + chn_id = si_recording.channel_ids[chn_idx] + lfp = si_recording.get_traces( + channel_ids=[chn_id], return_scaled=True + ).flatten() + electrode_key = dict( + electrodes_df.query(f"channel_idx == {chn_idx}").iloc[0] + ) + self.Electrode.insert1( + {**key, **electrode_key, "lfp": lfp}, ignore_extra_fields=True + ) else: raise NotImplementedError( f"LFP extraction from acquisition software" f" of type {acq_software} is not yet implemented" ) - # single insert in loop to mitigate potential memory issue - for electrode_key, lfp_trace in zip(electrode_keys, lfp): - self.Electrode.insert1({**key, **electrode_key, "lfp": lfp_trace}) - # ------------ Clustering -------------- @@ -681,6 +724,7 @@ class ClusteringMethod(dj.Lookup): ("kilosort2", "kilosort2 clustering method"), ("kilosort2.5", "kilosort2.5 clustering method"), ("kilosort3", "kilosort3 clustering method"), + ("kilosort4", "kilosort4 clustering method"), ] @@ -700,7 +744,7 @@ class ClusteringParamSet(dj.Lookup): # Parameter set to be used in a clustering procedure paramset_idx: smallint --- - -> ClusteringMethod + -> ClusteringMethod paramset_desc: varchar(128) param_set_hash: uuid unique index (param_set_hash) @@ -829,7 +873,7 @@ def infer_output_dir(cls, key, relative: bool = False, mkdir: bool = False): output_dir = ( processed_dir / session_dir.relative_to(root_dir) - / f'probe_{key["insertion_number"]}' + / f'insertion_{key["insertion_number"]}' / f'{method}_{key["paramset_idx"]}' ) @@ -886,7 +930,7 @@ class Clustering(dj.Imported): # Clustering Procedure -> ClusteringTask --- - clustering_time: datetime # time of generation of this set of clustering results + clustering_time: datetime # time of generation of this set of clustering results package_version='': varchar(16) """ @@ -1002,7 +1046,7 @@ class CuratedClustering(dj.Imported): definition = """ # Clustering results of the spike sorting step. - -> Clustering + -> Clustering """ class Unit(dj.Part): @@ -1019,7 +1063,7 @@ class Unit(dj.Part): spike_depths (longblob): Array of depths associated with each spike, relative to each spike. """ - definition = """ + definition = """ # Properties of a given unit from a round of clustering (and curation) -> master unit: int @@ -1029,14 +1073,21 @@ class Unit(dj.Part): spike_count: int # how many spikes in this recording for this unit spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording spike_sites : longblob # array of electrode associated with each spike - spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe + spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe """ - def make(self, key): - """Automated population of Unit information.""" + class ManualLabel(dj.Part): + definition = """ + -> master.Unit + --- + manual_label: varchar(64) # manual label for a particular unit/cluster + """ + + def make_fetch(self, key, **kwargs): clustering_method, output_dir = ( ClusteringTask * ClusteringParamSet & key ).fetch1("clustering_method", "clustering_output_dir") + output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) # Get channel and electrode-site mapping @@ -1045,11 +1096,70 @@ def make(self, key): chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True) } - # Get sorter method and create output directory. sorter_name = clustering_method.replace(".", "_") si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" - if si_sorting_analyzer_dir.exists(): # Read from spikeinterface outputs + si_export_exists = si_sorting_analyzer_dir.exists() + + if si_export_exists: # Read from spikeinterface outputs + # create channel2electrode_map + electrode_map: dict[int, dict] = { + elec["electrode"]: elec for elec in electrode_query.fetch(as_dict=True) + } + + ephys_sync_func = get_sync_ephys_function(key) + + return ( + ( + si_export_exists, + output_dir, + sorter_name, + channel2electrode_map, + electrode_map, + ephys_sync_func, + ), + ) + + else: + logger.warning( + "SI export not found: populating CuratedClustering with Kilosort output. May not sync correctly!" + ) + acq_software, sample_rate = (EphysRecording & key).fetch1( + "acq_software", "sampling_rate" + ) + + return ( + ( + si_export_exists, + output_dir, + sorter_name, + channel2electrode_map, + sample_rate, + ), + ) + + def make_compute(self, key, fetched): + # unpack passed values + if fetched[0]: + ( + si_export_exists, + output_dir, + sorter_name, + channel2electrode_map, + electrode_map, + ephys_sync_func, + ) = fetched + else: + ( + si_export_exists, + output_dir, + sorter_name, + channel2electrode_map, + sample_rate, + ) = fetched + + si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" + if si_export_exists: import spikeinterface as si from spikeinterface import sorters @@ -1061,8 +1171,7 @@ def make(self, key): logger.info( f"No units found in {sorting_file}. Skipping Unit ingestion..." ) - self.insert1(key) - return + return (None,) sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) si_sorting = sorting_analyzer.sorting @@ -1081,10 +1190,17 @@ def make(self, key): spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} - # update channel2electrode_map to match with probe's channel index + # create channel2electrode_map + # (electrode_map created in make_fetch) + # electrode_map: dict[int, dict] = { + # elec["electrode"]: elec for elec in electrode_query.fetch(as_dict=True) + # } channel2electrode_map = { - idx: channel2electrode_map[int(chn_idx)] - for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids) + chn_idx: electrode_map[int(elec_id)] + for chn_idx, elec_id in zip( + sorting_analyzer.get_probe().device_channel_indices, + sorting_analyzer.get_probe().contact_ids, + ) } # Get unit id to quality label mapping @@ -1107,6 +1223,8 @@ def make(self, key): ) ) + # ephys_sync_func = get_sync_ephys_function(key) # from make_fetch + units = [] for unit_idx, unit_id in enumerate(si_sorting.unit_ids): unit_id = int(unit_id) @@ -1122,6 +1240,7 @@ def make(self, key): spike_times = si_sorting.get_unit_spike_train( unit_id, return_times=True ) + spike_times = ephys_sync_func(spike_times) assert len(spike_times) == len(spike_sites) == len(spike_depths) @@ -1137,11 +1256,12 @@ def make(self, key): "spike_depths": spike_depths, } ) + else: # read from kilosort outputs kilosort_dataset = kilosort.Kilosort(output_dir) - acq_software, sample_rate = (EphysRecording & key).fetch1( - "acq_software", "sampling_rate" - ) + # acq_software, sample_rate = (EphysRecording & key).fetch1( + # "acq_software", "sampling_rate" + # ) sample_rate = kilosort_dataset.data["params"].get( "sample_rate", sample_rate @@ -1208,8 +1328,207 @@ def make(self, key): } ) - self.insert1(key) - self.Unit.insert(units, ignore_extra_fields=True) + return ((units,),) + + def make_insert(self, key, computed): + (units,) = computed + + if units is None: + self.insert1(key) + + # split insert to make transaction more manageable + else: + self.insert1(key) + for i in range(0, len(units), 64): + units_slice = units[i : i + 64] + self.Unit.insert(units_slice, ignore_extra_fields=True) + + # def make(self, key): + # """Automated population of Unit information.""" + # clustering_method, output_dir = ( + # ClusteringTask * ClusteringParamSet & key + # ).fetch1("clustering_method", "clustering_output_dir") + # output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) + + # # Get channel and electrode-site mapping + # electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name") + # channel2electrode_map: dict[int, dict] = { + # chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True) + # } + # # Get sorter method and create output directory. + # sorter_name = clustering_method.replace(".", "_") + # si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" + + # if si_sorting_analyzer_dir.exists(): # Read from spikeinterface outputs + # import spikeinterface as si + # from spikeinterface import sorters + + # sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" + # si_sorting_: si.sorters.BaseSorter = si.load_extractor( + # sorting_file, base_folder=output_dir + # ) + # if si_sorting_.unit_ids.size == 0: + # logger.info( + # f"No units found in {sorting_file}. Skipping Unit ingestion..." + # ) + # self.insert1(key) + # return + + # sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) + # si_sorting = sorting_analyzer.sorting + + # # Find representative channel for each unit + # unit_peak_channel: dict[int, np.ndarray] = ( + # si.ChannelSparsity.from_best_channels( + # sorting_analyzer, + # 1, + # ).unit_id_to_channel_indices + # ) + # unit_peak_channel: dict[int, int] = { + # u: chn[0] for u, chn in unit_peak_channel.items() + # } + + # spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() + # # {unit: spike_count} + + # # create channel2electrode_map + # electrode_map: dict[int, dict] = { + # elec["electrode"]: elec for elec in electrode_query.fetch(as_dict=True) + # } + # channel2electrode_map = { + # chn_idx: electrode_map[int(elec_id)] + # for chn_idx, elec_id in zip( + # sorting_analyzer.get_probe().device_channel_indices, + # sorting_analyzer.get_probe().contact_ids, + # ) + # } + + # # Get unit id to quality label mapping + # cluster_quality_label_map = { + # int(unit_id): ( + # si_sorting.get_unit_property(unit_id, "KSLabel") + # if "KSLabel" in si_sorting.get_property_keys() + # else "n.a." + # ) + # for unit_id in si_sorting.unit_ids + # } + + # spike_locations = sorting_analyzer.get_extension("spike_locations") + # extremum_channel_inds = si.template_tools.get_template_extremum_channel( + # sorting_analyzer, outputs="index" + # ) + # spikes_df = pd.DataFrame( + # sorting_analyzer.sorting.to_spike_vector( + # extremum_channel_inds=extremum_channel_inds + # ) + # ) + + # ephys_sync_func = get_sync_ephys_function(key) + + # units = [] + # for unit_idx, unit_id in enumerate(si_sorting.unit_ids): + # unit_id = int(unit_id) + # unit_spikes_df = spikes_df[spikes_df.unit_index == unit_idx] + # spike_sites = np.array( + # [ + # channel2electrode_map[chn_idx]["electrode"] + # for chn_idx in unit_spikes_df.channel_index + # ] + # ) + # unit_spikes_loc = spike_locations.get_data()[unit_spikes_df.index] + # _, spike_depths = zip(*unit_spikes_loc) # x-coordinates, y-coordinates + # spike_times = si_sorting.get_unit_spike_train( + # unit_id, return_times=True + # ) + # spike_times = ephys_sync_func(spike_times) + + # assert len(spike_times) == len(spike_sites) == len(spike_depths) + + # units.append( + # { + # **key, + # **channel2electrode_map[unit_peak_channel[unit_id]], + # "unit": unit_id, + # "cluster_quality_label": cluster_quality_label_map[unit_id], + # "spike_times": spike_times, + # "spike_count": spike_count_dict[unit_id], + # "spike_sites": spike_sites, + # "spike_depths": spike_depths, + # } + # ) + # else: # read from kilosort outputs + # kilosort_dataset = kilosort.Kilosort(output_dir) + # acq_software, sample_rate = (EphysRecording & key).fetch1( + # "acq_software", "sampling_rate" + # ) + + # sample_rate = kilosort_dataset.data["params"].get( + # "sample_rate", sample_rate + # ) + + # # ---------- Unit ---------- + # # -- Remove 0-spike units + # withspike_idx = [ + # i + # for i, u in enumerate(kilosort_dataset.data["cluster_ids"]) + # if (kilosort_dataset.data["spike_clusters"] == u).any() + # ] + # valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx] + # valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx] + + # # -- Spike-times -- + # # spike_times_sec_adj > spike_times_sec > spike_times + # spike_time_key = ( + # "spike_times_sec_adj" + # if "spike_times_sec_adj" in kilosort_dataset.data + # else ( + # "spike_times_sec" + # if "spike_times_sec" in kilosort_dataset.data + # else "spike_times" + # ) + # ) + # spike_times = kilosort_dataset.data[spike_time_key] + # kilosort_dataset.extract_spike_depths() + + # # -- Spike-sites and Spike-depths -- + # spike_sites = np.array( + # [ + # channel2electrode_map[s]["electrode"] + # for s in kilosort_dataset.data["spike_sites"] + # ] + # ) + # spike_depths = kilosort_dataset.data["spike_depths"] + + # # -- Insert unit, label, peak-chn + # units = [] + # for unit, unit_lbl in zip(valid_units, valid_unit_labels): + # if (kilosort_dataset.data["spike_clusters"] == unit).any(): + # unit_channel, _ = kilosort_dataset.get_best_channel(unit) + # unit_spike_times = ( + # spike_times[kilosort_dataset.data["spike_clusters"] == unit] + # / sample_rate + # ) + # spike_count = len(unit_spike_times) + + # units.append( + # { + # **key, + # "unit": unit, + # "cluster_quality_label": unit_lbl, + # **channel2electrode_map[unit_channel], + # "spike_times": unit_spike_times, + # "spike_count": spike_count, + # "spike_sites": spike_sites[ + # kilosort_dataset.data["spike_clusters"] == unit + # ], + # "spike_depths": spike_depths[ + # kilosort_dataset.data["spike_clusters"] == unit + # ], + # } + # ) + + # self.insert1(key) + # self.Unit.insert(units, ignore_extra_fields=True) @schema @@ -1257,8 +1576,8 @@ class Waveform(dj.Part): # Spike waveforms and their mean across spikes for the given unit -> master -> CuratedClustering.Unit - -> probe.ElectrodeConfig.Electrode - --- + -> probe.ElectrodeConfig.Electrode + --- waveform_mean: longblob # (uV) mean waveform across spikes of the given unit waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit """ @@ -1298,10 +1617,16 @@ def make(self, key): ) # {unit: peak_channel_index} unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} - # update channel2electrode_map to match with probe's channel index + # create channel2electrode_map + electrode_map: dict[int, dict] = { + elec["electrode"]: elec for elec in electrode_query.fetch(as_dict=True) + } channel2electrode_map = { - idx: channel2electrode_map[int(chn_idx)] - for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids) + chn_idx: electrode_map[int(elec_id)] + for chn_idx, elec_id in zip( + sorting_analyzer.get_probe().device_channel_indices, + sorting_analyzer.get_probe().contact_ids, + ) } templates = sorting_analyzer.get_extension("templates") @@ -1448,7 +1773,7 @@ class QualityMetrics(dj.Imported): definition = """ # Clusters and waveforms metrics - -> CuratedClustering + -> CuratedClustering """ class Cluster(dj.Part): @@ -1473,26 +1798,26 @@ class Cluster(dj.Part): contamination_rate (float): Frequency of spikes in the refractory period. """ - definition = """ + definition = """ # Cluster metrics for a particular unit -> master -> CuratedClustering.Unit --- - firing_rate=null: float # (Hz) firing rate for a unit + firing_rate=null: float # (Hz) firing rate for a unit snr=null: float # signal-to-noise ratio for a unit presence_ratio=null: float # fraction of time in which spikes are present isi_violation=null: float # rate of ISI violation as a fraction of overall rate number_violation=null: int # total number of ISI violations amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram isolation_distance=null: float # distance to nearest cluster in Mahalanobis space - l_ratio=null: float # + l_ratio=null: float # d_prime=null: float # Classification accuracy based on LDA nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster silhouette_score=null: float # Standard metric for cluster overlap max_drift=null: float # Maximum change in spike depth throughout recording - cumulative_drift=null: float # Cumulative change in spike depth throughout recording - contamination_rate=null: float # + cumulative_drift=null: float # Cumulative change in spike depth throughout recording + contamination_rate=null: float # """ class Waveform(dj.Part): @@ -1512,7 +1837,7 @@ class Waveform(dj.Part): velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe. """ - definition = """ + definition = """ # Waveform metrics for a particular unit -> master -> CuratedClustering.Unit diff --git a/element_array_ephys/spike_sorting/ephys_curation.py b/element_array_ephys/spike_sorting/ephys_curation.py new file mode 100644 index 00000000..9382831e --- /dev/null +++ b/element_array_ephys/spike_sorting/ephys_curation.py @@ -0,0 +1,549 @@ +from __future__ import annotations + +import os +import datajoint as dj +from datetime import datetime, timezone +from pathlib import Path +import json +import shutil +import numpy as np +import pandas as pd +from element_interface.utils import dict_to_uuid, find_full_path + +from element_array_ephys import ephys +from element_array_ephys.spike_sorting import si_spike_sorting as ephys_sorter + + +logger = dj.logger + +schema = dj.schema() + + +def activate( + schema_name, + *, + create_schema=True, + create_tables=True, +): + """ + activate(schema_name, *, create_schema=True, create_tables=True) + :param schema_name: schema name on the database server to activate the `ephys_curation` schema + :param create_schema: when True (default), create schema in the database if it does not yet exist. + :param create_tables: when True (default), create tables in the database if they do not yet exist. + """ + + if not ephys.schema.is_activated(): + raise RuntimeError("Please activate the `ephys` schema first.") + if not ephys_sorter.schema.is_activated(): + raise RuntimeError("Please activate the `si_spike_sorting` schema first.") + + schema.activate( + schema_name, + create_schema=create_schema, + create_tables=create_tables, + add_objects={**ephys.__dict__, **ephys_sorter.__dict__}, + ) + try: + ephys.ClusterQualityLabel.insert1( + ("no kilosort label", "unsorted spikes (Phy default)"), skip_duplicates=True + ) + except Exception as e: + logger.warning(f"Failed to insert default ClusterQualityLabel: {e}") + + +@schema +class CurationMethod(dj.Lookup): + definition = """ + # Curation method + curation_method: varchar(16) # method/package used to perform manual curation (e.g. Phy, FigURL, etc.) + """ + contents = [ + ("Phy",), + ] + + +@schema +class ManualCuration(dj.Manual): + definition = """ + # Manual curation from an ephys.Clustering + -> ephys.Clustering + curation_id: int + --- + curation_datetime: datetime # UTC time when the curation was performed + parent_curation_id=-1: int # if -1, this curation is based on the raw spike sorting results + -> CurationMethod # which method/package used for manual curation (inform how to ingest the results) + description="": varchar(1000) # user-defined description/note of the curation + """ + + class File(dj.Part): + definition = """ + -> master + file_name: varchar(255) + --- + file: filepath@ephys-processed + """ + + @classmethod + def prepare_manual_curation( + cls, + key, + *, + parent_curation_id=-1, + curation_method="Phy", + if_exists="skip", + download_binary=False, + ): + """ + Create a new directory to for a new round of manual curation + Download the spike sorting results for new manual curation from the specified "key" and "parent_curation_id". + Store the initial meta information in json file. + Args: + key (dict): PK of the ephys.Clustering table + parent_curation_id (int): curation_id to be used as the starting point for this new curation + curation_method (str): method/package used for manual curation (e.g. Phy) + if_exists (str): overwrite|skip + download_binary (bool): if True, also download the raw ephys (.dat) file + + Returns: + directory where the spike sorting results are downloaded + """ + assert if_exists in [ + "overwrite", + "skip", + ], f"Invalid if_exists: {if_exists}" + + if curation_method != "Phy": + raise ValueError(f"Unsupported curation method: {curation_method}") + + init_datetime = datetime.now(timezone.utc) + + # Download the spike sorting results + assert ephys.CuratedClustering & key, f"Invalid ephys.Clustering key: {key}" + if len(ephys.CuratedClustering.Unit & key) == 0: + logger.warning("This clustering has no units!!!") + + if parent_curation_id == -1: + assert ( + ephys_sorter.SIExport & key + ), "SIExport not found for the specified key" + files_query = ( + ephys_sorter.SIExport.File + & key + & "file_name LIKE 'phy%' AND file_name NOT LIKE '%recording.dat'" + ) + else: + assert cls & { + **key, + "curation_id": parent_curation_id, + }, "ManualCuration not found for the specified key" + files_query = cls.File & {**key, "curation_id": parent_curation_id} + + # create new directory for new curation + output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") + output_dir = ( + Path(ephys.get_processed_root_data_dir()) / output_dir / "curations" + ) + + dirname = f"{curation_method}_parentID_" + ( + "orig" if parent_curation_id == -1 else f"{parent_curation_id}" + ) + curation_output_dir = output_dir / dirname + curation_output_dir.mkdir(parents=True, exist_ok=True) + + # download spike sorting files and copy to the new directory + logger.info( + f"New manual curation: {curation_output_dir} - Downloading {len(files_query)} files..." + ) + + for f in files_query.fetch("file"): + f = Path(f) + if f.name.startswith(".") and f.suffix in (".json", ".pickle"): + continue + new_f = curation_output_dir / f.name + if not new_f.exists() or if_exists == "overwrite": + shutil.copy2(f, new_f) + + if download_binary: + new_f = curation_output_dir / "recording.dat" + if not new_f.exists() or if_exists == "overwrite": + raw_file_query = ( + ephys_sorter.SIExport.File & key & "file_name = 'phy/recording.dat'" + ) + if raw_file_query: + logger.info("Downloading raw ephys data file...") + f = Path(raw_file_query.fetch1("file")) + shutil.copy2(f, new_f) + else: + logger.warning("Raw ephys data file not found.") + + # write "entry" into a json file + with open(curation_output_dir / ".manual_curation_entry.json", "w") as f: + json.dump( + { + **key, + "parent_curation_id": parent_curation_id, + "curation_method": curation_method, + }, + f, + default=str, + ) + + return curation_output_dir + + @classmethod + def insert_manual_curation( + cls, + curation_output_dir, + *, + key=None, + parent_curation_id=None, + curation_method="Phy", + description="", + delete_local_dir=True, + ): + """ + Insert a new manual curation into the database + 1. Get a new "curation_id" (auto-incremented) + 2. Copy the curation_output_dir to a new directory with the new curation_id + 3. Insert the new curation into the database (excluding the raw recording.dat file) + 4. Delete the old curation directory + 5. Optionally delete the new curation directory + Args: + curation_output_dir: directory where the curation results are stored + key: ephys.Clustering key + parent_curation_id: curation_id of the parent curation + curation_method: method/package used for manual curation (e.g. Phy) + description: user-defined description/note of the curation + delete_local_dir: if True, delete the new curation directory after inserting the curation into the database + + Returns: `key` of the newly inserted manual curation + """ + + curation_output_dir = Path(curation_output_dir) + + # Light logic to safeguard against re-inserting the same curation + if (curation_output_dir / ".manual_curation_entry.json").exists(): + if key is not None: + logger.warning( + ".manual_curation_entry.json already exists. Ignoring inputs." + ) + + with open(curation_output_dir / ".manual_curation_entry.json", "r") as f: + key = json.load(f) + + parent_curation_id = key.pop("parent_curation_id") + curation_method = key.pop("curation_method") + + if "curation_id" in key: + print(f"Manual curation already inserted: {key}") + if ( + dj.utils.user_choice("Insert a new manual curation anyway?") + != "yes" + ): + print("Canceled") + return + else: + if parent_curation_id is None or key is None: + raise ValueError( + ".manual_curation_entry.json not found. `key` AND `parent_curation_id` must be specified" + ) + + if curation_method != "Phy": + raise ValueError(f"Unsupported curation method: {curation_method}") + + from element_array_ephys.readers import kilosort + + kilosort_dataset = kilosort.Kilosort(curation_output_dir) + assert ( + kilosort_dataset.data + ), f"Invalid Phy output directory: {curation_output_dir}" + + curate_datetime = datetime.now(timezone.utc) + curation_id = ( + ephys.Clustering.aggr(cls, count="count(curation_id)", keep_all_rows=True) + & key + ).fetch1("count") + 1 + logger.info(f"New curation id: {curation_id}") + + key["curation_id"] = curation_id + with cls.connection.transaction: + entry = { + **key, + "curation_datetime": curate_datetime, + "parent_curation_id": parent_curation_id, + "curation_method": curation_method, + "description": description, + } + cls.insert1(entry) + + # rename curation_output_dir folder into curation_id (skip the raw recording.dat file) + new_curation_output_dir = ( + curation_output_dir.parent + / f"{curation_method}_curationID_{curation_id}" + ) + new_curation_output_dir.mkdir(parents=True, exist_ok=True) + for f in curation_output_dir.glob("*"): + if f.is_file() and f.name != "recording.dat": + shutil.copy2( + f, new_curation_output_dir / f.relative_to(curation_output_dir) + ) + elif f.is_dir(): + shutil.copytree( + f, new_curation_output_dir / f.relative_to(curation_output_dir) + ) + + logger.info(f"Inserting files from {new_curation_output_dir}...") + cls.File.insert( + [ + { + **key, + "file_name": f.relative_to(new_curation_output_dir).as_posix(), + "file": f, + } + for f in new_curation_output_dir.rglob("*") + if f.is_file() + and f.name not in ("recording.dat", ".manual_curation_entry.json") + ] + ) + + logger.info(f"New manual curation successfully inserted: {key}") + + # write "entry" into a json file + with open(new_curation_output_dir / ".manual_curation_entry.json", "w") as f: + json.dump(entry, f, default=str) + + # Delete the old curation directory + try: + shutil.rmtree(curation_output_dir) + except Exception as e: + logger.error( + f"Failed to fully delete the old curation directory (please try to delete manually):\n\t{curation_output_dir}\n{e}" + ) + + if delete_local_dir: + try: + shutil.rmtree(new_curation_output_dir) + except Exception as e: + logger.error( + f"Failed to fully delete the new curation directory (please try to delete manually):\n\t{new_curation_output_dir}\n{e}" + ) + + return key + + +@schema +class OfficialCuration(dj.Manual): + definition = """ + -> ephys.Clustering + --- + -> ManualCuration + """ + + +@schema +class ApplyOfficialCuration(dj.Imported): + definition = """ + -> OfficialCuration + --- + execution_time: datetime # datetime of the start of this step + new_unit_count: int # number of new units added + removed_unit_count: int # number of units removed + """ + + @property + def key_source(self): + return OfficialCuration & ephys.CuratedClustering + + def make(self, key): + """ + High level logic + Step 1: delete units from ephys.CuratedClustering.Unit that are not in the new curation (merged or split) + Step 2: add new entries for new units (newly merged or split) + + Note: when replacing an OfficialCuration, manual steps must be taken + for a reset ingestion of ephys.CuratedClustering and below + - delete: (ephys.CuratedClustering & key).delete() + - repopulate: calls populate for `CuratedClustering`, `WaveformSet`, `QualityMetrics` + """ + from element_array_ephys.readers import kilosort + + # Resolve full key including curation_id from OfficialCuration + # (key only contains Clustering PK; curation_id is a dependent attribute) + official_key = (OfficialCuration & key).fetch1() + + curated_files = (ManualCuration.File & official_key).fetch("file") + curation_output_dir = ( + next(Path(f) for f in curated_files if Path(f).name == "params.py") + ).parent + + curation_method = (ManualCuration & official_key).fetch1( + "curation_method" + ) + + if curation_method != "Phy": + raise ValueError(f"Unsupported curation method: {curation_method}") + + clus_key = (ephys.Clustering & key).fetch1("KEY") + + kilosort_dataset = kilosort.Kilosort(curation_output_dir) + + orig_si_unit_map = { + u: i + for i, u in enumerate( + (ephys.CuratedClustering.Unit & clus_key).fetch("unit", order_by="unit") + ) + } + + new_si_unit_map = pd.read_csv( + curation_output_dir / "cluster_si_unit_id.tsv", sep="\t", index_col=1 + ).to_dict()["cluster_id"] + + # find set of units that are in the original curation but not in the new + removed_units = set(orig_si_unit_map) - set(new_si_unit_map) + # find set of units that are in the new curation but not in the original + new_units = set(kilosort_dataset.data["cluster_ids"]) - set( + orig_si_unit_map.values() + ) + new_si_unit_map.update( + {i + max(orig_si_unit_map) + 1: u for i, u in enumerate(new_units)} + ) + new_si_unit_reverse_map = {v: k for k, v in new_si_unit_map.items()} + + # Get channel and electrode-site mapping + # For SI-processed data, Kilosort's spike_sites contains device_channel_indices + # (sequential 0, 1, 2...) which may differ from the original channel_idx in + # EphysRecording.Channel (e.g., 32-63 for multi-insertion probes like MBA). + # Use the sorting_analyzer's probe to correctly map channel indices to electrodes. + electrode_query = (ephys.EphysRecording.Channel & clus_key).proj( + ..., "-channel_name" + ) + + # Get sorting_analyzer to access the probe's channel mapping + clustering_method, output_dir = ( + ephys.ClusteringTask * ephys.ClusteringParamSet & clus_key + ).fetch1("clustering_method", "clustering_output_dir") + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + sorter_name = clustering_method.replace(".", "_") + si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" + + # If sorting_analyzer not available locally, fetch from external storage (S3) + if not si_sorting_analyzer_dir.exists(): + _ = (ephys_sorter.PostProcessing.File & clus_key).fetch("file") + + if si_sorting_analyzer_dir.exists(): + import spikeinterface as si + + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) + # Create electrode_map keyed by electrode ID + electrode_map: dict[int, dict] = {} + for elec in electrode_query.fetch(as_dict=True): + elec.pop("channel_idx") + electrode_map[elec["electrode"]] = elec + # Map device_channel_indices to electrode info using probe's contact_ids + channel2electrode_map: dict[int, dict] = { + chn_idx: electrode_map[int(elec_id)] + for chn_idx, elec_id in zip( + sorting_analyzer.get_probe().device_channel_indices, + sorting_analyzer.get_probe().contact_ids, + ) + } + else: + # Fallback for non-SI processed data - use original channel_idx + channel2electrode_map: dict[int, dict] = { + chn["channel_idx"]: chn for chn in electrode_query.fetch(as_dict=True) + } + + sample_rate = kilosort_dataset.data["params"]["sample_rate"] + spike_times = kilosort_dataset.data["spike_times"] + kilosort_dataset.extract_spike_depths() + # -- Spike-sites and Spike-depths -- + spike_sites = np.array( + [ + channel2electrode_map[s]["electrode"] + for s in kilosort_dataset.data["spike_sites"] + ] + ) + spike_depths = kilosort_dataset.data["spike_depths"] + + # -- Remove units + with dj.config(safemode=False): + ( + ephys.CuratedClustering.Unit + & clus_key + & [{"unit": u} for u in removed_units] + ).delete(force=True) + + # -- Insert unit, label, peak-chn + for cluster_id, cluster_group in zip( + kilosort_dataset.data["cluster_ids"], + kilosort_dataset.data["cluster_groups"], + ): + unit_key = {**clus_key, "unit": new_si_unit_reverse_map[cluster_id]} + if cluster_id in new_units: + # add new unit entry + unit_channel, _ = kilosort_dataset.get_best_channel(cluster_id) + unit_spike_times = ( + spike_times[kilosort_dataset.data["spike_clusters"] == cluster_id] + / sample_rate + ) + spike_count = len(unit_spike_times) + + ephys.CuratedClustering.Unit.insert1( + { + **unit_key, + "cluster_quality_label": "no kilosort label", # new units will have "no kilosort label" label + **channel2electrode_map[unit_channel], + "spike_times": unit_spike_times, + "spike_count": spike_count, + "spike_sites": spike_sites[ + kilosort_dataset.data["spike_clusters"] == cluster_id + ], + "spike_depths": spike_depths[ + kilosort_dataset.data["spike_clusters"] == cluster_id + ], + }, + allow_direct_insert=True, + ) + + # insert the new unit label + ephys.CuratedClustering.ManualLabel.insert1( + {**unit_key, "manual_label": cluster_group}, allow_direct_insert=True + ) + + self.insert1( + { + **key, + "execution_time": datetime.now(timezone.utc), + "new_unit_count": len(new_units), + "removed_unit_count": len(removed_units), + } + ) + + +def launch_phy(key, parent_curation_id=-1, download_binary=False): + """ + Select a spike sorting key for manual curation + 1. download the spike sorting results + 2. launch phy + 3. commit new curation locally + 4. insert new curation into the database + Args: + key: ephys.Clustering key + parent_curation_id: if -1, this curation is based on the raw spike sorting results + download_raw: if True, also download the raw ephys (.dat) file + do_insert: if True, insert the new curation into the database (upload result files) + """ + from phy.apps.template import template_gui + + curation_output_dir = ManualCuration.prepare_manual_curation( + key, parent_curation_id=parent_curation_id, download_binary=download_binary + ) + + template_gui(curation_output_dir / "params.py") + + description = input("Curation description: ") + + ManualCuration.insert_manual_curation( + curation_output_dir, + description=description, + ) diff --git a/element_array_ephys/spike_sorting/infer_map.py b/element_array_ephys/spike_sorting/infer_map.py new file mode 100644 index 00000000..28ba4dcd --- /dev/null +++ b/element_array_ephys/spike_sorting/infer_map.py @@ -0,0 +1,106 @@ +"""Functions to infer a rough channel map from ephys data, used for microwire brush arrays""" + +import numpy as np +from scipy.spatial.distance import pdist +from sklearn.manifold import Isomap + + +def infer_map(signal, knn_k=5, scale_method="max", scale_params=None): + """Use Isomap algorithm to infer 2D channel map from correlations within signal array + + args: + signal (np.ndarray): 2D numpy array of shape (n_samples, n_channels) + knn_k (int): number of nearest neighbors for isomap knn graph + scale_method (str): fed to scale_map, see that function for details + scale_params (dict): fed to scale_map, see that function for details + + returns: + X (np.ndarray): channel map coordinates, of shape (n_channels, 2) + """ + # default scaling + if scale_method == "max" and not scale_params: + scale_params = {"max_dist": 350} + + C = signal_to_corr(signal) + D = corr_to_dist(C) + X = dist_to_coords(D, knn_k) + X = scale_map(X, scale_method, scale_params) + X = rotate_pc1_vertical(X) + return X + + +def signal_to_corr(signal): + """Generate correlations from signal""" + C = np.corrcoef(signal, rowvar=False) + return C + + +def corr_to_dist(C): + """Convert pairwise channel correlation matrix C to distance matrix D""" + D = 1 - C + return D + + +def dist_to_coords(D, knn_k): + """Use Isomap to infer channel coordinates from locally valid, approximate distances""" + isomap = Isomap( + n_neighbors=knn_k, n_components=2, metric="precomputed", eigen_solver="dense" + ) + X = isomap.fit_transform(D) + return X + + +def scale_map(X, scale_method, scale_params): + """Scale inferred channel map (scale all distances by a linear factor) + + args: + X (np.ndarray): (N x 2) map of inferred channel coordinates + scale_method (str): method for determining linear factor + scale_params (dict): free parameters for scale method + + returns: + X (np.ndarray): scaled coordinates + + See helper functions (scale_by_*) for scale method details + """ + if scale_method == "radius": + return scale_by_radius(X, **scale_params) + elif scale_method == "max": + return scale_by_max(X, **scale_params) + else: + raise NotImplementedError(f"{scale_method} is not a valid scale method") + + +def scale_by_radius(X, n_chan, r): + """Scale the inferred coordinates (X) so that no more than n_chan channels lie within a circle of radius r (approx.)""" + from sklearn.neighbors import kneighbors_graph + + G = kneighbors_graph(X, n_chan, mode="distance", include_self=True).toarray() + # get smallest distance to the n_chan-th nearest neighbor + min_dist = G.max(axis=1).min() + # scale s.t. the smallest distance to the n_chan-th nearest neighbor is r + X_out = X / min_dist * r + return X_out + + +def scale_by_max(X, max_dist): + """Scale the inferred coordinates (X) so that the maximum distance between any two points is max_dist""" + X = X / pdist(X).max() * max_dist + return X + + +def rotate_pc1_vertical(X): + """Helper function used to ensure similar/identical channel maps are visibly similar/identical + + Note that rotation does not alter how the channel map is processed by kilosort, which cares only about distances between channels. + """ + from sklearn.decomposition import PCA + + pc1 = PCA(1).fit(X).components_[0] + theta = np.arctan(pc1[1] / pc1[0]) - np.pi / 2 + + def R(x): + return np.array([[np.cos(x), np.sin(x)], [-np.sin(x), np.cos(x)]]) + + X_rot = np.matmul(R(theta), X.T).T + return X_rot diff --git a/element_array_ephys/spike_sorting/si_preprocessing.py b/element_array_ephys/spike_sorting/si_preprocessing.py index 22adbdca..8bdf40b6 100644 --- a/element_array_ephys/spike_sorting/si_preprocessing.py +++ b/element_array_ephys/spike_sorting/si_preprocessing.py @@ -1,3 +1,4 @@ +import numpy as np import spikeinterface as si from spikeinterface import preprocessing @@ -35,3 +36,61 @@ def IBLdestriping_modified(recording): recording, operator="median", reference="global" ) return recording + + +def NienborgLab_preproc(recording): + """Preprocessing pipeline for 32chn ephys data from Trellis.""" + recording = si.preprocessing.bandpass_filter( + recording=recording, freq_min=300, freq_max=6000 + ) + recording = si.preprocessing.common_reference( + recording=recording, operator="median" + ) + return recording + + +def MBA_infer_map(recording, **infer_map_kwargs): + """Preprocessing pipeline to infer channel map for microwire brush array data""" + from element_array_ephys.spike_sorting.infer_map import infer_map + + recording = si.preprocessing.bandpass_filter( + recording=recording, freq_min=300, freq_max=6000 + ) + if recording.get_num_channels() <= 32: + recording = si.preprocessing.common_reference( + recording=recording, operator="median" + ) + else: + # do common average referencing on each group of 32 channels + group_ids = np.arange(recording.get_num_channels(), dtype=int) // 32 + recording.set_property("group", group_ids) + split_recording_dict = recording.split_by("group") + split_recording_dict = si.preprocessing.common_reference(split_recording_dict) + recording = si.aggregate_channels(split_recording_dict) + + fs = recording.get_sampling_frequency() + if recording.get_duration() > 120: + # extract second minute of recording (arbitrary) + start_frame = 60 * fs + end_frame = 120 * fs + else: + # extract first minute or whole recording + start_frame = 0 + end_frame = int(min(60, recording.get_duration()) * fs) + + signal = recording.get_traces(start_frame=start_frame, end_frame=end_frame).astype( + np.float32 + ) + + channel_map = infer_map(signal, **infer_map_kwargs) + + # modify electrode positions within SI object + # TODO: eventually figure out a better way to do this + si_probe = recording.get_probe() + assert ( + channel_map.shape == si_probe.contact_positions.shape + ), f"Inferred coordinates dimensions: {channel_map.shape} do not match target dimensions: {si_probe.contact_positions.shape}" + si_probe.set_contacts(positions=channel_map) + recording.set_probe(si_probe, in_place=True) + + return recording diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index a324afa3..0a865fa9 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -5,13 +5,14 @@ """ from datetime import datetime +from pathlib import Path import datajoint as dj -import pandas as pd import spikeinterface as si -from element_array_ephys import probe, ephys, readers from element_interface.utils import find_full_path, memoized_result -from spikeinterface import exporters, extractors, sorters +from spikeinterface import exporters, extractors, sorters # noqa: F401 + +from element_array_ephys import ephys, probe, readers from . import si_preprocessing @@ -61,6 +62,14 @@ class PreProcessing(dj.Imported): execution_duration: float # execution duration in hours """ + class File(dj.Part): + definition = """ + -> master + file_name: varchar(255) + --- + file: filepath@ephys-processed + """ + @property def key_source(self): return ( @@ -81,8 +90,18 @@ def make(self, key): # Get sorter method and create output directory. sorter_name = clustering_method.replace(".", "_") + # Handle alternative preprocessing params structure + if "SI_PREPROCESSING_METHOD" in params: + preprocessing_method = params.pop("SI_PREPROCESSING_METHOD") + params["SI_PREPROCESSING_PARAMS"] = params.get( + "SI_PREPROCESSING_PARAMS", {} + ) + params["SI_PREPROCESSING_PARAMS"][ + "preprocessing_method" + ] = preprocessing_method + for required_key in ( - "SI_PREPROCESSING_METHOD", + "SI_PREPROCESSING_PARAMS", "SI_SORTING_PARAMS", "SI_POSTPROCESSING_PARAMS", ): @@ -91,6 +110,18 @@ def make(self, key): f"{required_key} must be defined in ClusteringParamSet for SpikeInterface execution" ) + # Get probe information to recording object + electrodes_df = ( + ( + ephys.EphysRecording.Channel + * probe.ElectrodeConfig.Electrode + * probe.ProbeType.Electrode + & key + ) + .fetch(format="frame") + .reset_index() + ) + # Set directory to store recording file. if not output_dir: output_dir = ephys.ClusteringTask.infer_output_dir( @@ -107,22 +138,32 @@ def make(self, key): # Create SI recording extractor object if acq_software == "SpikeGLX": - spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = readers.spikeglx.SpikeGLX( - spikeglx_meta_filepath.parent - ) - spikeglx_recording.validate_file("ap") - data_dir = spikeglx_meta_filepath.parent - si_extractor = ( si.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor ) - stream_names, stream_ids = si.extractors.get_neo_streams( - "spikeglx", folder_path=data_dir - ) - si_recording: si.BaseRecording = si_extractor( - folder_path=data_dir, stream_name=stream_names[0] - ) + spikeglx_meta_filepaths = ( + ephys.EphysRecording.EphysFile & key & 'file_path LIKE "%.ap.meta"' + ).fetch("file_path") + ap_bin_filepaths = [ + find_full_path( + ephys.get_ephys_root_data_dir(), Path(f).with_suffix(".bin") + ) + for f in spikeglx_meta_filepaths + ] + + si_recs = [] + for ap_bin_filepath in ap_bin_filepaths: + data_dir = ap_bin_filepath.parent + spikeglx_recording = readers.spikeglx.SpikeGLX(data_dir) + spikeglx_recording.validate_file("ap") + stream_names, stream_ids = si.extractors.get_neo_streams( + "spikeglx", folder_path=data_dir + ) + si_rec: si.BaseRecording = si_extractor( + folder_path=data_dir, stream_name=stream_names[0] + ) + si_recs.append(si_rec) + si_recording = si.concatenate_recordings(si_recs) elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) assert len(oe_probe.recording_info["recording_files"]) == 1 @@ -137,23 +178,34 @@ def make(self, key): si_recording: si.BaseRecording = si_extractor( folder_path=data_dir, stream_name=stream_names[0] ) + elif acq_software == "Trellis": + si_extractor = ( + si.extractors.neoextractors.blackrock.BlackrockRecordingExtractor + ) + + nsx5_relpaths = (ephys.EphysRecording.EphysFile & key).fetch("file_path") + nsx5_fullpaths = [ + find_full_path(ephys.get_ephys_root_data_dir(), f + ".ns5") + for f in nsx5_relpaths + ] + si_recs = [] + for f in nsx5_fullpaths: + si_rec = si_extractor(file_path=f, stream_name="nsx5") + si_recs.append(si_rec) + si_recording = si.concatenate_recordings(si_recs) else: raise NotImplementedError( f"SpikeInterface processing for {acq_software} not yet implemented." ) - # Add probe information to recording object - electrodes_df = ( - ( - ephys.EphysRecording.Channel - * probe.ElectrodeConfig.Electrode - * probe.ProbeType.Electrode - & key - ) - .fetch(format="frame") - .reset_index() - ) - + # Find & remove extra channels + in_use_chn_ids = si_recording.channel_ids[electrodes_df.channel_idx.values] + chn2remove = set(si_recording.channel_ids) - set(in_use_chn_ids) + si_recording = si_recording.remove_channels(list(chn2remove)) + in_use_chn_ind = [ + si_recording.channel_ids.tolist().index(chn_id) for chn_id in in_use_chn_ids + ] + electrodes_df.channel_idx = in_use_chn_ind # Create SI probe object si_probe = readers.probe_geometry.to_probeinterface( electrodes_df[["electrode", "x_coord", "y_coord", "shank"]] @@ -162,8 +214,11 @@ def make(self, key): si_recording.set_probe(probe=si_probe, in_place=True) # Run preprocessing and save results to output folder - si_preproc_func = getattr(si_preprocessing, params["SI_PREPROCESSING_METHOD"]) - si_recording = si_preproc_func(si_recording) + si_preproc_params = params["SI_PREPROCESSING_PARAMS"] + si_preproc_func = getattr( + si_preprocessing, si_preproc_params.pop("preprocessing_method") + ) + si_recording = si_preproc_func(si_recording, **si_preproc_params) si_recording.dump_to_pickle(file_path=recording_file, relative_to=output_dir) self.insert1( @@ -176,6 +231,14 @@ def make(self, key): / 3600, } ) + # Insert result files + self.File.insert( + [ + {**key, "file_name": f.relative_to(recording_dir).as_posix(), "file": f} + for f in recording_dir.rglob("*") + if f.is_file() + ] + ) @schema @@ -189,6 +252,14 @@ class SIClustering(dj.Imported): execution_duration: float # execution duration in hours """ + class File(dj.Part): + definition = """ + -> master + file_name: varchar(255) + --- + file: filepath@ephys-processed + """ + def make(self, key): execution_time = datetime.utcnow() @@ -239,6 +310,16 @@ def _run_sorter(): / 3600, } ) + # Insert result files + for f in sorting_output_dir.rglob("*"): + if f.is_file(): + self.File.insert1( + { + **key, + "file_name": f.relative_to(sorting_output_dir).as_posix(), + "file": f, + } + ) @schema @@ -253,6 +334,14 @@ class PostProcessing(dj.Imported): do_si_export=0: bool # whether to export to phy """ + class File(dj.Part): + definition = """ + -> master + file_name: varchar(255) + --- + file: filepath@ephys-processed + """ + def make(self, key): execution_time = datetime.utcnow() @@ -333,6 +422,15 @@ def _sorting_analyzer_compute(): "do_si_export": do_si_export and has_units, } ) + for f in analyzer_output_dir.rglob("*"): + if f.is_file(): + self.File.insert1( + { + **key, + "file_name": f.relative_to(analyzer_output_dir).as_posix(), + "file": f, + } + ) # Once finished, insert this `key` into ephys.Clustering ephys.Clustering.insert1( @@ -351,6 +449,14 @@ class SIExport(dj.Computed): execution_duration: float """ + class File(dj.Part): + definition = """ + -> master + file_name: varchar(255) + --- + file: filepath@ephys-processed + """ + @property def key_source(self): return PostProcessing & "do_si_export = 1" @@ -416,3 +522,14 @@ def _export_report(): / 3600, } ) + # Insert result files + for report_dirname in ("spikeinterface_report", "phy"): + for f in (analyzer_output_dir / report_dirname).rglob("*"): + if f.is_file(): + self.File.insert1( + { + **key, + "file_name": f.relative_to(analyzer_output_dir).as_posix(), + "file": f, + } + )