diff --git a/doc/api.rst b/doc/api.rst index 844b07ca..be1fc673 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -121,6 +121,7 @@ Analysis modules filtering perievent randomize + signal spectrum tuning_curves wavelets diff --git a/doc/examples/tutorial_phase_preferences.md b/doc/examples/tutorial_phase_preferences.md index 40f3c98e..56b8e3bb 100644 --- a/doc/examples/tutorial_phase_preferences.md +++ b/doc/examples/tutorial_phase_preferences.md @@ -173,12 +173,12 @@ plt.show() Computing phase --------------- -From the filtered signal, it is easy to get the phase using the Hilbert transform. Here we use scipy Hilbert method. +From the filtered signal, it is easy to get the phase using the Hilbert transform, (see the [user guide](/user_guide/13_phases_and_envelopes.md)). +Pynapple provides function [`compute_hilbert_phase`](pynapple.process.signal.compute_hilbert_phase) for this: ```{code-cell} ipython3 -from scipy import signal - -theta_phase = nap.Tsd(t=theta_band.t, d=np.angle(signal.hilbert(theta_band))) +theta_phase = nap.compute_hilbert_phase(theta_band) +theta_phase ``` Let's plot the phase. diff --git a/doc/user_guide.md b/doc/user_guide.md index c5b2ec9e..38f0a5fb 100644 --- a/doc/user_guide.md +++ b/doc/user_guide.md @@ -30,39 +30,53 @@ Metadata :::{card} High-level analysis ```{toctree} +:maxdepth: 2 Correlograms & ISI ``` ```{toctree} +:maxdepth: 2 Tuning curves ``` ```{toctree} +:maxdepth: 2 Decoding ``` ```{toctree} +:maxdepth: 2 Perievent / Spike-triggered average ``` ```{toctree} +:maxdepth: 2 Randomization ``` ```{toctree} +:maxdepth: 2 Power spectral density ``` ```{toctree} +:maxdepth: 2 Wavelet decomposion ``` ```{toctree} +:maxdepth: 2 Filtering time series ``` ```{toctree} -Building trial-based tensors +:maxdepth: 2 +Extracting phases and envelopes +``` + +```{toctree} +:maxdepth: 2 +Building trial-based tensors ``` ::: diff --git a/doc/user_guide/01_introduction_to_pynapple.md b/doc/user_guide/01_introduction_to_pynapple.md index 72fd409c..4c6f9149 100644 --- a/doc/user_guide/01_introduction_to_pynapple.md +++ b/doc/user_guide/01_introduction_to_pynapple.md @@ -464,6 +464,10 @@ Tuning curves of neurons based on spiking or calcium activity can be computed. The wavelets module performs Morlet wavelets decomposition of a time series. -**[Warping](13_warping)** +**[Phases & envelopes](13_phases_and_envelopes)** + +This modules allows for computing analytic signals and extracting the phase and envelope. + +**[Warping](14_warping)** This module provides methods for building trial-based tensors and time-warped trial-based tensors. diff --git a/doc/user_guide/12_filtering.md b/doc/user_guide/12_filtering.md index 4fa053f2..d7d725d3 100644 --- a/doc/user_guide/12_filtering.md +++ b/doc/user_guide/12_filtering.md @@ -380,78 +380,5 @@ for arr, label in zip( plt.legend() plt.xlabel("Number of dimensions") plt.ylabel("Time (s)") -plt.title("Low pass filtering benchmark") -``` - - -*** -Detecting Oscillatory Events ---------------------------- - -The filtering module also provides a method [`detect_oscillatory_events`](pynapple.process.filtering.detect_oscillatory_events) to automatically detect intervals containing oscillatory events (such as ripples or spindles) in a signal. - -To demonstrate, let's create a synthetic signal where a fast oscillation (e.g., 40 Hz) occurs in a noisy signal: - -```{code-cell} ipython3 -# Parameters -fs = 1000 # Sampling frequency (Hz) -duration = 3 # seconds -t = np.linspace(0, duration, int(fs * duration)) - -# 40 Hz oscillation -osc = np.sin(2 * np.pi * 40 * t) -signal = np.zeros_like(t) + 0.2 * np.random.randn(len(t)) -mask = (t > 1) & (t < 1.5) -signal[mask] += osc[mask] - -# Create Tsd object -ts = nap.Tsd(t=t, d=signal) -``` - -```{code-cell} ipython3 -:tags: [hide-input] - -# Plot the signal -plt.figure(figsize=(15, 4)) -plt.plot(ts, label="Signal (40 Hz oscillation)") -plt.xlabel("Time (s)") -plt.ylabel("Amplitude") -plt.title("Signal with oscillatory bursts") -plt.legend() -plt.show() -``` - -Now, let's use [`detect_oscillatory_events`](pynapple.process.filtering.detect_oscillatory_events) to find the oscillation intervals. The function will return the detected intervals as an `IntervalSet` along with metadata containing peak times. - -```{code-cell} ipython3 -# Define detection parameters -freq_band = (35, 45) # Bandpass filter for 40 Hz -thres_band = (1, 10) # Thresholds for normalized squared signal -min_dur = 0.03 # Minimum event duration (s) -max_dur = 1 # Maximum event duration (s) -min_inter = 0.02 # Minimum interval between events (s) -epoch = nap.IntervalSet(start=0, end=duration) - -# Detect oscillatory events -osc_ep = nap.filtering.detect_oscillatory_events( - ts, epoch, freq_band, thres_band, (min_dur, max_dur), min_inter -) - -print("Detected intervals:\n", osc_ep) -``` - -Let's visualize the detected intervals and peaks on the original signal: - -```{code-cell} ipython3 -:tags: [hide-input] - -plt.figure(figsize=(15, 4)) -plt.plot(ts, label="Signal") -for s, e in osc_ep.values: - plt.axvspan(s, e, color="orange", alpha=0.3) -plt.xlabel("Time (s)") -plt.ylabel("Amplitude") -plt.title("Detected oscillatory events") -plt.legend() -plt.show() +plt.title("Low pass filtering benchmark"); ``` diff --git a/doc/user_guide/13_phases_and_envelopes.md b/doc/user_guide/13_phases_and_envelopes.md new file mode 100644 index 00000000..92d38d31 --- /dev/null +++ b/doc/user_guide/13_phases_and_envelopes.md @@ -0,0 +1,236 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +# Phases and envelopes +In this tutorial, we will introduce Pynapple's functionality for computing signal phases +and envelopes. +Most of this functionality is part of the [`pynapple.process.signal`](pynapple.process.signal) module and is built on [the +Hilbert transform](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.hilbert.html). + +```{code-cell} ipython3 +:tags: [hide-input] +# we'll import the packages we're going to use +import numpy as np +import pynapple as nap +import matplotlib.pyplot as plt +import seaborn as sns + +# some configuration, you can ignore this +custom_params = {"axes.spines.right": False, "axes.spines.top": False} +sns.set_theme( + style="ticks", palette="colorblind", font_scale=1.5, rc=custom_params +) +``` + +## Extracting the analytic signal +Let us start by simulating a noisy oscillatory signal containing two frequencies: +```{code-cell} ipython3 +sampling_rate_hz = 1000 +times = np.arange(0, 5, 1 / sampling_rate_hz) + +# Low frequency (8 Hz) +low_freq_hz = 8 +signal = np.cos(2 * np.pi * low_freq_hz * times) + +# High-frequency (40 Hz) +high_freq_hz = 40 +segments = [(1.5, 2.0), (3.0, 3.5)] +for start, end in segments: + mask = (times >= start) & (times <= end) + signal[mask] = np.cos(2 * np.pi * high_freq_hz * times[mask]) + +# Add noise +signal = signal + 0.3 * np.random.normal(size=len(times)) + +# Convert to Tsd +signal = nap.Tsd(t=times, d=signal) +``` + +Let's visualize that: +```{code-cell} ipython3 +:tags: [hide-input] +segment = nap.IntervalSet(1, 4) +plt.figure(figsize=(10,3)) +plt.plot(signal.restrict(segment)) +plt.xlabel("time (s)") +plt.ylabel("amplitude (a.u.)") +plt.tight_layout(); +``` + +Now, imagine that we are interested in the low frequency part of the signal. +We can start by applying a bandpass filter using [`apply_bandpass_filter`](pynapple.process.filtering.apply_bandpass_filter) +to keep only the relevant frequencies (5-10Hz): +```{code-cell} ipython3 +filtered_signal = nap.apply_bandpass_filter( + signal, (5, 10), fs=sampling_rate_hz, mode="butter" +) +filtered_signal +``` + +Let's visualize that together with the original signal: +```{code-cell} ipython3 +:tags: [hide-input] +plt.figure(figsize=(10,3)) +plt.plot(signal.restrict(segment), label="signal") +plt.plot(filtered_signal.restrict(segment), label="filtered signal") +plt.xlabel("time (s)") +plt.ylabel("amplitude (a.u.)") +plt.legend(loc="upper left", bbox_to_anchor=(1, 1)) +plt.tight_layout(); +``` + +We can now use [`apply_hilbert_transform`](pynapple.process.signal.apply_hilbert_transform) +to extract the analytic signal: +```{code-cell} ipython3 +analytic_signal = nap.apply_hilbert_transform(filtered_signal) +analytic_signal +``` + +If we visualize the analytic signal with the input signal, +you will notice that the analytic signal appears identical to the original signal: +```{code-cell} ipython3 +:tags: [hide-input] +plt.figure(figsize=(10,3)) +plt.plot(filtered_signal.restrict(segment), label="filtered signal") +plt.plot(analytic_signal.restrict(segment), label="analytic signal") +plt.xlabel("time (s)") +plt.ylabel("amplitude (a.u.)") +plt.legend(loc="upper left", bbox_to_anchor=(1, 1)) +plt.title("nap.apply_hilbert_transform(filtered_signal)") +plt.tight_layout(); +``` + +This happens because the analytic signal is complex-valued. +The real part is exactly the input signal. +The imaginary part is the Hilbert transform (a 90° phase-shifted version). +When you pass a complex signal to matplotlib, +it automatically plots only the real part (see the warnings above). + +To actually see what’s going on, you can plot the real and imaginary parts separately: +```{code-cell} ipython3 +:tags: [hide-input] +plt.figure(figsize=(10,3)) +plt.plot(np.real(analytic_signal).restrict(segment), label="real part") +plt.plot( + np.imag(analytic_signal).restrict(segment), + label="imaginary part", +) +plt.xlabel("time (s)") +plt.ylabel("amplitude (a.u.)") +plt.legend(loc="upper left", bbox_to_anchor=(1, 1)); +``` + +## Computing the signal envelope +From the analytic signal, it is often the case that we will compute other things. +For one, we can extract the envelope of a signal, by taking the absolute value. +To make things easy, Pynapple provides [`compute_hilbert_envelope`](pynapple.process.signal.compute_hilbert_envelope) +to compute the envelope in one go: +```{code-cell} ipython3 +envelope = nap.compute_hilbert_envelope(filtered_signal) +envelope +``` + +Visualizing the envelope over the signal: +```{code-cell} ipython3 +:tags: [hide-input] +plt.figure(figsize=(10,3)) +plt.plot(filtered_signal.restrict(segment), label="filtered signal") +plt.plot(envelope.restrict(segment), label="envelope", color="red") +plt.xlabel("time (s)") +plt.ylabel("amplitude (a.u.)") +plt.legend(loc="upper left", bbox_to_anchor=(1, 1)) +plt.title("nap.compute_hilbert_envelope(filtered_signal)") +plt.tight_layout(); +``` + +## Computing the signal phase +We can also estimate the signal's phase, by taking angle and wrapping. +To make things easy, Pynapple provides [`compute_hilbert_phase`](pynapple.process.signal.compute_hilbert_phase) +to compute the phase in one go: +```{code-cell} ipython3 +phase = nap.compute_hilbert_phase(filtered_signal) +phase +``` + +Visualizing the phase with the signal: +```{code-cell} ipython3 +:tags: [hide-input] +fig, ax_sig = plt.subplots(figsize=(10, 3)) +signal_line = ax_sig.plot(filtered_signal.restrict(segment)) +ax_sig.set_ylabel("amplitude (a.u.)") +ax_phase = ax_sig.twinx() +ax_phase.spines["right"].set_visible(True) +phase_line = ax_phase.plot(phase.restrict(segment), color="red", linewidth=0.5) +ax_phase.set_ylabel("phase (rad)") +ax_sig.set_xlabel("time (s)") +ax_sig.legend( + signal_line + phase_line, + ["filtered signal", "phase"], + loc="upper left", + bbox_to_anchor=(1.15, 1), +) +plt.title("nap.compute_hilbert_phase(filtered_signal)") +plt.tight_layout(); +``` + +## Detecting oscillatory events +Having looked at the low frequency part of our signal, we might also be interested in the high frequency part. +To start with, we might simply be interested in finding the epochs where the signal is oscillating +at high frequencies. +Pynapple provides the [`detect_oscillatory_events`](pynapple.process.signal.detect_oscillatory_events) +exactly for such a goal. + +To get it to work nicely, you will have the tune the following detection parameters: +- frequency band: the band of frequencies you are interested in. +- threshold band: minimum and maximum thresholds to apply to the z-scored envelope of the squared signal. +- minimum and maximum duration of the events +- minimum interval between events +```{code-cell} ipython3 +# Define detection parameters +freq_band = (35, 45) # Gamma band +thres_band = (1, 5) # Thresholds for normalized squared envelope +min_dur = 0.4 # Minimum event duration +max_dur = 0.6 # Max event duration +min_inter = 0.05 # Minimum interval between events +epoch = signal.time_support + +# Detect oscillatory events +events = nap.detect_oscillatory_events( + signal, + epoch, + freq_band, + thres_band, + (min_dur, max_dur), + min_inter, +) +events +``` + +We can then visualize the found events on top of the original signal as validation: +```{code-cell} ipython3 +:tags: [hide-input] +plt.figure(figsize=(10, 3)) +plt.plot(signal.restrict(segment), label="signal") +first = True +for s, e in events.intersect(segment).values: + if first: + plt.axvspan(s, e, color="orange", alpha=0.3, label="event") + first = False + else: + plt.axvspan(s, e, color="orange", alpha=0.3) +plt.xlabel("time (s)") +plt.ylabel("amplitude (a.u.)") +plt.legend(loc="upper left", bbox_to_anchor=(1, 1)) +plt.title("nap.detect_oscillatory_events(signal, ...)") +plt.tight_layout(); +``` diff --git a/doc/user_guide/13_warping.md b/doc/user_guide/14_warping.md similarity index 100% rename from doc/user_guide/13_warping.md rename to doc/user_guide/14_warping.md diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index bd397fae..fa206d07 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -2764,7 +2764,7 @@ def __repr__(self): n_rows = max_rows // 2 table = [] for i, v in zip(self.index[0:n_rows], self.values[0:n_rows]): - table.append([i, v]) + table.append([i, str(v)]) table.append(["..."]) for i, v in zip( self.index[-n_rows:], @@ -2772,7 +2772,7 @@ def __repr__(self): self.values.shape[0] - n_rows : self.values.shape[0] ], ): - table.append([i, v]) + table.append([i, str(v)]) return ( tabulate(table, headers=headers, colalign=("left",)) diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 57c170bf..44c5456d 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -23,6 +23,12 @@ shift_timestamps, shuffle_ts_intervals, ) +from .signal import ( + apply_hilbert_transform, + compute_hilbert_envelope, + compute_hilbert_phase, + detect_oscillatory_events, +) from .spectrum import ( compute_fft, compute_mean_power_spectral_density, diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 119a606a..fb6d9d65 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -515,98 +515,3 @@ def get_filter_frequency_response( ) else: raise ValueError("Unrecognized filter mode. Choose either 'butter' or 'sinc'") - - -def detect_oscillatory_events( - data, - epoch, - freq_band, - thresh_band, - duration_band, - min_inter_duration, - fs=None, - wsize=51, -): - """ - Simple helper for detecting oscillatory events (e.g. ripples, spindles) - - Parameters - ---------- - data : Tsd - 1-dimensional time series - epoch : IntervalSet - The epoch for restricting the detection - freq_band : tuple - The (low, high) frequency to bandpass the signal - thresh_band : tuple - The (min, max) value for thresholding the normalized squared signal after filtering - duration_band : tuple - The (min, max) duration of an event in second - min_inter_duration : float - The minimum duration between two events otherwise they are merged (in seconds) - fs : float, optional - The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. - wsize : int, optional - The size of the window for digital filtering - - Returns - ------- - IntervalSet - The interval set of detected events with metadata containing - the power, amplitude, and peak_time - """ - from scipy.signal import filtfilt - - data = data.restrict(epoch) - - if fs is None: - fs = data.rate - - signal = apply_bandpass_filter(data, freq_band, fs) - squared_signal = np.square(signal.values) - window = np.ones(wsize) / wsize - - nSS = filtfilt(window, 1, squared_signal) - nSS = (nSS - np.mean(nSS)) / np.std(nSS) - nSS = nap.Tsd(t=signal.index.values, d=nSS, time_support=epoch) - - # Detect oscillation periods by thresholding normalized signal - nSS2 = nSS.threshold(thresh_band[0], method="above") - nSS3 = nSS2.threshold(thresh_band[1], method="below") - - # Exclude oscillation where min_duration < length < max_duration - osc_ep = nSS3.time_support - osc_ep = osc_ep.drop_short_intervals(duration_band[0], time_units="s") - osc_ep = osc_ep.drop_long_intervals(duration_band[1], time_units="s") - - # Merge if inter-oscillation period is too short - osc_ep = osc_ep.merge_close_intervals(min_inter_duration, time_units="s") - - # Compute power, amplitude, and peak_time for each interval - powers = [] - amplitudes = [] - peak_times = [] - - for s, e in osc_ep.values: - seg = signal.get(s, e) - if len(seg) == 0: - powers.append(np.nan) - amplitudes.append(np.nan) - peak_times.append(np.nan) - continue - power = np.mean(np.square(seg)) - power_db = 10 * np.log10(power) - amplitude = np.max(np.abs(seg.values)) - peak_idx = np.argmax(np.abs(seg.values)) - peak_time = seg.index.values[peak_idx] - powers.append(power_db) - amplitudes.append(amplitude) - peak_times.append(peak_time) - - metadata = { - "power": powers, - "amplitude": amplitudes, - "peak_time": peak_times, - } - - return nap.IntervalSet(start=osc_ep.start, end=osc_ep.end, metadata=metadata) diff --git a/pynapple/process/signal.py b/pynapple/process/signal.py new file mode 100644 index 00000000..d957e466 --- /dev/null +++ b/pynapple/process/signal.py @@ -0,0 +1,346 @@ +""" +Functions to compute phases and envelopes +""" + +import numpy as np + +import pynapple as nap + + +def apply_hilbert_transform(data): + """ + Apply the Hilbert transform to a time-series. + + This function wraps :func:`scipy.signal.hilbert` to compute the analytic signal, + which represents the original signal plus its Hilbert transform. + The Hilbert transform is commonly used for phase and envelope computations. + + Parameters + ---------- + data : Tsd, TsdFrame + The time-series to which the Hilbert transform will be applied. + + Returns + ------- + Tsd, TsdFrame + The analytic signal. + + Examples + -------- + >>> import numpy as np + >>> import pynapple as nap + >>> times = np.arange(0, 20, 0.1) + >>> data = nap.Tsd(d=np.sin(times), t=times) + >>> analytic_signal = nap.apply_hilbert_transform(data) + >>> analytic_signal + Time (s) + ---------- --------------------------------------------- + 0.0 (-7.105427357601002e-17+0.16863846755783507j) + 0.1 (0.09983341664682804-0.4242708612183068j) + 0.2 (0.1986693307950611-0.39707635680514186j) + 0.3 (0.2955202066613397-0.5595039695854611j) + 0.4 (0.3894183423086504-0.5111880915563278j) + 0.5 (0.4794255386042027-0.5737652837793169j) + 0.6 (0.5646424733950355-0.507679928825233j) + ... + 19.3 (0.4353653603728936-0.8487963381972523j) + 19.4 (0.5230657651576995-0.8346984246178151j) + 19.5 (0.6055398697196014-0.6962597876075753j) + 19.6 (0.6819636200681357-0.673794027075586j) + 19.7 (0.7515734153521505-0.4510063664570575j) + 19.8 (0.8136737375071058-0.4288655114005536j) + 19.9 (0.8676441006416694+0.17871162129618953j) + dtype: complex128, shape: (200,) + + Can be used for multiple signals in a `TsdFrame`: + + >>> data = nap.TsdFrame(d=np.stack([np.sin(times), np.cos(times)], axis=1), t=times) + >>> analytic_signals = nap.apply_hilbert_transform(data) + >>> analytic_signals + Time (s) 0 1 + ---------- --------------------------------------------- ------------------------------------------ + 0.0 (-7.105427357601002e-17+0.16863846755783507j) (0.9999999999999999-0.10933857636723118j) + 0.1 (0.09983341664682804-0.4242708612183068j) (0.9950041652780255+0.2511675765027083j) + 0.2 (0.1986693307950611-0.39707635680514186j) (0.9800665778412415+0.29226919446209765j) + 0.3 (0.2955202066613397-0.5595039695854611j) (0.9553364891256061+0.45596645091122484j) + 0.4 (0.3894183423086504-0.5111880915563278j) (0.9210609940028853+0.5095457107948864j) + 0.5 (0.4794255386042027-0.5737652837793169j) (0.8775825618903729+0.6332440634658392j) + 0.6 (0.5646424733950355-0.507679928825233j) (0.8253356149096783+0.6873614294191187j) + ... + 19.3 (0.4353653603728936-0.8487963381972523j) (0.9002538547473041+0.09873915660984472j) + 19.4 (0.5230657651576995-0.8346984246178151j) (0.852292323865463+0.18298411058238345j) + 19.5 (0.6055398697196014-0.6962597876075753j) (0.7958149698139438+0.19019044271235275j) + 19.6 (0.6819636200681357-0.673794027075586j) (0.7313860956454965+0.2587502761655236j) + 19.7 (0.7515734153521505-0.4510063664570575j) (0.6596494533734591+0.1992088667212628j) + 19.8 (0.8136737375071058-0.4288655114005536j) (0.5813218118144358+0.24323915864085138j) + 19.9 (0.8676441006416694+0.17871162129618953j) (0.49718579487120196-0.09195658451657955j) + dtype: complex128, shape: (200, 2) + """ + from scipy.signal import hilbert + + if isinstance(data, nap.Tsd): + return nap.Tsd( + d=hilbert(data.values), + t=data.times(), + time_support=data.time_support, + ) + elif isinstance(data, nap.TsdFrame): + return nap.TsdFrame( + d=hilbert(data.values, axis=0), + t=data.times(), + columns=data.columns, + time_support=data.time_support, + ) + else: + raise TypeError("data should be a Tsd or TsdFrame.") + + +def compute_hilbert_envelope(data): + """ + Compute the Hilbert envelope of a time-series. + + This function computes the envelope of the signal, which is the magnitude of the analytic + signal obtained by applying the Hilbert transform. The envelope provides a smooth + representation of the amplitude modulation of the signal. + + Parameters + ---------- + data : Tsd, TsdFrame + The time-series data to compute the Hilbert envelope for. + + Returns + ------- + Tsd, TsdFrame + The Hilbert envelope. + + Examples + -------- + >>> import numpy as np + >>> import pynapple as nap + >>> times = np.arange(0, 20, 0.1) + >>> data = nap.Tsd(d=np.sin(times), t=times) + >>> envelope = nap.compute_hilbert_envelope(data) + >>> envelope + Time (s) + ---------- -------- + 0.0 0.168638 + 0.1 0.435858 + 0.2 0.444004 + 0.3 0.632753 + 0.4 0.64262 + 0.5 0.7477 + 0.6 0.759316 + ... + 19.3 0.953938 + 19.4 0.985048 + 19.5 0.922744 + 19.6 0.958683 + 19.7 0.87651 + 19.8 0.919777 + 19.9 0.885858 + dtype: float64, shape: (200,) + + Can be used for multiple signals in a `TsdFrame`: + + >>> data = nap.TsdFrame(d=np.stack([np.sin(times), np.cos(times)], axis=1), t=times) + >>> envelopes = nap.compute_hilbert_envelope(data) + >>> envelopes + Time (s) 0 1 + ---------- -------- -------- + 0.0 0.168638 1.00596 + 0.1 0.435858 1.02622 + 0.2 0.444004 1.02272 + 0.3 0.632753 1.05857 + 0.4 0.64262 1.05261 + 0.5 0.7477 1.0822 + 0.6 0.759316 1.07408 + ... + 19.3 0.953938 0.905652 + 19.4 0.985048 0.871714 + 19.5 0.922744 0.818226 + 19.6 0.958683 0.775808 + 19.7 0.87651 0.689073 + 19.8 0.919777 0.630159 + 19.9 0.885858 0.505618 + dtype: float64, shape: (200, 2) + """ + analytic_signal = apply_hilbert_transform(data) + return np.abs(analytic_signal) + + +def compute_hilbert_phase(data): + """ + Compute the Hilbert phase of a time-series. + + This function computes the instantaneous phase of the signal using the Hilbert transform. + The phase is unwrapped to provide a continuous representation, and it is then wrapped to + ensure it stays within the range [0, 2π]. + + Parameters + ---------- + data : Tsd, TsdFrame + The time-series data to compute the Hilbert phase for. + + Returns + ------- + Tsd, TsdFrame + The instantaneous phase of the signal, with values wrapped between [0, 2π]. + + Examples + -------- + >>> import numpy as np + >>> import pynapple as nap + >>> times = np.arange(0, 20, 0.1) + >>> data = nap.Tsd(d=np.sin(times), t=times) + >>> phase = nap.compute_hilbert_envelope(data) + >>> phase + Time (s) + ---------- -------- + 0.0 0.168638 + 0.1 0.435858 + 0.2 0.444004 + 0.3 0.632753 + 0.4 0.64262 + 0.5 0.7477 + 0.6 0.759316 + ... + 19.3 0.953938 + 19.4 0.985048 + 19.5 0.922744 + 19.6 0.958683 + 19.7 0.87651 + 19.8 0.919777 + 19.9 0.885858 + dtype: float64, shape: (200,) + + Can be used for multiple signals in a `TsdFrame`: + + >>> data = nap.TsdFrame(d=np.stack([np.sin(times), np.cos(times)], axis=1), t=times) + >>> phases = nap.compute_hilbert_envelope(data) + >>> phases + Time (s) 0 1 + ---------- -------- -------- + 0.0 0.168638 1.00596 + 0.1 0.435858 1.02622 + 0.2 0.444004 1.02272 + 0.3 0.632753 1.05857 + 0.4 0.64262 1.05261 + 0.5 0.7477 1.0822 + 0.6 0.759316 1.07408 + ... + 19.3 0.953938 0.905652 + 19.4 0.985048 0.871714 + 19.5 0.922744 0.818226 + 19.6 0.958683 0.775808 + 19.7 0.87651 0.689073 + 19.8 0.919777 0.630159 + 19.9 0.885858 0.505618 + dtype: float64, shape: (200, 2) + """ + analytic_signal = apply_hilbert_transform(data) + phase = np.angle(analytic_signal) + phase = np.mod(np.unwrap(phase), 2 * np.pi) + return phase + + +def detect_oscillatory_events( + data, + epoch, + freq_band, + thresh_band, + duration_band, + min_inter_duration, + fs=None, + wsize=51, +): + """ + Simple helper for detecting oscillatory events (e.g. ripples, spindles) + + Parameters + ---------- + data : Tsd + 1-dimensional time series + epoch : IntervalSet + The epoch for restricting the detection + freq_band : tuple + The (low, high) frequency to bandpass the signal + thresh_band : tuple + The (min, max) value for thresholding the normalized squared signal after filtering + duration_band : tuple + The (min, max) duration of an event in second + min_inter_duration : float + The minimum duration between two events otherwise they are merged (in seconds) + fs : float, optional + The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. + wsize : int, optional + The size of the window for digital filtering + + Returns + ------- + IntervalSet + The interval set of detected events with metadata containing + the power, amplitude, and peak_time + """ + import warnings + + from scipy.signal import filtfilt + + data = data.restrict(epoch) + + if fs is None: + fs = data.rate + + signal = nap.apply_bandpass_filter(data, freq_band, fs) + squared_signal = np.square(signal.values) + window = np.ones(wsize) / wsize + + nSS = filtfilt(window, 1, squared_signal) + nSS = (nSS - np.mean(nSS)) / np.std(nSS) + nSS = nap.Tsd(t=signal.index.values, d=nSS, time_support=epoch) + + # Detect oscillation periods by thresholding normalized signal + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="Some epochs have no duration", + category=UserWarning, + ) + nSS2 = nSS.threshold(thresh_band[0], method="above") + nSS3 = nSS2.threshold(thresh_band[1], method="below") + + # Exclude oscillation where min_duration < length < max_duration + osc_ep = nSS3.time_support + osc_ep = osc_ep.drop_short_intervals(duration_band[0], time_units="s") + osc_ep = osc_ep.drop_long_intervals(duration_band[1], time_units="s") + + # Merge if inter-oscillation period is too short + osc_ep = osc_ep.merge_close_intervals(min_inter_duration, time_units="s") + + # Compute power, amplitude, and peak_time for each interval + powers = [] + amplitudes = [] + peak_times = [] + + for s, e in osc_ep.values: + seg = signal.get(s, e) + if len(seg) == 0: + powers.append(np.nan) + amplitudes.append(np.nan) + peak_times.append(np.nan) + continue + power = np.mean(np.square(seg)) + power_db = 10 * np.log10(power) + amplitude = np.max(np.abs(seg.values)) + peak_idx = np.argmax(np.abs(seg.values)) + peak_time = seg.index.values[peak_idx] + powers.append(power_db) + amplitudes.append(amplitude) + peak_times.append(peak_time) + + metadata = { + "power": powers, + "amplitude": amplitudes, + "peak_time": peak_times, + } + + return nap.IntervalSet(start=osc_ep.start, end=osc_ep.end, metadata=metadata) diff --git a/tests/test_filtering.py b/tests/test_filtering.py index 499bda87..81e9a4ae 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -523,55 +523,3 @@ def test_get_filter_frequency_response_error(): ValueError, match="Unrecognized filter mode. Choose either 'butter' or 'sinc'" ): nap.get_filter_frequency_response(250, 1000, "lowpass", "a", 4, 0.02) - - -@pytest.mark.parametrize( - "freq_band, thresh_band, num_events, start, end", - [ - ((10, 30), (1, 10), 1, 0, 2), - ((40, 60), (1, 10), 1, 3, 5), - ((100, 150), (1, 10), 0, None, None), - ], -) -def test_detect_oscillatory_events(freq_band, thresh_band, num_events, start, end): - fs = 1000 - duration = 5 - min_dur = 0.1 - max_dur = 2 - min_inter = 0.02 - - t = np.linspace(0, duration, int(fs * duration), endpoint=False) - signal = np.zeros_like(t) - - # 25 Hz oscillation from 0-2s - freq_1 = 25 - mask1 = (t >= 0) & (t < 2) - signal[mask1] = np.sin(2 * np.pi * freq_1 * t[mask1]) - - # 50 Hz oscillation from 3-5s - freq_2 = 50 - mask2 = (t >= 3) & (t < 5) - signal[mask2] = np.sin(2 * np.pi * freq_2 * t[mask2]) - - ts = nap.Tsd(t=t, d=signal) - epoch = nap.IntervalSet(start=0, end=duration) - osc_ep = nap.filtering.detect_oscillatory_events( - ts, epoch, freq_band, thresh_band, (min_dur, max_dur), min_inter - ) - - assert len(osc_ep) == num_events # Only one event in given freq_band - - if num_events > 0: - # Start and end should be close to actuals +/- a small amount - detected_start = osc_ep.start[0] - detected_end = osc_ep.end[0] - assert np.isclose(start, detected_start, atol=0.05) - assert np.isclose(end, detected_end, atol=0.05) - - # Check we store power, amplitude, and peak_time - for key in ["power", "amplitude", "peak_time"]: - assert key in osc_ep._metadata - - # Check peak_time is within the interval - peak_time = osc_ep._metadata["peak_time"][0] - assert start <= peak_time <= end diff --git a/tests/test_signal.py b/tests/test_signal.py new file mode 100644 index 00000000..2a22b9fa --- /dev/null +++ b/tests/test_signal.py @@ -0,0 +1,174 @@ +from contextlib import nullcontext as does_not_raise + +import numpy as np +import pytest +from scipy.signal import hilbert + +import pynapple as nap + + +@pytest.mark.parametrize( + "freq_band, thresh_band, num_events, start, end", + [ + ((10, 30), (1, 10), 1, 0, 2), + ((40, 60), (1, 10), 1, 3, 5), + ((100, 150), (1, 10), 0, None, None), + ], +) +def test_detect_oscillatory_events(freq_band, thresh_band, num_events, start, end): + fs = 1000 + duration = 5 + min_dur = 0.1 + max_dur = 2 + min_inter = 0.02 + + t = np.linspace(0, duration, int(fs * duration), endpoint=False) + signal = np.zeros_like(t) + + # 25 Hz oscillation from 0-2s + freq_1 = 25 + mask1 = (t >= 0) & (t < 2) + signal[mask1] = np.sin(2 * np.pi * freq_1 * t[mask1]) + + # 50 Hz oscillation from 3-5s + freq_2 = 50 + mask2 = (t >= 3) & (t < 5) + signal[mask2] = np.sin(2 * np.pi * freq_2 * t[mask2]) + + ts = nap.Tsd(t=t, d=signal) + epoch = nap.IntervalSet(start=0, end=duration) + osc_ep = nap.detect_oscillatory_events( + ts, epoch, freq_band, thresh_band, (min_dur, max_dur), min_inter + ) + + assert len(osc_ep) == num_events # Only one event in given freq_band + + if num_events > 0: + # Start and end should be close to actuals +/- a small amount + detected_start = osc_ep.start[0] + detected_end = osc_ep.end[0] + assert np.isclose(start, detected_start, atol=0.05) + assert np.isclose(end, detected_end, atol=0.05) + + # Check we store power, amplitude, and peak_time + for key in ["power", "amplitude", "peak_time"]: + assert key in osc_ep._metadata + + # Check peak_time is within the interval + peak_time = osc_ep._metadata["peak_time"][0] + assert start <= peak_time <= end + + +@pytest.mark.parametrize( + "func", + [ + nap.apply_hilbert_transform, + nap.compute_hilbert_envelope, + nap.compute_hilbert_phase, + ], +) +@pytest.mark.parametrize( + "input, expectation", + [ + (nap.Tsd(d=np.ones(3), t=[1, 2, 3]), does_not_raise()), + (nap.TsdFrame(d=np.ones((3, 3)), t=[1, 2, 3]), does_not_raise()), + ( + np.ones((3, 3)), + pytest.raises(TypeError, match="data should be a Tsd or TsdFrame."), + ), + ( + nap.TsdTensor(d=np.ones((3, 3, 3)), t=[1, 2, 3]), + pytest.raises(TypeError, match="data should be a Tsd or TsdFrame."), + ), + ( + [], + pytest.raises(TypeError, match="data should be a Tsd or TsdFrame."), + ), + ( + nap.IntervalSet(1, 2), + pytest.raises(TypeError, match="data should be a Tsd or TsdFrame."), + ), + ( + None, + pytest.raises(TypeError, match="data should be a Tsd or TsdFrame."), + ), + ], +) +def test_hilbert_type_errors(input, func, expectation): + with expectation: + func(input) + + +@pytest.mark.parametrize( + "data", + [ + nap.Tsd(t=np.linspace(0, 1, 500), d=np.sin(np.linspace(0, 1, 500))), + nap.TsdFrame( + t=np.linspace(0, 1, 500), + d=np.stack( + [np.sin(np.linspace(0, 1, 500)), np.cos(np.linspace(0, 1, 500))], + axis=1, + ), + ), + ], +) +def test_apply_hilbert_transform(data): + result = nap.apply_hilbert_transform(data) + expected = hilbert(data.values, axis=0) + + assert isinstance(result, type(data)) + np.testing.assert_array_equal(data.time_support, result.time_support) + np.testing.assert_array_equal(data.times(), result.times()) + + np.testing.assert_array_equal(result.values, expected) + + +@pytest.mark.parametrize( + "data", + [ + nap.Tsd(t=np.linspace(0, 1, 500), d=np.sin(np.linspace(0, 1, 500))), + nap.TsdFrame( + t=np.linspace(0, 1, 500), + d=np.stack( + [np.sin(np.linspace(0, 1, 500)), np.cos(np.linspace(0, 1, 500))], + axis=1, + ), + ), + ], +) +def test_compute_hilbert_phase(data): + result = nap.compute_hilbert_phase(data) + analytic_signal = hilbert(data.values, axis=0) + phase = np.angle(analytic_signal) + expected = np.mod(np.unwrap(phase), 2 * np.pi) + + assert isinstance(result, type(data)) + np.testing.assert_array_equal(data.time_support, result.time_support) + np.testing.assert_array_equal(data.times(), result.times()) + + np.testing.assert_array_equal(result.values, expected) + + +@pytest.mark.parametrize( + "data", + [ + nap.Tsd(t=np.linspace(0, 1, 500), d=np.sin(np.linspace(0, 1, 500))), + nap.TsdFrame( + t=np.linspace(0, 1, 500), + d=np.stack( + [np.sin(np.linspace(0, 1, 500)), np.cos(np.linspace(0, 1, 500))], + axis=1, + ), + ), + ], +) +def test_compute_hilbert_envelope(data): + result = nap.compute_hilbert_envelope(data) + analytic_signal = hilbert(data.values, axis=0) + expected = np.abs(analytic_signal) + + assert isinstance(result, type(data)) + np.testing.assert_array_equal(data.time_support, result.time_support) + np.testing.assert_array_equal(data.times(), result.times()) + + np.testing.assert_array_equal(result.values, expected)