diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 6fe1848ff..c2d6a6e6e 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -8,20 +8,25 @@ jobs: name: Build and test package runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - name: Checkout with full history and tags + uses: actions/checkout@v4 # this is necessary for setuptools_scm to work properly with github # actions, see https://github.com/pypa/setuptools_scm/issues/480 and # https://stackoverflow.com/a/68959339 with: fetch-depth: 0 + tags: true + - name: Set up Python uses: actions/setup-python@v5 with: python-version: 3.x + - name: Build package run: | pip install build python -m build --outdir dist/ --sdist --wheel + - name: Check there's only one sdist and one whl file created shell: bash # because the following two tests will be weird otherwise. see @@ -31,6 +36,7 @@ jobs: run: | [[ $(find dist/ -type f -name "*whl" -printf x | wc -c) == 1 ]] || exit 1 [[ $(find dist/ -type f -name "*tar.gz" -printf x | wc -c) == 1 ]] || exit 1 + - name: Check setuptools_scm version against git tag shell: bash run: | @@ -40,6 +46,7 @@ jobs: # ends in the most recent git tag, fail if it does not. TAG=$(git describe --tags) [[ "$(ls dist/*tar.gz)" =~ "-${TAG:1}.tar.gz" ]] + - name: Check we can install from wheel # note that this is how this works in bash (different shells might be # slightly different). we've checked there's only one .whl file in an @@ -50,10 +57,12 @@ jobs: shell: bash run: | pip install "$(ls dist/*whl)[dev]" + - name: Run some tests # modify the following as necessary to e.g., run notebooks run: | pytest tests/ + - uses: actions/upload-artifact@v4 with: path: dist/* diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index d2ea86667..c5ff7e986 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -22,6 +22,10 @@ jobs: # https://stackoverflow.com/a/68959339 with: fetch-depth: 0 + tags: true + # Ensure all tags are present + - run: git fetch --tags + - uses: actions/setup-python@v5 - name: Install dependencies run: | diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b88cff142..e21964c57 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -20,7 +20,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.11" - name: Install dependencies run: | echo "testing: ${{github.ref}}" @@ -33,8 +33,22 @@ jobs: flake8 pynapple --max-complexity 10 black --check tests isort --check tests --profile black + check-param-naming: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + - name: Install pynapple + run: | + python -m pip install --upgrade pip + pip install . + - name: Check parameter name consistency + run: python scripts/check_parameter_naming.py test: - needs: lint + needs: [lint, check-param-naming] runs-on: ${{ matrix.os }} strategy: matrix: @@ -90,7 +104,13 @@ jobs: with: directory: "doc/_build/html" # The directory to scan - arguments: --checks Links,Scripts --ignore-urls "https://fonts.gstatic.com,https://mkdocs-gallery.github.io,./doc/_build/html/_static/,https://www.nature.com/articles/s41593-022-01020-w" --assume-extension --check-external-hash --ignore-status-codes 403 --ignore-files "/.+\/html\/_static\/.+/" + arguments: + --checks Links,Scripts + --ignore-urls "https://fonts.gstatic.com,https://mkdocs-gallery.github.io,./doc/_build/html/_static/,https://www.nature.com/articles/s41593-022-01020-w,https://elifesciences.org/reviewed-preprints/85786" + --assume-extension + --check-external-hash + --ignore-status-codes 403 + --ignore-files "/.+\/html\/_static\/.+/" # The arguments to pass to HTMLProofer check: @@ -99,6 +119,7 @@ jobs: - lint - test - documentation + - check-param-naming runs-on: ubuntu-latest steps: - name: Decide whether all tests and notebooks succeeded diff --git a/.gitignore b/.gitignore index 66781c0f5..081e5c05d 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,5 @@ your # Ignore npz files from testing: tests/*.npz .vscode/settings.json +doc/user_guide/MyProject/sub-A2929/A2929-200711/stimulus-fish.json +doc/user_guide/memmap.dat diff --git a/README.md b/README.md index d6486bc02..7a8a837f3 100644 --- a/README.md +++ b/README.md @@ -19,9 +19,25 @@ pynapple is a light-weight python library for neurophysiological data analysis. ------------------------------------------------------------------------ + +Learning pynapple +----------------- + +Workshops are regularly organized by the [center for Computational Neuroscience ](https://www.simonsfoundation.org/flatiron/center-for-computational-neuroscience/) of the Flatiron institute +to teach pynapple & [NeMos](https://nemos.readthedocs.io/en/latest/) to new users. + +**The next workshop will take place in New York City on February 2 - 5, 2026. Register [here](https://www.simonsfoundation.org/event/flatiron-ccn-neural-data-analysis-workshop/).** + + New release :fire: ------------------ +### pynapple >= 0.10.0 + +Tuning curves computation have been generalized to n-dimensions with the function `compute_tuning_curves`. +It can now return a [xarray DataArray](https://docs.xarray.dev/en/stable/) instead of a Pandas DataFrame. + + ### pynapple >= 0.8.2 The objects `IntervalSet`, `TsdFrame` and `TsGroup` inherits a new metadata class. It is now possible to add labels for @@ -38,27 +54,6 @@ nap.apply_bandpass_filter(signal, (10, 20), fs=1250) ``` New functions includes power spectral density and Morlet wavelet decomposition. See the [documentation](https://pynapple-org.github.io/pynapple/reference/process/) for more details. -### pynapple >= 0.6 - -Starting with 0.6, [`IntervalSet`](https://pynapple-org.github.io/pynapple/reference/core/interval_set/) objects are behaving as immutable numpy ndarray. Before 0.6, you could select an interval within an `IntervalSet` object with: - -```python -new_intervalset = intervalset.loc[[0]] # Selecting first interval -``` - -With pynapple>=0.6, the slicing is similar to numpy and it returns an `IntervalSet` - -```python -new_intervalset = intervalset[0] -``` - -### pynapple >= 0.4 - -Starting with 0.4, pynapple rely on the [numpy array container](https://numpy.org/doc/stable/user/basics.dispatch.html) approach instead of Pandas for the time series. Pynapple builtin functions will remain the same except for functions inherited from Pandas. - -This allows for a better handling of returned objects. - -Additionaly, it is now possible to define time series objects with more than 2 dimensions with `TsdTensor`. You can also look at this [notebook](https://pynapple-org.github.io/pynapple/generated/gallery/tutorial_pynapple_numpy/) for a demonstration of numpy compatibilities. Community --------- @@ -73,13 +68,11 @@ Getting Started The best way to install pynapple is with pip inside a new [conda](https://docs.conda.io/en/latest/) environment: ``` {.sourceCode .shell} -$ conda create --name pynapple pip python=3.8 +$ conda create --name pynapple pip python=3.11 $ conda activate pynapple $ pip install pynapple ``` -> **Note** -> The package uses a pyproject.toml file for installation and dependencies management. Running `pip install pynapple` will install all the dependencies, including: @@ -90,13 +83,14 @@ Running `pip install pynapple` will install all the dependencies, including: - pynwb 2.0 - tabulate - h5py +- xarray For development, see the [contributor guide](CONTRIBUTING.md) for steps to install from source code. +The decoded HD (dashed grey line) closely matches the actual HD (solid white line), and thus the population activity in ADn is a reliable estimate of the heading direction of the animal. I hope this tutorial was helpful. If you have any questions, comments or suggestions, please feel free to reach out to the Pynapple Team! @@ -292,8 +300,11 @@ I hope this tutorial was helpful. If you have any questions, comments or suggest :::{card} Authors ^^^ +Wolf de Wulf + Dhruv Mehrotra Guillaume Viejo -::: \ No newline at end of file +::: + diff --git a/doc/examples/tutorial_calcium_imaging.md b/doc/examples/tutorial_calcium_imaging.md index bb594bd8c..396a8bcf1 100644 --- a/doc/examples/tutorial_calcium_imaging.md +++ b/doc/examples/tutorial_calcium_imaging.md @@ -11,44 +11,37 @@ kernelspec: name: python3 --- - Calcium Imaging ============ Working with calcium data. -For the example dataset, we will be working with a recording of a freely-moving mouse imaged with a Miniscope (1-photon imaging). The area recorded for this experiment is the postsubiculum - a region that is known to contain head-direction cells, or cells that fire when the animal's head is pointing in a specific direction. +As example dataset, we will be working with a recording of a freely-moving mouse imaged with a Miniscope (1-photon imaging). +The area recorded for this experiment is the postsubiculum - a region that is known to contain head-direction cells, or cells that fire when the animal's head is pointing in a specific direction. The NWB file for the example is hosted on [OSF](https://osf.io/sbnaw). We show below how to stream it. - ```{code-cell} ipython3 ---- -jupyter: - outputs_hidden: false ---- -import numpy as pd +:tags: [hide-output] +import numpy as np import pynapple as nap import matplotlib.pyplot as plt import seaborn as sns -import sys, os -import requests, math +import os +import requests +import xarray as xr custom_params = {"axes.spines.right": False, "axes.spines.top": False} sns.set_theme(style="ticks", palette="colorblind", font_scale=1.5, rc=custom_params) +xr.set_options(display_expand_attrs=False) ``` *** Downloading the data ------------------ -First things first: Let's find our file - +First things first: let's find our file. ```{code-cell} ipython3 ---- -jupyter: - outputs_hidden: false ---- path = "A0670-221213.nwb" if path not in os.listdir("."): r = requests.get(f"https://osf.io/sbnaw/download", stream=True) @@ -61,42 +54,26 @@ if path not in os.listdir("."): *** Parsing the data ------------------ -Now that we have the file, let's load the data - +Now that we have the file, let's load the data: ```{code-cell} ipython3 ---- -jupyter: - outputs_hidden: false ---- -data = nap.load_file(path) -print(data) +data = nap.load_file(path, lazy_loading=False) +data ``` -Let's save the RoiResponseSeries as a variable called 'transients' and print it - +Let's save the RoiResponseSeries as a variable called 'transients' and print it: ```{code-cell} ipython3 ---- -jupyter: - outputs_hidden: false ---- transients = data['RoiResponseSeries'] -print(transients) +transients ``` *** Plotting the activity of one neuron ----------------------------------- -Our transients are saved as a (35757, 65) TsdFrame. Looking at the printed object, you can see that we have 35757 data points for each of our 65 regions of interest. We want to see which of these are head-direction cells, so we need to plot a tuning curve of fluorescence vs head-direction of the animal. - - +Our transients are saved as a (35757, 65) TsdFrame. Looking at the printed object, you can see that we have 35757 data points for each of our 65 regions of interest (ROIs). We want to see which of these are head-direction cells, so we need to plot a tuning curve of fluorescence vs head-direction of the animal. ```{code-cell} ipython3 ---- -jupyter: - outputs_hidden: false ---- plt.figure(figsize=(6, 2)) plt.plot(transients[0:2000,0], linewidth=5) plt.xlabel("Time (s)") @@ -104,58 +81,47 @@ plt.ylabel("Fluorescence") plt.show() ``` -Here we extract the head-direction as a variable called angle - +Here, we extract the head-direction as a variable called angle. ```{code-cell} ipython3 ---- -jupyter: - outputs_hidden: false ---- angle = data['ry'] -print(angle) +angle ``` As you can see, we have a longer recording for our tracking of the animal's head than we do for our calcium imaging - something to keep in mind. - ```{code-cell} ipython3 ---- -jupyter: - outputs_hidden: false ---- -print(transients.time_support) -print(angle.time_support) +transients.time_support ``` *** Calcium tuning curves --------------------- -Here we compute the tuning curves of all the neurons - +Here, we compute the tuning curves of all the ROIs. ```{code-cell} ipython3 ---- -jupyter: - outputs_hidden: false ---- -tcurves = nap.compute_1d_tuning_curves_continuous(transients, angle, nb_bins = 120) - -print(tcurves) +tuning_curves = nap.compute_tuning_curves(transients, angle, bins=120) +tuning_curves ``` -We now have a DataFrame, where our index is the angle of the animal's head in radians, and each column represents the tuning curve of each region of interest. We can plot one neuron. +This yields an `xarray.DataFrame`, which we can beautify by setting feature names and units: + +```{code-cell} ipython3 +def set_metadata(tuning_curves): + _tuning_curves=tuning_curves.rename({"0": "Angle", "unit": "ROI"}) + _tuning_curves.name="Fluorescence" + _tuning_curves.attrs["units"]="a.u." + _tuning_curves.coords["Angle"].attrs["units"]="rad" + return _tuning_curves + +annotated_tuning_curves = set_metadata(tuning_curves) +annotated_tuning_curves +``` +Having set some metadata, we can easily plot one ROI: ```{code-cell} ipython3 ---- -jupyter: - outputs_hidden: false ---- -plt.figure() -plt.plot(tcurves[4]) -plt.xlabel("Angle") -plt.ylabel("Fluorescence") +annotated_tuning_curves[4].plot() plt.show() ``` @@ -163,47 +129,190 @@ It looks like this could be a head-direction cell. One important property of hea We start by finding the midpoint of the recording, using the function [`get_intervals_center`](pynapple.IntervalSet.get_intervals_center). Using this, then create one new IntervalSet with two rows, one for each half of the recording. - ```{code-cell} ipython3 ---- -jupyter: - outputs_hidden: false ---- center = transients.time_support.get_intervals_center() halves = nap.IntervalSet( - start = [transients.time_support.start[0], center.t[0]], + start = [transients.time_support.start[0], center.t[0]], end = [center.t[0], transients.time_support.end[0]] +) +``` + +Now, we can compute the tuning curves for each half of the recording and plot the tuning curves again. + +```{code-cell} ipython3 +tuning_curves_half1 = nap.compute_tuning_curves(transients, angle, bins = 120, epochs = halves.loc[[0]]) +tuning_curves_half2 = nap.compute_tuning_curves(transients, angle, bins = 120, epochs = halves.loc[[1]]) + +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) +set_metadata(tuning_curves_half1[4]).plot(ax=ax1) +ax1.set_title("First half") +set_metadata(tuning_curves_half2[4]).plot(ax=ax2) +ax2.set_title("Second half") +plt.show() +``` + +*** +Calcium decoding +--------------------- + +Given some tuning curves, we can also try to decode head direction from the population activity. +For calcium imaging data, Pynapple has `decode_template`, which implements a template matching algorithm. + +```{code-cell} ipython3 +epochs = nap.IntervalSet([50, 150]) +transients = transients.bin_average(0.1) +decoded, dist = nap.decode_template( + tuning_curves=tuning_curves, + data=transients, + epochs=epochs, + bin_size=0.1, + metric="correlation", +) +``` + +```{code-cell} ipython3 +:tags: [hide-input] +# normalize distance for better visualization +dist_norm = (dist - np.min(dist.values, axis=1, keepdims=True)) / np.ptp( + dist.values, axis=1, keepdims=True +) + +fig, (ax1, ax2, ax3) = plt.subplots(figsize=(8, 8), nrows=3, ncols=1, sharex=True) +ax1.plot(angle.restrict(epochs), label="True") +ax1.scatter(decoded.times(), decoded.values, label="Decoded", c="orange") +ax1.legend(frameon=False, bbox_to_anchor=(1.0, 1.0)) +ax1.set_ylabel("Angle [rad]") + +im = ax2.imshow( + dist.values.T, + aspect="auto", + origin="lower", + cmap="inferno_r", + extent=(epochs.start[0], epochs.end[0], 0.0, 2*np.pi) +) +ax2.set_ylabel("Angle [rad]") +cbar_ax2 = fig.add_axes([0.95, ax2.get_position().y0, 0.015, ax2.get_position().height]) +fig.colorbar(im, cax=cbar_ax2, label="Distance") + +im = ax3.imshow( + dist_norm.values.T, + aspect="auto", + origin="lower", + cmap="inferno_r", + extent=(epochs.start[0], epochs.end[0], 0.0, 2*np.pi) +) +cbar_ax3 = fig.add_axes([0.95, ax3.get_position().y0, 0.015, ax3.get_position().height]) +fig.colorbar(im, cax=cbar_ax3, label="Norm. distance") +ax3.set_xlabel("Time (s)") +ax3.set_ylabel("Angle [rad]") +plt.show() +``` + +The distance metric you choose can influence how well we decode. +Internally, ``decode_template`` uses `scipy.spatial.distance.cdist` to compute the distance matrix; +you can take a look at [its documentation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html) +to see which metrics are supported. Here are a couple examples: + +```{code-cell} ipython3 +:tags: [hide-input] +metrics = [ + "chebyshev", + "dice", + "canberra", + "sqeuclidean", + "minkowski", + "euclidean", + "cityblock", + "mahalanobis", + "correlation", + "cosine", + "seuclidean", + "braycurtis", + "jensenshannon", +] + +fig, axs = plt.subplots(5, 1, figsize=(8,12), sharex=True, sharey=True) +for metric, ax in zip(metrics[-5:], axs.flatten()): + decoded, dist = nap.decode_template( + tuning_curves=tuning_curves, + data=transients, + bin_size=0.1, + metric=metric, + epochs=epochs, ) + # normalize distance for better visualization + dist_norm = (dist - np.min(dist.values, axis=1, keepdims=True)) / np.ptp( + dist.values, axis=1, keepdims=True + ) + ax.plot(angle.restrict(epochs), label="True") + im = ax.imshow( + dist_norm.values.T, + aspect="auto", + origin="lower", + cmap="inferno_r", + extent=(epochs.start[0], epochs.end[0], 0.0, 2*np.pi) + ) + if metric != metrics[-1]: + ax.spines['bottom'].set_visible(False) + ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False) + ax.set_yticks([]) + ax.spines['left'].set_visible(False) + ax.set_ylabel(metric) +cbar_ax = fig.add_axes([0.92, ax.get_position().y0, 0.015, ax.get_position().height]) +cbar=fig.colorbar(im, cax=cbar_ax) +cbar.set_label("Norm. distance") +ax.set_xlabel("Time (s)") +plt.show() ``` -Now we can compute the tuning curves for each half of the recording and plot the tuning curves for the fifth region of interest. +We recommend trying a bunch to see which one works best for you. +In the case of head direction, we can quantify how well we decode using the absolute angular error. +To get a fair estimate of error, we will compute the tuning curves on the first half of the data +and compute the error for predictions of the second half. +```{code-cell} ipython3 +def absolute_angular_error(x, y): + return np.abs(np.angle(np.exp(1j * (x - y)))) + +# Compute errors +errors = {} +for metric in metrics: + decoded, dist = nap.decode_template( + tuning_curves=tuning_curves_half1, + data=transients, + bin_size=0.1, + metric=metric, + epochs=halves.loc[[1]], + ) + errors[metric] = absolute_angular_error( + angle.interpolate(decoded).values, decoded.values + ) +``` ```{code-cell} ipython3 ---- -jupyter: - outputs_hidden: false ---- -half1 = nap.compute_1d_tuning_curves_continuous(transients, angle, nb_bins = 120, ep = halves.loc[[0]]) -half2 = nap.compute_1d_tuning_curves_continuous(transients, angle, nb_bins = 120, ep = halves.loc[[1]]) - -plt.figure(figsize=(12, 5)) -plt.subplot(1,2,1) -plt.plot(half1[4]) -plt.title("First half") -plt.xlabel("Angle") -plt.ylabel("Fluorescence") -plt.subplot(1,2,2) -plt.plot(half2[4]) -plt.title("Second half") -plt.xlabel("Angle") +:tags: [hide-input] +sorted_items = sorted(errors.items(), key=lambda item: np.median(item[1])) +sorted_labels, sorted_values = zip(*sorted_items) + +fig, ax = plt.subplots(figsize=(8, 8)) +bp = ax.boxplot( + x=sorted_values, + tick_labels=sorted_labels, + vert=False, + showfliers=False +) +ax.set_xlabel("Angular error [rad]") plt.show() ``` +In this case, `jensenshannon` yields the lowest angular error. + :::{card} Authors ^^^ Sofia Skromne Carrasco +Wolf De Wulf + ::: diff --git a/doc/examples/tutorial_phase_preferences.md b/doc/examples/tutorial_phase_preferences.md index 35e9e0599..40f3c98e0 100644 --- a/doc/examples/tutorial_phase_preferences.md +++ b/doc/examples/tutorial_phase_preferences.md @@ -11,7 +11,6 @@ kernelspec: name: python3 --- - Spikes-phase coupling ===================== @@ -20,7 +19,6 @@ with spiking data, to find phase preferences of spiking units. Specifically, we will examine LFP and spiking data from a period of REM sleep, after traversal of a linear track. - ```{code-cell} ipython3 import math import os @@ -41,8 +39,7 @@ sns.set_theme(style="ticks", palette="colorblind", font_scale=1.5, rc=custom_par *** Downloading the data ------------------ -Let's download the data and save it locally - +Let's download the data and save it locally. ```{code-cell} ipython3 path = "Achilles_10252013_EEG.nwb" @@ -59,11 +56,10 @@ Loading the data ------------------ Let's load and print the full dataset. - ```{code-cell} ipython3 data = nap.load_file(path) FS = 1250 # We know from the methods of the paper -print(data) +data ``` *** @@ -71,7 +67,6 @@ Selecting slices ----------------------------------- For later visualization, we define an interval of 3 seconds of data during REM sleep. - ```{code-cell} ipython3 ep_ex_rem = nap.IntervalSet( data["rem"]["start"][0] + 97.0, @@ -81,7 +76,6 @@ ep_ex_rem = nap.IntervalSet( Here we restrict the lfp to the REM epochs. - ```{code-cell} ipython3 tsd_rem = data["eeg"][:,0].restrict(data["rem"]) @@ -95,13 +89,13 @@ Plotting the LFP Activity ----------------------------------- We should first plot our REM Local Field Potential data. - ```{code-cell} ipython3 fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 3)) ax.plot(tsd_rem.restrict(ep_ex_rem)) ax.set_title("REM Local Field Potential") ax.set_ylabel("LFP (a.u.)") ax.set_xlabel("time (s)") +plt.show() ``` *** @@ -114,14 +108,12 @@ frequencies present in the data. We must define the frequency set that we'd like to use for our decomposition. - ```{code-cell} ipython3 freqs = np.geomspace(5, 200, 25) ``` We compute the wavelet transform on our LFP data (only during the example interval). - ```{code-cell} ipython3 cwt_rem = nap.compute_wavelet_transform(tsd_rem.restrict(ep_ex_rem), fs=FS, freqs=freqs) ``` @@ -129,7 +121,6 @@ cwt_rem = nap.compute_wavelet_transform(tsd_rem.restrict(ep_ex_rem), fs=FS, freq *** Now let's plot the calculated wavelet scalogram. - ```{code-cell} ipython3 # Define wavelet decomposition plotting function def plot_timefrequency(freqs, powers, ax=None): @@ -155,6 +146,7 @@ ax1.plot(tsd_rem.restrict(ep_ex_rem)) ax1.set_ylabel("LFP (a.u.)") ax1.set_xlabel("Time (s)") ax1.margins(0) +plt.show() ``` *** @@ -163,14 +155,12 @@ Filtering Theta As expected, there is a strong 8Hz component during REM sleep. We can filter it using the function [`nap.apply_bandpass_filter`](pynapple.process.filtering.apply_bandpass_filter). - ```{code-cell} ipython3 theta_band = nap.apply_bandpass_filter(tsd_rem, cutoff=(6.0, 10.0), fs=FS) ``` We can plot the original signal and the filtered signal. - ```{code-cell} ipython3 plt.figure(constrained_layout=True, figsize=(12, 3)) plt.plot(tsd_rem.restrict(ep_ex_rem), alpha=0.5) @@ -185,7 +175,6 @@ Computing phase From the filtered signal, it is easy to get the phase using the Hilbert transform. Here we use scipy Hilbert method. - ```{code-cell} ipython3 from scipy import signal @@ -194,7 +183,6 @@ theta_phase = nap.Tsd(t=theta_band.t, d=np.angle(signal.hilbert(theta_band))) Let's plot the phase. - ```{code-cell} ipython3 plt.figure(constrained_layout=True, figsize=(12, 3)) plt.subplot(211) @@ -211,34 +199,33 @@ plt.show() Finding Phase of Spikes ----------------------- Now that we have the phase of our theta wavelet, and our spike times, we can find the phase firing preferences -of each of the units using the [`compute_1d_tuning_curves`](pynapple.process.tuning_curves.compute_1d_tuning_curves) function. +of each of the units using the [`compute_tuning_curves`](pynapple.process.tuning_curves.compute_tuning_curves) function. We will start by throwing away cells which do not have a high enough firing rate during our interval. - ```{code-cell} ipython3 spikes = spikes[spikes.rate > 5.0] ``` The feature is the theta phase during REM sleep. - ```{code-cell} ipython3 -phase_modulation = nap.compute_1d_tuning_curves( - group=spikes, feature=theta_phase, nb_bins=61, minmax=(-np.pi, np.pi) +phase_modulation = nap.compute_tuning_curves( + data=spikes, + features=theta_phase, + bins=61, + range=(-np.pi, np.pi), + feature_names=["Phase"] ) ``` Let's plot the first 3 neurons. - ```{code-cell} ipython3 -plt.figure(constrained_layout=True, figsize = (12, 3)) -for i in range(3): - plt.subplot(1,3,i+1) - plt.plot(phase_modulation.iloc[:,i]) - plt.xlabel("Phase (rad)") - plt.ylabel("Firing rate (Hz)") +phase_modulation.name="Firing Rate" +phase_modulation.attrs["units"]="Hz" +phase_modulation.coords["Phase"].attrs["unit"]="rad" +phase_modulation[:3].plot(row="unit", col_wrap=3, sharey=False) plt.show() ``` @@ -246,14 +233,12 @@ There is clearly a strong modulation for the third neuron. Finally, we can use the function [`value_from`](pynapple.Ts.value_from) to align each spikes to the corresponding phase position and overlay it with the LFP. - ```{code-cell} ipython3 spike_phase = spikes[spikes.index[3]].value_from(theta_phase) ``` Let's plot it. - ```{code-cell} ipython3 plt.figure(constrained_layout=True, figsize=(12, 3)) plt.subplot(211) @@ -274,4 +259,4 @@ Kipp Freud (https://kippfreud.com/) Guillaume Viejo -::: \ No newline at end of file +::: diff --git a/doc/examples/tutorial_pynapple_dandi.md b/doc/examples/tutorial_pynapple_dandi.md index 2efd85aeb..a0061a6ca 100644 --- a/doc/examples/tutorial_pynapple_dandi.md +++ b/doc/examples/tutorial_pynapple_dandi.md @@ -11,7 +11,6 @@ kernelspec: name: python3 --- - Streaming data from DANDI ========================= @@ -30,7 +29,6 @@ DANDI ----- DANDI allows you to stream data without downloading all the files. In this case the data extracted from the NWB file are stored in the nwb-cache folder. - ```{code-cell} ipython3 from pynwb import NWBHDF5IO @@ -63,15 +61,13 @@ fs = CachingFileSystem( # next, open the file file = h5py.File(fs.open(s3_url, "rb")) io = NWBHDF5IO(file=file, load_namespaces=True) - -print(io) +io ``` Pynapple -------- If opening the NWB works, you can start streaming data straight into pynapple with the `NWBFile` class. - ```{code-cell} ipython3 import pynapple as nap import matplotlib.pyplot as plt @@ -82,49 +78,39 @@ custom_params = {"axes.spines.right": False, "axes.spines.top": False} sns.set_theme(style="ticks", palette="colorblind", font_scale=1.5, rc=custom_params) nwb = nap.NWBFile(io.read()) - -print(nwb) +nwb ``` We can load the spikes as a TsGroup for inspection. - ```{code-cell} ipython3 units = nwb["units"] - -print(units) +units ``` -As well as the position - +As well as the position: ```{code-cell} ipython3 position = nwb["SpatialSeriesLED1"] ``` -Here we compute the 2d tuning curves - +Here, we compute the 2d tuning curves: ```{code-cell} ipython3 -tc, binsxy = nap.compute_2d_tuning_curves(units, position, 20) +tuning_curves = nap.compute_tuning_curves(units, position, 20) ``` -Let's plot the tuning curves - +Let's plot the tuning curves: ```{code-cell} ipython3 -plt.figure(figsize=(15, 7)) -for i in tc.keys(): - plt.subplot(2, 4, i + 1) - plt.imshow(tc[i], origin="lower", aspect="auto") - plt.title("Unit {}".format(i)) -plt.tight_layout() +tuning_curves.name="Firing Rate" +tuning_curves.attrs["units"] = "Hz" +tuning_curves.plot(row="unit", col_wrap=4, figsize=(15, 7)) plt.show() ``` -Let's plot the spikes of unit 1 who has a nice grid -Here I use the function [`value_from`](pynapple.Ts.value_from) to assign to each spike the closest position in time. - +Let's plot the spikes of unit 1, which has a nice grid. +Here, I use the [`value_from`](pynapple.Ts.value_from) function to assign to each spike the closest position in time. ```{code-cell} ipython3 plt.figure(figsize=(15, 6)) @@ -135,7 +121,7 @@ extent = ( np.min(position["y"]), np.max(position["y"]), ) -plt.imshow(tc[1], origin="lower", extent=extent, aspect="auto") +plt.imshow(tuning_curves[1], origin="lower", extent=extent, aspect="auto") plt.xlabel("x") plt.ylabel("y") diff --git a/doc/releases.md b/doc/releases.md index 1458675f7..4a07b9bb8 100644 --- a/doc/releases.md +++ b/doc/releases.md @@ -1,5 +1,27 @@ # Releases + +### 0.10.1 (2025-10-30) + +- Fixing smoothing for `nap.decode_bayes`. +- Fixing `np.einsum`. + +### 0.10.0 (2025-10-27) + +- Generalizing `nap.compute_tuning_curves`. It can take any time series object (Tsd, TsdFrame, TsGroup, TsdTensor) as input and + work for any dimension of data. +- `nap.compute_1d_tuning_curve`, `nap.compute_2d_tuning_curve`, `nap.compute_1d_tuning_curve_continuous`, `nap.compute_2d_tuning_curve_continuous` + are being deprecated in favor of the general `nap.compute_tuning_curves`. +- Generalization of `nap.decode_1d` and `nap.decode_2d` to `nap.decode_bayes` for bayesian decoding of any dimension of data. +- New function `nap.decode_template` for template matching decoding of any dimension of data. +- Metadata can be restricted with `restrict_info`. +- New function `detect_oscillatory_events` to detect oscillatory events in a Tsd object. +- Fix TsdFrame `__repr__` for boolean data type. +- Refactoring of `nap.compute_mutual_information` to take as input xarray tuning curves object. +- `in_interval` method for IntervalSet to check if time points are within intervals. +- Refactoring `nap.compute_discrete_tuning_curves` to `compute_response_per_epoch`. +- Tuning curves function can return spike counts and occupancy separately. + ### 0.9.2 (2025-06-16) - Implement `time_diff` method for time series objects diff --git a/doc/user_guide/01_introduction_to_pynapple.md b/doc/user_guide/01_introduction_to_pynapple.md index 7140859ea..d3e9bdc8f 100644 --- a/doc/user_guide/01_introduction_to_pynapple.md +++ b/doc/user_guide/01_introduction_to_pynapple.md @@ -427,10 +427,6 @@ Overview of advanced analysis The `process` module of pynapple contains submodules that group methods that can be applied for high level analysis. All of the method are directly available from the `nap` namespace. -:::{important} -Some functions have been doubled given the nature of the data. For instance, computing a 1d tuning curves from spiking activity requires the [`nap.compute_1d_tuning_curves`](pynapple.process.tuning_curves.compute_1d_tuning_curves). The same function for calcium imaging data which is a continuous time series is available with [`nap.compute_1d_tuning_curves_continuous`](pynapple.process.tuning_curves.compute_1d_tuning_curves_continuous). -::: - **[Discrete correlograms & ISI](05_correlograms_isi)** This module analyses discrete events, specifically correlograms (for example by computing the cross-correlograms of a population of neurons) and interspike interval (ISI) distributions. diff --git a/doc/user_guide/02_input_output.md b/doc/user_guide/02_input_output.md index 695d37a11..9588ff9bb 100644 --- a/doc/user_guide/02_input_output.md +++ b/doc/user_guide/02_input_output.md @@ -113,7 +113,7 @@ You can still apply any high level function of pynapple. For example here, we co ```{code-cell} ipython3 -tc = nap.compute_1d_tuning_curves(data['units'], data['y'], 10) +tc = nap.compute_tuning_curves(data['units'], data['y'], 10) ``` @@ -212,7 +212,7 @@ It is possible to use Higher level library like [zarr](https://zarr.readthedocs. ```{code-cell} ipython3 import zarr zarr_array = zarr.zeros((10000, 5), chunks=(1000, 5), dtype='i4') -timestep = np.arange(len(zarr_array)) +timestep = np.arange(zarr_array.shape[0]) tsdframe = nap.TsdFrame(t=timestep, d=zarr_array) ``` diff --git a/doc/user_guide/03_core_methods.md b/doc/user_guide/03_core_methods.md index da93f9352..750a97809 100644 --- a/doc/user_guide/03_core_methods.md +++ b/doc/user_guide/03_core_methods.md @@ -63,6 +63,30 @@ print(epochs) print(tsdframe.restrict(epochs).time_support) ``` +### `in_interval` + +[`in_interval`](pynapple.Tsd.in_interval) is similar to [`restrict`](pynapple.Tsd.restrict), except instead of returning the restricted time series, it returns a `Tsd` of booleans for each time point indicating whether or not it falls within the intervals of an `IntervalSet`. + +```{code-cell} ipython3 +tsdframe.in_interval(epochs) +``` +```{code-cell} ipython3 +:tags: [hide-input] +plt.figure() +plt.subplot(2,1,1) +plt.plot(tsdframe) +[plt.axvspan(s, e, alpha=0.2) for s, e in epochs.values] +plt.xlim(0, 100) +plt.subplot(2,1,2) +plt.plot(tsdframe.in_interval(epochs)) +plt.xlabel("Time (s)") +plt.title("tsdframe.in_interval(epochs)") +plt.xlim(0, 100) +plt.tight_layout() +plt.show() +``` + + ### `count` [`count`](pynapple.Tsd.count) returns the number of timestamps within bins or epochs of an `IntervalSet` object. diff --git a/doc/user_guide/03_metadata.md b/doc/user_guide/03_metadata.md index 6c1116ce9..2ac1a9102 100644 --- a/doc/user_guide/03_metadata.md +++ b/doc/user_guide/03_metadata.md @@ -11,6 +11,7 @@ kernelspec: name: python3 --- + # Metadata Metadata can be added to `TsGroup`, `IntervalSet`, and `TsdFrame` objects at initialization or after an object has been created. - `TsGroup` metadata is information associated with each Ts/Tsd object, such as brain region or unit type. @@ -26,7 +27,8 @@ At initialization, metadata can be passed via a dictionary or pandas DataFrame u - The `rate` attribute for `TsGroup` is stored with the metadata and cannot be overwritten. ``` -The length of the metadata must match the length of the object it describes (see class examples below for more detail). +The length of the metadata must match the length of the object it describes (see class examples below for more detail). + ```{code-cell} ipython3 :tags: [hide-cell] @@ -54,6 +56,7 @@ columns = ["a", "b", "c"] ### `TsGroup` Metadata added to `TsGroup` must match the number of `Ts`/`Tsd` objects, or the length of its `index` property. + ```{code-cell} ipython3 metadata = {"region": ["pfc", "pfc", "hpc", "hpc"]} @@ -62,6 +65,7 @@ print(tsgroup) ``` When initializing with a DataFrame, the index must align with the input dictionary keys (only when a dictionary is used to create the `TsGroup`). + ```{code-cell} ipython3 metadata = pd.DataFrame( index=group.keys(), @@ -73,9 +77,8 @@ tsgroup = nap.TsGroup(group, metadata=metadata) print(tsgroup) ``` - ### `IntervalSet` -Metadata added to `IntervalSet` must match the number of intervals, or the length of its `index` property. +Metadata added to `IntervalSet` must match the number of intervals, or the length of its `index` property. ```{code-cell} ipython3 metadata = { @@ -87,6 +90,7 @@ print(intervalset) ``` Metadata can be initialized as a DataFrame using the metadata argument, or it can be inferred when initializing an `IntervalSet` with a DataFrame. + ```{code-cell} ipython3 df = pd.DataFrame( data=[[0, 30, 1, "left"], [35, 65, 0, "right"], [70, 100, 1, "left"]], @@ -98,7 +102,8 @@ print(intervalset) ``` ### `TsdFrame` -Metadata added to `TsdFrame` must match the number of data columns, or the length of its `columns` property. +Metadata added to `TsdFrame` must match the number of data columns, or the length of its `columns` property. + ```{code-cell} ipython3 metadata = { "color": ["red", "blue", "green"], @@ -111,6 +116,7 @@ print(tsdframe) ``` When initializing with a DataFrame, the DataFrame index must match the `TsdFrame` columns. + ```{code-cell} ipython3 metadata = pd.DataFrame( index=["a", "b", "c"], @@ -131,6 +137,7 @@ The remaining metadata examples will be shown on a `TsGroup` object; however, al ### `set_info` Metadata can be passed as a dictionary or pandas DataFrame as the first positional argument, or metadata can be passed as name-value keyword arguments. + ```{code-cell} ipython3 tsgroup.set_info(unit_type=["multi", "single", "single", "single"]) print(tsgroup) @@ -138,6 +145,7 @@ print(tsgroup) ### Using dictionary-like keys (square brackets) Most metadata names can set as a dictionary-like key (i.e. using square brackets). The only exceptions are for `IntervalSet`, where the names "start" and "end" are reserved for class properties. + ```{code-cell} ipython3 tsgroup["depth"] = [0, 1, 2, 3] print(tsgroup) @@ -145,6 +153,7 @@ print(tsgroup) ### Using attribute assignment If the metadata name is unique from other class attributes and methods, and it is formatted properly (i.e. only alpha-numeric characters and underscores), it can be set as an attribute (i.e. using a `.` followed by the metadata name). + ```{code-cell} ipython3 tsgroup.label=["MUA", "good", "good", "good"] print(tsgroup) @@ -152,6 +161,7 @@ print(tsgroup) ## Allowed data types As long as the length of the metadata container matches the length of the object (number of columns for `TsdFrame` and number of indices for `IntervalSet` and `TsGroup`), elements of the metadata can be any data type. + ```{code-cell} ipython3 tsgroup.coords = [[1,0],[0,1],[1,1],[2,1]] print(tsgroup) @@ -159,16 +169,19 @@ print(tsgroup) ## Accessing metadata Metadata is stored as a pandas DataFrame, which can be previewed using the `metadata` attribute. + ```{code-cell} ipython3 print(tsgroup.metadata) ``` Single metadata columns (or lists of columns) can be retrieved using the [`get_info()`](pynapple.TsGroup.get_info) class method: + ```{code-cell} ipython3 print(tsgroup.get_info("region")) ``` Similarly, metadata can be accessed using key indexing (i.e. square brakets) + ```{code-cell} ipython3 print(tsgroup["region"]) ``` @@ -178,12 +191,14 @@ Metadata names must be strings. Key indexing with an integer will produce differ ``` Finally, metadata that can be set as an attribute can also be accessed as an attribute. + ```{code-cell} ipython3 print(tsgroup.region) ``` ## Overwriting metadata User-set metadata is mutable and can be overwritten. + ```{code-cell} ipython3 print(tsgroup, "\n") tsgroup.set_info(label=["A", "B", "C", "D"]) @@ -192,75 +207,97 @@ print(tsgroup) ## Dropping metadata To drop metadata, use the [`drop_info()`](pynapple.TsGroup.drop_info) method. Multiple metadata columns can be dropped by passing a list of metadata names. + ```{code-cell} ipython3 print(tsgroup, "\n") tsgroup.drop_info("coords") print(tsgroup) ``` +## Restricting metadata +Instead of dropping multiple metadata fields, you may want to restrict to a set of specified fields, i.e. select which columns to keep. For this operation, use the [`restrict_info()`](pynapple.TsGroup.restrict_info) method. Multiple metadata columns can be kept by passing a list of metadata names. +```{code-cell} ipython3 +import copy +tsgroup2 = copy.deepcopy(tsgroup) +tsgroup2.restrict_info("region") +print(tsgroup2) +``` +```{admonition} Note +The `rate` column will always be kept for a `TsGroup`. +``` + ## Using metadata to slice objects Metadata can be used to slice or filter objects based on metadata values. + ```{code-cell} ipython3 print(tsgroup[tsgroup.label == "A"]) ``` ## `groupby`: Using metadata to group objects Similar to pandas, metadata can be used to group objects based on one or more metadata columns using the object method [`groupby`](pynapple.TsGroup.groupby), where the first argument is the metadata columns name(s) to group by. This function returns a dictionary with keys corresponding to unique groups and values corresponding to object indices belonging to each group. + ```{code-cell} ipython3 print(tsgroup,"\n") print(tsgroup.groupby("region")) ``` Grouping by multiple metadata columns should be passed as a list. + ```{code-cell} ipython3 tsgroup.groupby(["region","unit_type"]) ``` The optional argument `get_group` can be provided to return a new object corresponding to a specific group. + ```{code-cell} ipython3 tsgroup.groupby("region", get_group="hpc") ``` ## `groupby_apply`: Applying functions to object groups The `groupby_apply` object method allows a specific function to be applied to object groups. The first argument, same as `groupby`, is the metadata column(s) used to group the object. The second argument is the function to apply to each group. If only these two arguments are supplied, it is assumed that the grouped object is the first and only input to the applied function. This function returns a dictionary, where keys correspond to each unique group, and values correspond to the function output on each group. + ```{code-cell} ipython3 print(tsdframe,"\n") print(tsdframe.groupby_apply("label", np.mean)) ``` If the applied function requires additional inputs, these can be passed as additional keyword arguments into `groupby_apply`. + ```{code-cell} ipython3 feature = nap.Tsd(t=np.arange(100), d=np.repeat([0,1], 50)) tsgroup.groupby_apply( "region", - nap.compute_1d_tuning_curves, - feature=feature, - nb_bins=2) + nap.compute_tuning_curves, + features=feature, + bins=2) ``` Alternatively, an anonymous function can be passed instead that defines additional arguments. + ```{code-cell} ipython3 -func = lambda x: nap.compute_1d_tuning_curves(x, feature=feature, nb_bins=2) +func = lambda x: nap.compute_tuning_curves(x, features=feature, bins=2) tsgroup.groupby_apply("region", func) ``` An anonymous function can also be used to apply a function where the grouped object is not the first input. + ```{code-cell} ipython3 -func = lambda x: nap.compute_1d_tuning_curves( - group=tsgroup, - feature=feature, - nb_bins=2, - ep=x) +func = lambda x: nap.compute_tuning_curves( + data=tsgroup, + features=feature, + bins=2, + epochs=x) intervalset.groupby_apply("choice", func) ``` Alternatively, the optional parameter `input_key` can be passed to specify which keyword argument the grouped object corresponds to. Other required arguments of the applied function need to be passed as keyword arguments. + ```{code-cell} ipython3 intervalset.groupby_apply( "choice", - nap.compute_1d_tuning_curves, - input_key="ep", - group=tsgroup, - feature=feature, - nb_bins=2) + nap.compute_tuning_curves, + input_key="epochs", + data=tsgroup, + features=feature, + bins=2) ``` diff --git a/doc/user_guide/06_tuning_curves.md b/doc/user_guide/06_tuning_curves.md index 3b86e235c..a52ff4a74 100644 --- a/doc/user_guide/06_tuning_curves.md +++ b/doc/user_guide/06_tuning_curves.md @@ -13,16 +13,10 @@ kernelspec: # Tuning curves -Pynapple can compute 1-dimensional tuning curves -(for example firing rate as a function of angular direction) -and 2-dimensional tuning curves (for example firing rate as a function -of position). It can also compute average firing rate for different -epochs (for example firing rate for different epochs of stimulus presentation). - -:::{important} -If you are using calcium imaging data with the activity of the cell as a continuous transient, the function to call ends with `_continuous` for continuous time series (e.g. [`compute_1d_tuning_curves_continuous`](pynapple.process.tuning_curves.compute_1d_tuning_curves_continuous)). -::: - +With Pynapple you can easily compute n-dimensional tuning curves +(for example, firing rate as a function of 1D angular direction or firing rate as a function of 2D position). +It is also possible to compute average firing rate for different epochs +(for example, firing rate for different epochs of stimulus presentation). ```{code-cell} ipython3 :tags: [hide-cell] @@ -30,53 +24,17 @@ import pynapple as nap import numpy as np import matplotlib.pyplot as plt import seaborn as sns +import xarray as xr from pprint import pprint custom_params = {"axes.spines.right": False, "axes.spines.top": False} sns.set_theme(style="ticks", palette="colorblind", font_scale=1.5, rc=custom_params) +xr.set_options(display_expand_attrs=False) ``` -```{code-cell} -:tags: [hide-cell] - -group = { - 0: nap.Ts(t=np.sort(np.random.uniform(0, 100, 10))), - 1: nap.Ts(t=np.sort(np.random.uniform(0, 100, 20))), - 2: nap.Ts(t=np.sort(np.random.uniform(0, 100, 30))), -} - -tsgroup = nap.TsGroup(group) -``` - -## From epochs - - -The epochs should be stored in a dictionnary: -```{code-cell} ipython3 -dict_ep = { - "stim0": nap.IntervalSet(start=0, end=20), - "stim1":nap.IntervalSet(start=30, end=70) - } -``` - -[`nap.compute_discrete_tuning_curves`](pynapple.process.tuning_curves.compute_discrete_tuning_curves) takes a `TsGroup` for spiking activity and a dictionary of epochs. -The output is a pandas DataFrame where each column is a unit in the `TsGroup` and each row is one `IntervalSet` type. -The value is the mean firing rate of the neuron during this set of intervals. - -```{code-cell} ipython3 -mean_fr = nap.compute_discrete_tuning_curves(tsgroup, dict_ep) - -pprint(mean_fr) -``` - - -## From timestamps activity +## From timestamps or continuous activity -### 1-dimension tuning curves - - ```{code-cell} ipython3 :tags: [hide-cell] - from scipy.ndimage import gaussian_filter1d # Fake Tuning curves @@ -103,37 +61,80 @@ tsgroup = nap.TsGroup( time_support = nap.IntervalSet(0, 100) ) ``` +Computing tuning curves is done using [`compute_tuning_curves`](pynapple.process.tuning_curves.compute_tuning_curves). + +When computing from general time-series, mandatory arguments are: +* `data`: a `TsGroup` (or single `Ts`) or `TsdFrame` (or single `Tsd`) containing the neural activity of one or more units. +* `features`: a `Tsd` or `TsdFrame` containing one or more features. + +By default, 10 bins are used for all features, but you can specify the number of bins, +or the bin edges explicitly, using the `bins` argument. + +The min and max of the tuning curves are by default the minima and maxima of the features. +This can be tweaked with the `range` argument. + +If an `IntervalSet` is passed with `epochs`, everything is restricted to `epochs`, +otherwise the time support of the features is used. + +If you do not want the sampling rate of the features to be estimated from the timestamps, +you can pass it explicitly using the `fs` argument. -Mandatory arguments are `TsGroup`, `Tsd` (or `TsdFrame` with 1 column only) -and `nb_bins` for number of bins of the tuning curves. +You can further also pass a list of strings to label each dimension via `feature_names` +(by default the columns of the features are used). -If an `IntervalSet` is passed with `ep`, everything is restricted to `ep` -otherwise the time support of the feature is used. +The output is an `xarray.DataArray` in which the first dimension represents the units and further dimensions represent the features. +The occupancy and bin edges are stored as attributes. -The min and max of the tuning curve is by default the min and max of the feature. This can be tweaked with the argument `minmax`. +If you explicitly want a `pd.DataFrame` as output (which is only possible when you have just the one feature), +you can set `return_pandas=True`. Note that this will not return the occupancy and bin edges. -The output is a pandas DataFrame. Each column is a unit from `TsGroup` argument. The index of the DataFrame carries the center of the bin in feature space. +### 1D tuning curves from spikes ```{code-cell} ipython3 -tuning_curve = nap.compute_1d_tuning_curves( - group=tsgroup, - feature=feature, - nb_bins=120, - minmax=(0, 2*np.pi) +tuning_curves_1d = nap.compute_tuning_curves( + data=tsgroup, + features=feature, + bins=120, + range=(0, 2*np.pi), + feature_names=["feature"] ) +tuning_curves_1d +``` + +The `xarray.DataArray` can be treated like a `numpy` array. -print(tuning_curve) +It has a shape: +```{code-cell} ipython3 +tuning_curves_1d.shape +``` +It can be sliced: +```{code-cell} ipython3 +tuning_curves_1d[1, 2:8] +``` +It can also be indexed using the coordinates: +```{code-cell} ipython3 +tuning_curves_1d.sel(unit=1) ``` +`xarray` further has `matplotlib` support, allowing for easy visualization: + ```{code-cell} ipython3 -plt.figure() -plt.plot(tuning_curve) -plt.xlabel("Feature space") +tuning_curves_1d.plot.line(x="feature", add_legend=False) plt.ylabel("Firing rate (Hz)") plt.show() ``` -Internally, the function is calling the method [`value_from`](pynapple.Tsd.value_from) which maps timestamps to their closest values in time from a `Tsd` object. It is then possible to validate the tuning curves by displaying the timestamps as well as their associated values. +You can either customize the plot labels yourself using `matplotlib`, or you can set them in the tuning curve object: +```{code-cell} ipython3 +tuning_curves_1d.name = "Firing rate" +tuning_curves_1d.attrs["unit"] = "Hz" +tuning_curves_1d.coords["feature"].attrs["unit"] = "rad" +tuning_curves_1d.plot.line(x="feature", add_legend=False) +plt.show() +``` + +Internally, the `compute_tuning_curves` calls the [`value_from`](pynapple.Tsd.value_from) method which maps timestamps to their closest values in time from a `Tsd` object. +It is then possible to validate the tuning curves by displaying the timestamps as well as their associated values. ```{code-cell} ipython3 :tags: [hide-input] @@ -145,13 +146,47 @@ plt.ylabel("Feature") plt.xlim(0, 2) plt.xlabel("Time (s)") plt.subplot(122) -plt.plot(tuning_curve[3].values, tuning_curve[3].index.values, label="Tuning curve (unit=3)") +plt.plot(tuning_curves_1d[3].values, tuning_curves_1d.coords["feature"], label="Tuning curve (unit=3)") plt.xlabel("Firing rate (Hz)") plt.legend() plt.show() ``` -### 2-dimensional tuning curves +It is also possible to just get the spike counts per bins. This can be done by setting the argument `return_counts=True`. +The output is also a `xarray.DataArray` with the same dimensions as the tuning curves. + +```{code-cell} ipython3 +spike_counts = nap.compute_tuning_curves( + data=tsgroup, + features=feature, + bins=30, + range=(0, 2*np.pi), + feature_names=["feature"], + return_counts=True + ) +``` + +```{code-cell} ipython3 +:tags: [hide-input] +plt.figure() +plt.subplot(131) +plt.plot(tsgroup[3].value_from(feature), 'o') +plt.plot(feature, label="feature") +plt.ylabel("Feature") +plt.xlim(0, 2) +plt.xlabel("Time (s)") +plt.subplot(132) +plt.plot(tuning_curves_1d[3].values, tuning_curves_1d.coords["feature"]) +plt.xlabel("Firing rate (Hz)") +plt.subplot(133) +plt.barh(spike_counts.coords["feature"], width=spike_counts[3].values, height=np.mean(np.diff(spike_counts.coords["feature"]))) +plt.xlabel("Spike count") +plt.tight_layout() +plt.show() +``` + + +### 2D tuning curves from spikes ```{code-cell} ipython3 :tags: [hide-cell] @@ -173,73 +208,66 @@ tsgroup = nap.TsGroup({ }, time_support=epoch) ``` -The `group` argument must be a `TsGroup` object. -The `features` argument must be a 2-columns `TsdFrame` object. -`nb_bins` can be an int or a tuple of 2 ints. - +If you pass more than 1 feature, a multi-dimensional tuning curve is computed: ```{code-cell} ipython3 -tcurves2d, binsxy = nap.compute_2d_tuning_curves( - group=tsgroup, +tuning_curves_2d = nap.compute_tuning_curves( + data=tsgroup, features=features, - nb_bins=(5,5), - minmax=(-1, 1, -1, 1) + bins=(5,5), + range=[(-1, 1), (-1, 1)], + feature_names=["a", "b"] ) -pprint(tcurves2d) +tuning_curves_2d ``` -`tcurves2d` is a dictionnary with each key a unit in `TsGroup`. `binsxy` is a numpy array representing the centers of the bins and is useful for plotting tuning curves. Bins that have never been visited by the feature have been assigned a NaN value. +`tuning_curve_2d` is a again an `xarray.DataArray` but now with three dimensions: +one for the units of `TsGroup` and 2 for the features, the coordinates contain the centers of the bins. +Bins that have never been visited by the feature have been assigned a NaN value. + +Two-dimensional tuning curves can also easily be visualized: + +```{code-cell} ipython3 +tuning_curves_2d.name="Firing rate" +tuning_curves_2d.attrs["unit"]="Hz" +tuning_curves_2d.plot(col="unit") +plt.show() +``` -Checking the accuracy of the tuning curves can be bone by displaying the spikes aligned to the features with the function `value_from` which assign to each spikes the corresponding features value for unit 0. +Verifying the accuracy of the tuning curves can once more be done by displaying the spikes aligned +to the features with the function `value_from` which assign to each spikes the corresponding features value for unit 0. ```{code-cell} ipython3 ts_to_features = tsgroup[0].value_from(features) print(ts_to_features) ``` + `tsgroup[0]` which is a `Ts` object has been transformed to a `TsdFrame` object with each timestamps (spike times) being associated with a features value. ```{code-cell} ipython3 :tags: [hide-input] - - -plt.figure() -plt.subplot(121) -plt.plot(features["b"], features["a"], label="features") -plt.plot(ts_to_features["b"], ts_to_features["a"], "o", color="red", markersize=4, label="spikes") -plt.xlabel("feature b") -plt.ylabel("feature a") -[plt.axvline(b, linewidth=0.5, color='grey') for b in np.linspace(-1, 1, 6)] -[plt.axhline(b, linewidth=0.5, color='grey') for b in np.linspace(-1, 1, 6)] -plt.subplot(122) +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8,4), sharey=True) +ax1.plot(features["b"], features["a"], label="features") +ax1.plot(ts_to_features["b"], ts_to_features["a"], "o", color="red", markersize=4, label="spikes") +ax1.set_xlabel("b") +ax1.set_ylabel("a") +[ax1.axvline(b, linewidth=0.5, color='grey') for b in np.linspace(-1, 1, 6)] +[ax1.axhline(b, linewidth=0.5, color='grey') for b in np.linspace(-1, 1, 6)] extents = ( np.min(features["a"]), np.max(features["a"]), np.min(features["b"]), np.max(features["b"]), ) -plt.imshow(tcurves2d[0], - origin="lower", extent=extents, cmap="viridis", - aspect='auto' - ) -plt.title("Tuning curve unit 0") -plt.xlabel("feature b") -plt.ylabel("feature a") -plt.grid(False) -plt.colorbar() +tuning_curves_2d[0].plot(ax=ax2) +ax2.set_ylabel("") plt.tight_layout() plt.show() ``` - -## From continuous activity - -Tuning curves compute with the following functions are usually made with -data from calcium imaging activities. - -### 1-dimensional tuning curves +### 1D tuning curves from continuous activity ```{code-cell} ipython3 :tags: [hide-cell] - from scipy.ndimage import gaussian_filter1d # Fake Tuning curves @@ -267,34 +295,33 @@ tsdframe = nap.TsdFrame( ) ``` -Arguments are `TsdFrame` (for example continuous calcium data), `Tsd` or `TsdFrame` for the 1-d feature and `nb_bins` for the number of bins. +We do not always have spikes. Sometimes we are analysing continuous firing rates or calcium intensities. +In that case, we can simply pass a `Tsd` or `TsdFrame` as group: ```{code-cell} ipython3 - -tuning_curves = nap.compute_1d_tuning_curves_continuous( - tsdframe=tsdframe, - feature=feature, - nb_bins=120, - minmax=(0, 2*np.pi) +tuning_curves_1d = nap.compute_tuning_curves( + data=tsdframe, + features=feature, + bins=120, + range=(0, 2*np.pi), + feature_names=["feature"] ) - -print(tuning_curves) +tuning_curves_1d ``` ```{code-cell} ipython3 -plt.figure() -plt.plot(tuning_curves) -plt.xlabel("Feature space") -plt.ylabel("Mean activity") +tuning_curves_1d.name="ΔF/F" +tuning_curves_1d.attrs["unit"]="a.u." +tuning_curves_1d.plot.line(x="feature", add_legend=False) plt.show() ``` -### 2-dimensional tuning curves +### 2D tuning curves from continuous activity +This also works with more than one feature: ```{code-cell} ipython3 :tags: [hide-cell] - dt = 0.01 T = 10 epoch = nap.IntervalSet(start=0, end=T, time_units="s") @@ -309,52 +336,67 @@ features = nap.TsdFrame( # Calcium activity -tsdframe = nap.TsdFrame( - t=timestep, - d=np.random.randn(len(timestep), 2) +ft = features.values +alpha = np.arctan2(ft[:, 1], ft[:, 0]) +bin_centers = np.linspace(-np.pi, np.pi, 6) +kappa = 4.0 +units=[] +for i, mu in enumerate(bin_centers): + units.append(np.exp(kappa * np.cos(alpha - mu))) # wrapped Gaussian +units = np.stack(units, axis=1) +tsdframe = nap.TsdFrame(t=features.times(), d=units) +``` + +```{code-cell} ipython3 +tuning_curves_2d = nap.compute_tuning_curves( + data=tsdframe, + features=features, + bins=5, + feature_names=["a", "b"] ) +tuning_curves_2d ``` -Arguments are `TsdFrame` (for example continuous calcium data), `Tsd` or `TsdFrame` for the 1-d feature and `nb_bins` for the number of bins. +```{code-cell} ipython3 +tuning_curves_2d.name="ΔF/F" +tuning_curves_2d.attrs["unit"]="a.u." +tuning_curves_2d.plot(col="unit", col_wrap=3) +plt.show() +``` + +## From epochs + +When computing from epochs, you should store them in a dictionary: ```{code-cell} ipython3 +epochs_dict = { + "stim0": nap.IntervalSet(start=0, end=20), + "stim1":nap.IntervalSet(start=30, end=70) +} +``` +You can then compute the tuning curves using [`nap.compute_response_per_epoch`](pynapple.process.tuning_curves.compute_response_per_epoch). +You can pass either a `TsGroup` for spikes, or a `TsdFrame` for rates/calcium activity. -tuning_curves, xy = nap.compute_2d_tuning_curves_continuous( - tsdframe=tsdframe, - features=features, - nb_bins=5, - ) +The output is an `xarray.DataArray` with labeled dimensions: -print(tuning_curves) +```{code-cell} ipython3 +tuning_curves = nap.compute_response_per_epoch(tsgroup, epochs_dict) +tuning_curves ``` +# Mutual information +Given a set of tuning curves, you can use [`compute_mutual_information`](pynapple.process.tuning_curves.compute_mutual_information) to compute the mutual information between the activity of the neurons and the features, no matter what dimension. +See the [Skaggs et al. (1992)](https://proceedings.neurips.cc/paper/1992/hash/5dd9db5e033da9c6fb5ba83c7a7ebea9-Abstract.html) paper for more information on what mutual information computes. + ```{code-cell} ipython3 -plt.figure() -plt.subplot(121) -plt.plot(features["b"], features["a"], label="features") -plt.xlabel("feature b") -plt.ylabel("feature a") -[plt.axvline(b, linewidth=0.5, color='grey') for b in np.linspace(-1, 1, 6)] -[plt.axhline(b, linewidth=0.5, color='grey') for b in np.linspace(-1, 1, 6)] -plt.subplot(122) -extents = ( - np.min(features["a"]), - np.max(features["a"]), - np.min(features["b"]), - np.max(features["b"]), -) -plt.imshow(tuning_curves[0], - origin="lower", extent=extents, cmap="viridis", - aspect='auto' - ) -plt.title("Tuning curve unit 0") -plt.xlabel("feature b") -plt.ylabel("feature a") -plt.grid(False) -plt.colorbar() -plt.tight_layout() -plt.show() +MI = nap.compute_mutual_information(tuning_curves_1d) +MI ``` +```{code-cell} ipython3 +MI = nap.compute_mutual_information(tuning_curves_2d) +MI +``` +Take a look at the tutorial on [head direction cells](../examples/tutorial_HD_dataset.md) for a realistic example. diff --git a/doc/user_guide/07_decoding.md b/doc/user_guide/07_decoding.md index 310ba660f..b69cdd145 100644 --- a/doc/user_guide/07_decoding.md +++ b/doc/user_guide/07_decoding.md @@ -23,18 +23,32 @@ import seaborn as sns custom_params = {"axes.spines.right": False, "axes.spines.top": False} sns.set_theme(style="ticks", palette="colorblind", font_scale=1.5, rc=custom_params) ``` -Pynapple supports 1 dimensional and 2 dimensional bayesian decoding. The function returns the decoded feature as well as the probabilities for each timestamps. +Pynapple supports n-dimensional decoding from any neural modality. +For spike data, you can use [`decode_bayes`](pynapple.process.decoding.decode_bayes), which implements Bayesian decoding using a Poisson distribution. +For any other type of data (and also for spike data), you can use [`decode_template`](pynapple.process.decoding.decode_template), which implements a template matching algorithm. -:::{hint} -Input to the bayesian decoding functions always include the tuning curves computed from [`nap.compute_1d_tuning_curves`](pynapple.process.tuning_curves.compute_1d_tuning_curves) or [`nap.compute_2d_tuning_curves`](pynapple.process.tuning_curves.compute_2d_tuning_curves). +Input to both decoding functions always includes: + - `tuning_curves`, computed using [`compute_tuning_curves`](pynapple.process.tuning_curves.compute_tuning_curves). + - `data`, neural activity as a `TsGroup` (spikes) or `TsdFrame` (smoothed counts or calcium activity or any other time series). + - `epochs`, to restrict decoding to certain intervals. + - `sliding_window_size`, uniform convolution window size (in number of bins) to smooth spike counts, only used if a `TsGroup` is passed (default is `None`, for no smoothing). This is equivalent to using a sliding window with overlapping bins. + - `bin_size`, the size of the bins in which to count timestamps when data is a `TsGroup` object. + - `time_units`, the units of `bin_size`, defaulting to seconds. + +## Bayesian decoding +When using Bayesian decoding, users can additionally set `uniform_prior=False` to use the occupancy as a prior over the feature distribution. +By default `uniform_prior=True`, and a uniform prior is used. + +:::{important} +Bayesian decoding should only be used with spike (`TsGroup`) or spike count (`TsdFrame`) data, as these can be assumed to follow a Poisson distribution! ::: -## 1-dimensional decoding + +### 1-dimensional Bayesian decoding ```{code-cell} ipython3 :tags: [hide-cell] - from scipy.ndimage import gaussian_filter1d # Fake Tuning curves @@ -60,162 +74,381 @@ index = np.digitize(feature, bins)-1 count = np.random.poisson(tc[index])>0 tsgroup = nap.TsGroup({i:nap.Ts(timestep[count[:,i]]) for i in range(N)}) -epoch = nap.IntervalSet(0, 10) +epochs = nap.IntervalSet(0, 10) ``` -To decode, we need to compute tuning curves in 1D. +First, we compute the tuning curves: ```{code-cell} ipython3 -tcurves_1d = nap.compute_1d_tuning_curves( - tsgroup, feature, nb_bins=61, minmax=(0, 2 * np.pi) +tuning_curves_1d = nap.compute_tuning_curves( + tsgroup, + feature, + bins=60, + range=(0, 2 * np.pi), + feature_names=["Circular feature"] ) ``` -We can display the tuning curves of each neurons - -```{code-cell} +```{code-cell} ipython3 :tags: [hide-input] - -plt.figure() -plt.plot(tcurves_1d) -plt.xlabel("Feature position") -plt.ylabel("Rate (Hz)") +tuning_curves_1d.name = "Firing rate" +tuning_curves_1d.attrs["unit"] = "Hz" +tuning_curves_1d.plot.line(x="Circular feature", add_legend=False) plt.show() ``` -`nap.decode_1d` performs bayesian decoding: - +We can then use [`decode_bayes`](pynapple.process.decoding.decode_bayes) for Bayesian decoding. +We will use the `sliding_window_size` argument to additionally smooth the +spike counts with a uniform convolution window (i.e. use a sliding window), which often helps with decoding. ```{code-cell} ipython3 -decoded, proba_feature = nap.decode_1d( - tuning_curves=tcurves_1d , # 2D tuning curves - group=tsgroup, # Spiking activity - ep=epoch, # Small epoch - bin_size=0.06, # How to bin the spike trains - feature=feature, # Features +decoded, proba_feature = nap.decode_bayes( + tuning_curves=tuning_curves_1d, + data=tsgroup, + epochs=epochs, + sliding_window_size=4, + bin_size=0.02, ) ``` -`decoded` is `Tsd` object containing the decoded feature value. `proba_feature` is a `TsdFrame` containing the probabilities of being in a particular feature bin over time. +`decoded` is a `Tsd` object containing the decoded feature value. +`proba_feature` is a `TsdFrame` containing the probabilities of being in a particular feature bin over time. ```{code-cell} ipython3 :tags: [hide-input] -plt.figure(figsize=(12, 6)) -plt.subplot(211) -plt.plot(feature.restrict(epoch), label="True") -plt.plot(decoded, label="Decoded") -plt.legend() -plt.xlim(epoch[0,0], epoch[0,1]) -plt.subplot(212) -plt.imshow(proba_feature.values.T, aspect="auto", origin="lower", cmap="viridis") -plt.xticks([0, len(decoded)], epoch.values[0]) -plt.xlabel("Time (s)") +fig, (ax1, ax2) = plt.subplots(figsize=(8, 5), nrows=2, ncols=1, sharex=True) +feature=feature.restrict(epochs) +ax1.plot( + feature.times(), + feature.values, + label="True", +) +ax1.scatter( + decoded.times(), + decoded.values, + label="Decoded", + c="orange", +) +ax1.legend( + frameon=False, + bbox_to_anchor=(1.0, 1.0), +) +ax1.set_ylabel("Circular\nfeature") +ax1.set_yticks([0, 2*np.pi], ["0", "2π"]) +im = ax2.imshow(proba_feature.values.T, aspect="auto", origin="lower", cmap="viridis", extent=(0, 10.0, 0, 2*np.pi)) +cbar_ax = fig.add_axes([0.93, 0.1, 0.015, 0.36]) +fig.colorbar(im, cax=cbar_ax, label="Probability") +ax2.set_xlabel("Time (s)", labelpad=-20) +ax2.set_ylabel("Circular\nfeature") +ax2.set_yticks([0, 2*np.pi], ["0", "2π"]) plt.show() - ``` - -## 2-dimensional decoding +### N-dimensional Bayesian decoding ```{code-cell} ipython3 :tags: [hide-cell] - dt = 0.1 -epoch = nap.IntervalSet(start=0, end=1000, time_units="s") +epochs = nap.IntervalSet(start=0, end=1000, time_units="s") features = np.vstack((np.cos(np.arange(0, 1000, dt)), np.sin(np.arange(0, 1000, dt)))).T -features = nap.TsdFrame(t=np.arange(0, 1000, dt), +features = nap.TsdFrame( + t=np.arange(0, 1000, dt), d=features, time_units="s", - time_support=epoch, + time_support=epochs, columns=["a", "b"], ) times = features.as_units("us").index.values ft = features.values alpha = np.arctan2(ft[:, 1], ft[:, 0]) -bins = np.repeat(np.linspace(-np.pi, np.pi, 13)[::, np.newaxis], 2, 1) -bins += np.array([-2 * np.pi / 24, 2 * np.pi / 24]) +bin_centers = np.linspace(-np.pi, np.pi, 12) +kappa = 4.0 ts_group = {} -for i in range(12): - ts = times[(alpha >= bins[i, 0]) & (alpha <= bins[i + 1, 1])] +for i, mu in enumerate(bin_centers): + weights = np.exp(kappa * np.cos(alpha - mu)) # wrapped Gaussian + weights /= np.max(weights) # normalize to 0–1 + mask = weights > 0.5 + ts = times[mask] ts_group[i] = nap.Ts(ts, time_units="us") +ts_group = nap.TsGroup(ts_group) +``` -ts_group = nap.TsGroup(ts_group, time_support=epoch) +Decoding also works with multiple dimensions (here we show a 2D example). +First, we compute the tuning curves: + +```{code-cell} ipython3 +tuning_curves_2d = nap.compute_tuning_curves( + data=ts_group, + features=features, # containing 2 features + bins=10, + epochs=epochs, + range=[(-1.0, 1.0), (-1.0, 1.0)], # range can be specified for each feature +) ``` -To decode, we need to compute tuning curves in 2D. +```{code-cell} ipython3 +:tags: [hide-input] +tuning_curves_2d.name = "Firing rate" +tuning_curves_2d.attrs["unit"] = "Hz" +tuning_curves_2d.plot(row="unit", col_wrap=6) +plt.show() +``` +and then, [`decode_bayes`](pynapple.process.decoding.decode_bayes) again performs bayesian decoding: ```{code-cell} ipython3 -tcurves2d, binsxy = nap.compute_2d_tuning_curves( - group=ts_group, # Spiking activity of 12 neurons - features=features, # 2-dimensional features - nb_bins=10, - ep=epoch, - minmax=(-1.0, 1.0, -1.0, 1.0), # Minmax of the features +decoded, proba_feature = nap.decode_bayes( + tuning_curves=tuning_curves_2d, + data=ts_group, + epochs=epochs, + sliding_window_size=2, + bin_size=0.05, ) ``` -We can display the tuning curves of each neuron - -```{code-cell} +```{code-cell} ipython3 :tags: [hide-input] +fig, (ax1, ax2, ax3) = plt.subplots(figsize=(8, 3), nrows=1, ncols=3, sharey=True) +ax1.plot(features["a"].get(0, 20), label="True") +ax1.scatter( + decoded["a"].get(0, 20).times(), + decoded["a"].get(0, 20), + label="Decoded", + c="orange", +) +ax1.set_title("Feature a") +ax1.set_xlabel("Time (s)") + +ax2.plot(features["b"].get(0, 20), label="True") +ax2.scatter( + decoded["b"].get(0, 20).times(), + decoded["b"].get(0, 20), + label="Decoded", + c="orange", +) +ax2.set_xlabel("Time (s)") +ax2.set_title("Feature b") + +ax3.plot( + features["a"].get(0, 20), + features["b"].get(0, 20), + label="True", +) +ax3.scatter( + decoded["a"].get(0, 20), + decoded["b"].get(0, 20), + label="Decoded", + c="orange", +) +ax3.set_title("Combined") +plt.show() +``` -plt.figure(figsize=(20, 9)) -for i in ts_group.keys(): - plt.subplot(2, 6, i + 1) - plt.imshow( - tcurves2d[i], extent=(binsxy[1][0], binsxy[1][-1], binsxy[0][0], binsxy[0][-1]) +## Template matching +If you do not have spike data, or if you do not want to use the Poisson assumption, Pynapple also supports decoding using template matching, which makes no assumption on the modality of your data. +Instead of computing a probability distribution, `compute_template` computes a distance matrix between the samples and the tuning curves (smaller is better). +In addition to the default arguments, users can set `metric` to choose the used distance metric. By default `metric="correlation"`. + +### 1-dimensional template matching + +```{code-cell} ipython3 +:tags: [hide-cell] +from scipy.ndimage import gaussian_filter1d + +# Fake Tuning curves +N = 6 # Number of neurons +bins = np.linspace(0, 2*np.pi, 61) +x = np.linspace(-np.pi, np.pi, len(bins)-1) +tmp = np.roll(np.exp(-(1.5*x)**2), (len(bins)-1)//2) +tc = np.array([np.roll(tmp, i*(len(bins)-1)//N) for i in range(N)]).T + +tc_1d = pd.DataFrame(index=bins[0:-1], data=tc) + +# Feature +T = 10000 +dt = 0.01 +timestep = np.arange(0, T)*dt +feature = nap.Tsd( + t=timestep, + d=gaussian_filter1d(np.cumsum(np.random.randn(T)*0.5), 20)%(2*np.pi) ) - plt.xticks() +index = np.digitize(feature, bins)-1 + +# Spiking activity + +count = np.random.poisson(tc[index])>0 +tsgroup = nap.TsGroup({i:nap.Ts(timestep[count[:,i]]) for i in range(N)}) +epochs = nap.IntervalSet(0, 10) +``` + +First, we compute the tuning curves (here we'll use spikes as neural data): + +```{code-cell} ipython3 +tuning_curves_1d = nap.compute_tuning_curves( + tsgroup, + feature, + bins=61, + range=(0, 2 * np.pi), + feature_names=["Circular feature"] +) +``` + +```{code-cell} ipython3 +:tags: [hide-input] +tuning_curves_1d.name = "Firing rate" +tuning_curves_1d.attrs["unit"] = "Hz" +tuning_curves_1d.plot.line(x="Circular feature", add_legend=False) plt.show() ``` -`nap.decode_2d` performs bayesian decoding: +We can then use [`decode_template`](pynapple.process.decoding.decode_template) for template matching: + +```{code-cell} ipython3 +decoded, dist = nap.decode_template( + tuning_curves=tuning_curves_1d, + data=tsgroup, + epochs=epochs, + sliding_window_size=4, + bin_size=0.05, + metric="correlation" +) +``` + +`decoded` is a `Tsd` object containing the decoded feature value. +`dist` is a `TsdFrame` containing the distance matrix of every time bin with respect to the tuning curves. + +```{code-cell} ipython3 +:tags: [hide-input] +fig, (ax1, ax2) = plt.subplots(figsize=(8, 5), nrows=2, ncols=1, sharex=True) +feature=feature.restrict(epochs) +ax1.plot( + feature.times(), + feature.values, + label="True", +) +ax1.scatter( + decoded.times(), + decoded.values, + label="Decoded", + c="orange", +) +ax1.legend( + frameon=False, + bbox_to_anchor=(1.0, 1.0), +) +ax1.set_ylabel("Circular\nfeature") +ax1.set_yticks([0, 2*np.pi], ["0", "2π"]) +im = ax2.imshow(dist.values.T, aspect="auto", origin="lower", cmap="inferno_r", extent=(0, 10.0, 0, 2*np.pi)) +cbar_ax = fig.add_axes([0.93, 0.1, 0.015, 0.36]) +fig.colorbar(im, cax=cbar_ax, label="Distance") +ax2.set_xlabel("Time (s)", labelpad=-20) +ax2.set_ylabel("Circular\nfeature") +ax2.set_yticks([0, 2*np.pi], ["0", "2π"]) +plt.show() +``` + +### N-dimensional template matching + +```{code-cell} ipython3 +:tags: [hide-cell] +dt = 0.01 +T = 10 +epoch = nap.IntervalSet(start=0, end=T, time_units="s") +features = np.vstack((np.cos(np.arange(0, T, dt)), np.sin(np.arange(0, T, dt)))).T +features = nap.TsdFrame( + t=np.arange(0, T, dt), + d=features, + time_units="s", + time_support=epoch, + columns=["a", "b"], +) + + +# Calcium activity +ft = features.values +alpha = np.arctan2(ft[:, 1], ft[:, 0]) +bin_centers = np.linspace(-np.pi, np.pi, 12) +kappa = 4.0 +units=[] +for i, mu in enumerate(bin_centers): + units.append(np.exp(kappa * np.cos(alpha - mu))) # wrapped Gaussian +units = np.stack(units, axis=1) +tsdframe = nap.TsdFrame(t=features.times(), d=units) +``` + +Template matching also works with multiple dimensions. +First, we compute the tuning curves (now let's simulate calcium imaging in a `TsdFrame`): ```{code-cell} ipython3 -decoded, proba_feature = nap.decode_2d( - tuning_curves=tcurves2d, # 2D tuning curves - group=ts_group, # Spiking activity - ep=epoch, # Epoch - bin_size=0.1, # How to bin the spike trains - xy=binsxy, # Features position - features=features, # Features +tuning_curves_2d = nap.compute_tuning_curves( + data=tsdframe, + features=features, # containing 2 features + bins=10, + epochs=epochs, + range=[(-1.0, 1.0), (-1.0, 1.0)], # range can be specified for each feature ) ``` ```{code-cell} ipython3 :tags: [hide-input] +tuning_curves_2d.name = "ΔF/F" +tuning_curves_2d.attrs["unit"] = "a.u." +tuning_curves_2d.plot(row="unit", col_wrap=6) +plt.show() +``` -plt.figure(figsize=(15, 5)) -plt.subplot(131) -plt.plot(features["a"].get(0,20), label="True") -plt.plot(decoded["a"].get(0,20), label="Decoded") -plt.legend() -plt.xlabel("Time (s)") -plt.ylabel("Feature a") -plt.subplot(132) -plt.plot(features["b"].get(0,20), label="True") -plt.plot(decoded["b"].get(0,20), label="Decoded") -plt.legend() -plt.xlabel("Time (s)") -plt.title("Feature b") -plt.subplot(133) -plt.plot( - features["a"].get(0,20), - features["b"].get(0,20), +and then, [`decode_template`](pynapple.process.decoding.decode_template) again performs template matching: + +```{code-cell} ipython3 +decoded, dist = nap.decode_template( + tuning_curves=tuning_curves_2d, + data=tsdframe, + epochs=epochs, + bin_size=0.01, + metric="correlation" +) +``` + +```{code-cell} ipython3 +:tags: [hide-input] +fig, (ax1, ax2, ax3) = plt.subplots(figsize=(8, 3), nrows=1, ncols=3, sharey=True) +ax1.plot(features["a"].get(0, 20), label="True") +ax1.scatter( + decoded["a"].get(0, 20).times(), + decoded["a"].get(0, 20), + label="Decoded", + c="orange", +) +ax1.set_title("Feature a") +ax1.set_xlabel("Time (s)") + +ax2.plot(features["b"].get(0, 20), label="True") +ax2.scatter( + decoded["b"].get(0, 20).times(), + decoded["b"].get(0, 20), + label="Decoded", + c="orange", +) +ax2.set_xlabel("Time (s)") +ax2.set_title("Feature b") + +ax3.plot( + features["a"].get(0, 20), + features["b"].get(0, 20), label="True", ) -plt.plot( - decoded["a"].get(0,20), - decoded["b"].get(0,20), +ax3.scatter( + decoded["a"].get(0, 20), + decoded["b"].get(0, 20), label="Decoded", + c="orange", ) -plt.xlabel("Feature a") -plt.title("Feature b") -plt.legend() -plt.tight_layout() +ax3.set_title("Combined") plt.show() -``` \ No newline at end of file +``` + +Take a look at the [tutorial on calcium imaging](../examples/tutorial_calcium_imaging.md) +for an application of template matching with real data and a comparison of various distance metrics! diff --git a/doc/user_guide/12_filtering.md b/doc/user_guide/12_filtering.md index c187bfbc9..fb6a4e1ac 100644 --- a/doc/user_guide/12_filtering.md +++ b/doc/user_guide/12_filtering.md @@ -14,7 +14,7 @@ kernelspec: Filtering ========= -The filtering module holds the functions for frequency manipulation : +The filtering module holds the functions for frequency manipulation: - [`nap.apply_bandstop_filter`](pynapple.process.filtering.apply_bandstop_filter) - [`nap.apply_lowpass_filter`](pynapple.process.filtering.apply_lowpass_filter) @@ -377,3 +377,76 @@ plt.xlabel("Number of dimensions") plt.ylabel("Time (s)") plt.title("Low pass filtering benchmark") ``` + + +*** +Detecting Oscillatory Events +--------------------------- + +The filtering module also provides a method [`detect_oscillatory_events`](pynapple.process.filtering.detect_oscillatory_events) to automatically detect intervals containing oscillatory events (such as ripples or spindles) in a signal. + +To demonstrate, let's create a synthetic signal where a fast oscillation (e.g., 40 Hz) occurs in a noisy signal: + +```{code-cell} ipython3 +# Parameters +fs = 1000 # Sampling frequency (Hz) +duration = 3 # seconds +t = np.linspace(0, duration, int(fs * duration)) + +# 40 Hz oscillation +osc = np.sin(2 * np.pi * 40 * t) +signal = np.zeros_like(t) + 0.2 * np.random.randn(len(t)) +mask = (t > 1) & (t < 1.5) +signal[mask] += osc[mask] + +# Create Tsd object +ts = nap.Tsd(t=t, d=signal) +``` + +```{code-cell} ipython3 +:tags: [hide-input] + +# Plot the signal +plt.figure(figsize=(15, 4)) +plt.plot(ts, label="Signal (40 Hz oscillation)") +plt.xlabel("Time (s)") +plt.ylabel("Amplitude") +plt.title("Signal with oscillatory bursts") +plt.legend() +plt.show() +``` + +Now, let's use [`detect_oscillatory_events`](pynapple.process.filtering.detect_oscillatory_events) to find the oscillation intervals. The function will return the detected intervals as an `IntervalSet` along with metadata containing peak times. + +```{code-cell} ipython3 +# Define detection parameters +freq_band = (35, 45) # Bandpass filter for 40 Hz +thres_band = (1, 10) # Thresholds for normalized squared signal +min_dur = 0.03 # Minimum event duration (s) +max_dur = 1 # Maximum event duration (s) +min_inter = 0.02 # Minimum interval between events (s) +epoch = nap.IntervalSet(start=0, end=duration) + +# Detect oscillatory events +osc_ep = nap.filtering.detect_oscillatory_events( + ts, epoch, freq_band, thres_band, (min_dur, max_dur), min_inter +) + +print("Detected intervals:\n", osc_ep) +``` + +Let's visualize the detected intervals and peaks on the original signal: + +```{code-cell} ipython3 +:tags: [hide-input] + +plt.figure(figsize=(15, 4)) +plt.plot(ts, label="Signal") +for s, e in osc_ep.values: + plt.axvspan(s, e, color="orange", alpha=0.3) +plt.xlabel("Time (s)") +plt.ylabel("Amplitude") +plt.title("Detected oscillatory events") +plt.legend() +plt.show() +``` diff --git a/main.py b/main.py index 62118ab79..a5a2abf7a 100644 --- a/main.py +++ b/main.py @@ -16,16 +16,18 @@ wake_ep = data["position_time_support"] # COMPUTING TUNING CURVES -tuning_curves = nap.compute_1d_tuning_curves( - spikes, head_direction, 120, minmax=(0, 2 * np.pi) +tuning_curves = nap.compute_tuning_curves( + spikes, + head_direction, + 120, + epochs=wake_ep, + range=(0, 2 * np.pi), + feature_names=["head direction"], ) - # PLOT -plt.figure() -for i in spikes: - plt.subplot(3, 5, i + 1, projection="polar") - plt.plot(tuning_curves[i]) - plt.xticks([0, np.pi / 2, np.pi, 3 * np.pi / 2]) - +g = tuning_curves.plot( + row="unit", col_wrap=5, subplot_kws={"projection": "polar"}, sharey=False +) +plt.xticks([0, np.pi / 2, np.pi, 3 * np.pi / 2]) plt.show() diff --git a/pynapple/core/_jitted_functions.py b/pynapple/core/_jitted_functions.py index 33057b3cf..a74233a8a 100644 --- a/pynapple/core/_jitted_functions.py +++ b/pynapple/core/_jitted_functions.py @@ -15,8 +15,10 @@ def jitrestrict(time_array, starts, ends): t = 0 x = 0 - while ends[k] < time_array[t]: + while k < m and ends[k] < time_array[t]: k += 1 + if k == m: + return np.empty(0, dtype=np.int64) while k < m: # Outside diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 303fc5c3a..0cceb5461 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -362,6 +362,32 @@ def restrict(self, iset): data = None if not hasattr(self, "values") else self.values[idx] return self._define_instance(time_array[idx], iset, values=data) + def in_interval(self, iset): + """ + Check which timestamps of the time series are within the intervals defined by an IntervalSet object + + Parameters + ---------- + iset : IntervalSet + the IntervalSet object + + Returns + ------- + Tsd + A Tsd of indicating which timestamps are within the intervals + """ + if not isinstance(iset, IntervalSet): + raise TypeError("Argument should be IntervalSet") + + time_array = self.index.values + starts = iset.start + ends = iset.end + + idx = _restrict(time_array, starts, ends) + mask = np.zeros_like(time_array, dtype=bool) + mask[idx] = True + return self._define_instance(time_array, self.time_support, values=mask) + def copy(self): """Copy the data, index and time support""" data = getattr(self, "values", None) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 59e51b835..9ff001743 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -1198,6 +1198,45 @@ def drop_info(self, key): """ return _MetadataMixin.drop_info(self, key) + @add_meta_docstring("restrict_info") + def restrict_info(self, key): + """ + Examples + -------- + >>> import pynapple as nap + >>> import numpy as np + >>> times = np.array([[0, 5], [10, 12], [20, 33]]) + >>> metadata = {"l1": [1, 2, 3], "l2": ["x", "x", "y"], "l3": [4, 5, 6]} + >>> ep = nap.IntervalSet(tmp,metadata=metadata) + >>> ep + index start end l1 l2 l3 + 0 0 5 1 x 4 + 1 10 12 2 x 5 + 2 20 33 3 y 6 + shape: (3, 2), time unit: sec. + + To restrict to multiple metadata columns: + + >>> ep.restrict_info(["l2", "l3"]) + >>> ep + index start end l2 l3 + 0 0 5 x 4 + 1 10 12 x 5 + 2 20 33 y 6 + shape: (3, 2), time unit: sec. + + To restrict to a single metadata column: + + >>> ep.restrict_info("l2") + >>> ep + index start end l2 + 0 0 5 x + 1 10 12 x + 2 20 33 y + shape: (3, 2), time unit: sec. + """ + return _MetadataMixin.restrict_info(self, key) + @add_or_convert_metadata @add_meta_docstring("groupby") def groupby(self, by, get_group=None): @@ -1266,7 +1305,7 @@ def groupby_apply(self, by, func, input_key=None, **func_kwargs): Apply a numpy function: >>> ep.groupby_apply("l2", np.mean) - {'x': 6.75, 'y': 26.5} + {'x': np.float64(6.75), 'y': np.float64(26.5)} Apply a custom function: @@ -1289,16 +1328,29 @@ def groupby_apply(self, by, func, input_key=None, **func_kwargs): ... ) >>> feature = nap.Tsd(t=np.arange(40), d=np.concatenate([np.zeros(20), np.ones(20)])) >>> func_kwargs = { - >>> "group": tsg, - >>> "feature": feature, - >>> "nb_bins": 2, - >>> } - >>> ep.groupby_apply("l2", nap.compute_1d_tuning_curves, input_key="ep", **func_kwargs) - {'x': 1 2 3 - 0.25 1.025641 1.823362 4.216524 - 0.75 NaN NaN NaN, - 'y': 1 2 3 - 0.25 NaN NaN NaN - 0.75 1.025641 1.978022 4.835165} + ... "data": tsg, + ... "features": feature, + ... "bins": 2, + ... } + >>> ep.groupby_apply("l2", nap.compute_tuning_curves, input_key="epochs", **func_kwargs) + {'x': Size: 48B + array([[ nan, 1. ], + [ nan, 1.77777778], + [ nan, 4.11111111]]) + Coordinates: + * unit (unit) int64 24B 1 2 3 + * 0 (0) float64 16B -0.25 0.25 + Attributes: + occupancy: [nan 9.] + bin_edges: [array([-0.5, 0. , 0.5])], 'y': Size: 48B + array([[ nan, 1. ], + [ nan, 1.92857143], + [ nan, 4.71428571]]) + Coordinates: + * unit (unit) int64 24B 1 2 3 + * 0 (0) float64 16B 0.75 1.25 + Attributes: + occupancy: [nan 14.] + bin_edges: [array([0.5, 1. , 1.5])]} """ return _MetadataMixin.groupby_apply(self, by, func, input_key, **func_kwargs) diff --git a/pynapple/core/metadata_class.py b/pynapple/core/metadata_class.py index e00413d13..e284be43a 100644 --- a/pynapple/core/metadata_class.py +++ b/pynapple/core/metadata_class.py @@ -3,6 +3,7 @@ import itertools import warnings from collections import UserDict +from functools import wraps from numbers import Number from typing import Union @@ -28,6 +29,7 @@ def add_or_convert_metadata(func): Decorator for backwards compatibility of objects picked with older versions of pynapple. """ + @wraps(func) def _decorator(self, *args, **kwargs): if ( (len(args) == 1) @@ -407,6 +409,39 @@ def drop_info(self, key): f"Invalid metadata column {key}. Metadata columns are {self.metadata_columns}" ) + def restrict_info(self, key): + """ + Restrict metadata columns to a key or list of keys. + + Parameters + ---------- + key : str or list of str + Metadata column name(s) to restrict to. + + Returns + ------- + None + """ + if isinstance(key, Number): + raise TypeError( + f"Invalid metadata column {key}. Metadata columns are {self.metadata_columns}" + ) + if isinstance(key, str): + key = [key] + + no_keep = [k for k in key if k not in self.metadata_columns] + if no_keep: + raise KeyError( + f"Metadata column(s) {no_keep} not found. Metadata columns are {self.metadata_columns}" + ) + + drop_keys = set(self.metadata_columns) - set(key) + for k in drop_keys: + if (self.nap_class == "TsGroup") and (k == "rate"): + continue # cannot drop TsGroup 'rate' + else: + del self._metadata[k] + def groupby(self, by, get_group=None): """ Group pynapple object by metadata name(s). diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 071276f5a..56b4e21f2 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -46,6 +46,7 @@ add_docstring, convert_to_array, is_array_like, + modifies_time_axis, ) @@ -71,7 +72,12 @@ def _get_class(data): def _initialize_tsd_output( - input_object, values, time_index=None, time_support=None, kwargs=None + input_object, + values, + time_index=None, + time_support=None, + drop_metadata=False, + kwargs=None, ): """ Initialize the output object for time series data, ensuring proper alignment of time indices @@ -127,7 +133,7 @@ def _initialize_tsd_output( cls = _get_class(values) # if out will be a tsdframe implement kwargs logic - if cls is TsdFrame: + if (cls is TsdFrame) and (not drop_metadata): # get eventual setting cols = kwargs.get("columns", None) metadata = kwargs.get("metadata", None) @@ -178,10 +184,10 @@ def __init__(self, t, d, time_units="s", time_support=None, load_array=True): ) self.values = d - assert len(self.index) == len( - self.values + assert ( + len(self.index) == self.values.shape[0] ), "Length of values {} does not match length of index {}".format( - len(self.values), len(self.index) + self.values.shape[0], len(self.index) ) if isinstance(time_support, IntervalSet) and len(self.index): @@ -286,13 +292,32 @@ def __array_function__(self, func, types, args, kwargs): ]: return NotImplemented + # This should be implemented at some point + if func in [ + np.take, + np.take_along_axis, + np.extract, + np.compress, + np.choose, + np.select, + np.delete, + ]: + return NotImplemented + if hasattr(np.fft, func.__name__): return NotImplemented if func in [np.split, np.array_split, np.dsplit, np.hsplit, np.vsplit]: return _split_tsd(func, *args, **kwargs) - if func in [np.concatenate, np.vstack, np.hstack, np.dstack]: + if func in [ + np.concatenate, + np.vstack, + np.hstack, + np.dstack, + np.column_stack, + np.stack, + ]: return _concatenate_tsd(func, *args, **kwargs) new_args = [] @@ -303,7 +328,11 @@ def __array_function__(self, func, types, args, kwargs): new_args.append(a) out = func._implementation(*new_args, **kwargs) - return _initialize_tsd_output(self, out) + + if modifies_time_axis(func, new_args, kwargs): + return out + else: + return _initialize_tsd_output(self, out, drop_metadata=True) def as_array(self): """ @@ -1007,6 +1036,8 @@ def __getitem__(self, key): key = tuple(k.values if isinstance(k, Tsd) else k for k in key) output = self.values.__getitem__(key) index = self.index.__getitem__(key[0]) + if index.ndim > 1: + index = np.squeeze(index) else: output = self.values.__getitem__(key) index = self.index.__getitem__(key) @@ -1167,6 +1198,38 @@ def restrict(self, iset): """ return _Base.restrict(self, iset) + @add_docstring("in_interval", _Base) + def in_interval(self, iset): + """ + Examples + -------- + >>> import pynapple as nap + >>> import numpy as np + >>> t = np.arange(100) + >>> ep = nap.IntervalSet(start=0, end=50) + >>> tsdtensor = nap.TsdTensor(t=t, d=np.random.randn(len(t), 4, 4)) + >>> tsdtensor.in_interval(ep) + Time (s) + ---------- -- + 0.0 1 + 1.0 1 + 2.0 1 + 3.0 1 + 4.0 1 + 5.0 1 + 6.0 1 + ... + 93.0 0 + 94.0 0 + 95.0 0 + 96.0 0 + 97.0 0 + 98.0 0 + 99.0 0 + dtype: bool, shape: (100,) + """ + return _Base.in_interval(self, iset) + @add_docstring("value_from", _Base) def value_from(self, data, ep=None, mode="closest"): """ @@ -1535,7 +1598,7 @@ def __repr__(self): np.hstack( ( self.index[0:n_rows, None], - np.round(self.values[0:n_rows, 0:max_cols], 5), + self.values[0:n_rows, 0:max_cols], ends, ), dtype=object, @@ -1543,7 +1606,7 @@ def __repr__(self): np.array( [ ["..."] - + ["..."] * np.minimum(max_cols, self.shape[1]) + + [None] * np.minimum(max_cols, self.shape[1]) + end ], dtype=object, @@ -1551,7 +1614,7 @@ def __repr__(self): np.hstack( ( self.index[-n_rows:, None], - np.round(self.values[-n_rows:, 0:max_cols], 5), + self.values[-n_rows:, 0:max_cols], ends, ), dtype=object, @@ -1563,7 +1626,7 @@ def __repr__(self): table = np.hstack( ( self.index[:, None], - np.round(self.values[:, 0:max_cols], 5), + self.values[:, 0:max_cols], ends, ), dtype=object, @@ -1705,15 +1768,18 @@ def __setitem__(self, key, value): def __getitem__(self, key, *args, **kwargs): if isinstance(key, tuple): key = tuple(k.values if hasattr(k, "values") else k for k in key) - if isinstance(key, Tsd): + if isinstance(key, (Tsd, TsdFrame)): try: assert np.issubdtype(key.dtype, np.bool_) except AssertionError: raise ValueError( - "When indexing with a Tsd, it must contain boolean values" + "When indexing with a Tsd or TsdFrame, it must contain boolean values" ) - key = key.d - elif isinstance(key, str): + if isinstance(key, TsdFrame): + return self.values.__getitem__(key.d) + else: + key = key.d + if isinstance(key, str): if key in self.columns: with warnings.catch_warnings(): # ignore deprecated warning for loc @@ -1735,7 +1801,9 @@ def __getitem__(self, key, *args, **kwargs): if isinstance(key, tuple): index = self.index.__getitem__(key[0]) - if len(key) == 2: + if index.ndim > 1: + index = np.squeeze(index) + if len(key) == 2 and key[1] is not None: columns = self.columns.__getitem__(key[1]) else: index = self.index.__getitem__(key) @@ -1958,6 +2026,38 @@ def restrict(self, iset): """ return _Base.restrict(self, iset) + @add_docstring("in_interval", _Base) + def in_interval(self, iset): + """ + Examples + -------- + >>> import pynapple as nap + >>> import numpy as np + >>> t = np.arange(100) + >>> ep = nap.IntervalSet(start=0, end=50) + >>> tsdframe = nap.TsdFrame(t=t, d=np.random.randn(len(t), 4)) + >>> tsdframe.in_interval(ep) + Time (s) + ---------- -- + 0.0 1 + 1.0 1 + 2.0 1 + 3.0 1 + 4.0 1 + 5.0 1 + 6.0 1 + ... + 93.0 0 + 94.0 0 + 95.0 0 + 96.0 0 + 97.0 0 + 98.0 0 + 99.0 0 + dtype: bool, shape: (100,) + """ + return _Base.in_interval(self, iset) + @add_docstring("value_from", _Base) def value_from(self, data, ep=None, mode="closest"): """ @@ -2371,6 +2471,65 @@ def drop_info(self, key): """ return _MetadataMixin.drop_info(self, key) + @add_meta_docstring("restrict_info") + def restrict_info(self, key): + """ + Examples + -------- + >>> import pynapple as nap + >>> import numpy as np + >>> metadata = {"l1": [1, 2, 3], "l2": ["x", "x", "y"], "l3": [4, 5, 6]} + >>> tsdframe = nap.TsdFrame(t=np.arange(5), d=np.ones((5, 3)), metadata=metadata) + >>> print(tsdframe) + Time (s) 0 1 2 + ---------- --- --- --- + 0.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 + 2.0 1.0 1.0 1.0 + 3.0 1.0 1.0 1.0 + 4.0 1.0 1.0 1.0 + Metadata + ---------- --- --- --- + l1 1 2 3 + l2 x x y + l3 4 5 6 + dtype: float64, shape: (5, 3) + + To restrict to multiple metadata rows: + + >>> tsdframe.restrict_info(["l2", "l3"]) + >>> tsdframe + Time (s) 0 1 2 + ---------- --- --- --- + 0.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 + 2.0 1.0 1.0 1.0 + 3.0 1.0 1.0 1.0 + 4.0 1.0 1.0 1.0 + Metadata + ---------- --- --- --- + l2 x x y + l3 4 5 6 + dtype: float64, shape: (5, 3) + + To restrict to a single metadata row: + + >>> tsdframe.restrict_info("l2") + >>> tsdframe + Time (s) 0 1 2 + ---------- --- --- --- + 0 1 1 1 + 1 1 1 1 + 2 1 1 1 + 3 1 1 1 + 4 1 1 1 + Metadata + ---------- --- --- --- + l2 x x y + dtype: float64, shape: (5, 3) + """ + return _MetadataMixin.restrict_info(self, key) + @add_or_convert_metadata @add_meta_docstring("groupby") def groupby(self, by, get_group=None): @@ -2660,6 +2819,8 @@ def __getitem__(self, key, *args, **kwargs): if isinstance(key, tuple): index = self.index.__getitem__(key[0]) + if index.ndim > 1: + index = np.squeeze(index) elif isinstance(key, Number): index = np.array([key]) else: @@ -2911,6 +3072,38 @@ def restrict(self, iset): """ return _Base.restrict(self, iset) + @add_docstring("in_interval", _Base) + def in_interval(self, iset): + """ + Examples + -------- + >>> import pynapple as nap + >>> import numpy as np + >>> t = np.arange(100) + >>> ep = nap.IntervalSet(start=0, end=50) + >>> tsd = nap.Tsd(t=t, d=np.random.randn(len(t))) + >>> tsd.in_interval(ep) + Time (s) + ---------- -- + 0.0 1 + 1.0 1 + 2.0 1 + 3.0 1 + 4.0 1 + 5.0 1 + 6.0 1 + ... + 93.0 0 + 94.0 0 + 95.0 0 + 96.0 0 + 97.0 0 + 98.0 0 + 99.0 0 + dtype: bool, shape: (100,) + """ + return _Base.in_interval(self, iset) + @add_docstring("value_from", _Base) def value_from(self, data, ep=None, mode="closest"): """ @@ -3206,6 +3399,8 @@ def __repr__(self): def __getitem__(self, key): if isinstance(key, tuple): index = self.index.__getitem__(key[0]) + if index.ndim > 1: + index = np.squeeze(index) else: index = self.index.__getitem__(key) @@ -3536,6 +3731,38 @@ def restrict(self, iset): """ return _Base.restrict(self, iset) + @add_docstring("in_interval", _Base) + def in_interval(self, iset): + """ + Examples + -------- + >>> import pynapple as nap + >>> import numpy as np + >>> t = np.arange(100) + >>> ep = nap.IntervalSet(start=0, end=50) + >>> ts = nap.Ts(t) + >>> ts.in_interval(ep) + Time (s) + ---------- -- + 0.0 1 + 1.0 1 + 2.0 1 + 3.0 1 + 4.0 1 + 5.0 1 + 6.0 1 + ... + 93.0 0 + 94.0 0 + 95.0 0 + 96.0 0 + 97.0 0 + 98.0 0 + 99.0 0 + dtype: bool, shape: (100,) + """ + return _Base.in_interval(self, iset) + @add_docstring("value_from", _Base) def value_from(self, data, ep=None, mode="closest"): """ diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 4f3a3667b..fc2c4c44b 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -1620,7 +1620,7 @@ def _from_npz_reader(cls, file): time_support = IntervalSet(file["start"], file["end"]) if has_data: - data = file["data"] + data = file["d"] if "keys" in file.keys(): keys = file["keys"] @@ -1843,6 +1843,52 @@ def drop_info(self, key): """ return _MetadataMixin.drop_info(self, key) + @add_meta_docstring("restrict_info") + def restrict_info(self, key): + """ + Note + ---- + The `rate` column is always kept in the metadata, even if it is not specified in `key`. + + Examples + -------- + >>> import pynapple as nap + >>> import numpy as np + >>> tmp = {0:nap.Ts(t=np.arange(0,200), time_units='s'), + ... 1:nap.Ts(t=np.arange(0,200,0.5), time_units='s'), + ... 2:nap.Ts(t=np.arange(0,300,0.25), time_units='s'), + ... } + >>> metadata = {"l1": [1, 2, 3], "l2": ["x", "x", "y"], "l3": [4, 5, 6]} + >>> tsgroup = nap.TsGroup(tmp,metadata=metadata) + >>> print(tsgroup) + Index rate l1 l2 l3 + ------- ------- ---- ---- ---- + 0 0.66722 1 x 4 + 1 1.33445 2 x 5 + 2 4.00334 3 y 6 + + To restrict to multiple metadata columns: + + >>> tsgroup.restrict_info(["l2", "l3"]) + >>> tsgroup + Index rate l2 l3 + ------- ------- ---- ---- + 0 0.66722 x 4 + 1 1.33445 x 5 + 2 4.00334 y 6 + + To restrict to a single metadata column: + + >>> tsgroup.drop_info("l2") + >>> tsgroup + Index rate l2 + ------- ------- ---- + 0 0.66722 x + 1 1.33445 x + 2 4.00334 y + """ + return _MetadataMixin.restrict_info(self, key) + @add_or_convert_metadata @add_meta_docstring("groupby") def groupby(self, by, get_group=None): @@ -1925,12 +1971,22 @@ def groupby_apply(self, by, func, input_key=None, **func_kwargs): ... d=np.concatenate([np.zeros(20), np.ones(20)]), ... time_support=nap.IntervalSet(np.array([[0, 5], [10, 12], [20, 33]])), ... ) - >>> tsgroup.groupby_apply("l2", nap.compute_1d_tuning_curves, feature=feature, nb_bins=2) - {'x': 0 1 - 0.25 1.15 2.044444 - 0.75 1.15 2.217857, - 'y': 2 - 0.25 3.833333 - 0.75 4.353571} + >>> print(tsgroup.groupby_apply("l2", nap.compute_tuning_curves, features=feature, bins=2)) + {'x': Size: 32B + array([[1. , 1. ], + [1.77777778, 1.92857143]]) + Coordinates: + * unit (unit) int64 16B 0 1 + * 0 (0) float64 16B 0.25 0.75 + Attributes: + occupancy: [ 9. 14.] + bin_edges: [array([0. , 0.5, 1. ])], 'y': Size: 16B + array([[3.33333333, 3.78571429]]) + Coordinates: + * unit (unit) int64 8B 2 + * 0 (0) float64 16B 0.25 0.75 + Attributes: + occupancy: [ 9. 14.] + bin_edges: [array([0. , 0.5, 1. ])]} """ return _MetadataMixin.groupby_apply(self, by, func, input_key, **func_kwargs) diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index 9b24315b7..81d94ae86 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -2,8 +2,10 @@ Utility functions """ +import inspect import os import warnings +from collections.abc import Sequence from itertools import combinations from numbers import Number from pathlib import Path @@ -310,9 +312,17 @@ def _concatenate_tsd(func, *args, **kwargs): support_equal = _check_time_equals([x.values for x in time_supports]) if time_equal and support_equal: - return nap_class( - t=time_indexes[0], d=output, time_support=time_supports[0] - ) + new_kwargs = {} + if len(columns): + new_kwargs = {"columns": np.hstack([c for c in columns])} + if len(new_kwargs["columns"]) != output.shape[1]: + new_kwargs = {} + return args[0][0]._define_instance( + time_index=time_indexes[0], + time_support=time_supports[0], + values=output, + **new_kwargs, + ) # Dropping metadata in this case else: if not time_equal and not support_equal: msg = "Time indexes and time supports are not all equals up to pynapple precision. Returning numpy array!" @@ -484,3 +494,105 @@ def wrapper(func): return func return wrapper + + +def _arg_as_sequence(x): + return isinstance(x, Sequence) and not isinstance(x, (str, bytes)) + + +def modifies_time_axis(func, new_args, kwargs): + """ + Return True if calling func(*new_args, **kwargs) would modify/move axis 0. + Uses inspect.signature(bind_partial + apply_defaults) to get effective args. + Conservative: if we can't determine array ndim, assume it *may* modify axis 0. + """ + if func is np.flipud: + return True + if func is np.squeeze: + return False # This one should be handled by _initialize_tsd_output + + try: + sig = inspect.signature(func) + except (TypeError, ValueError): + return False + + bound = sig.bind_partial(*new_args, **kwargs) + bound.apply_defaults() + + # Helper to get first array-like from positional args (conservative) + arr = None + if new_args: + arr = new_args[0] + else: + # try common kw names + for name in ("a", "arr", "array", "x", "m"): + if name in bound.arguments: + arr = bound.arguments[name] + break + + ndim = getattr(arr, "ndim", None) + if ndim is None: + return False + + ### 1) single-axis arguments ### + axis = bound.arguments.get("axis", inspect._empty) + if axis is not inspect._empty: + # axis=None usually means "all axes" for reductions => affects axis 0 + if (axis is None) or (axis == 0): + return True + if isinstance(axis, tuple) and (0 in axis): + return True + # axis might be negative; normalize if ndim known + if axis < 0: + normalized_axis = axis + ndim + if func is np.expand_dims: + if normalized_axis == -1: + # normalized_axis will be -1 when expanding first dimension + # normalized_axis = 0 will expand in the second dimension + return True + else: + if normalized_axis == 0: + return True + + # Special case for np.rollaxis + if func is np.rollaxis: + if bound.arguments.get("start", 0) == 0: + return True + # special case for np.rot90 + if func is np.rot90: + if 0 in bound.arguments.get("axes", (0, 1)): + return True + + ### 2) multi-axis permutation (e.g., transpose) ### + axes = bound.arguments.get("axes", inspect._empty) + if axes is not inspect._empty: + if axes is None: + return True # all axes permuted => affects axis 0 + if _arg_as_sequence(axes): + # if axis 0 is not at position 0 after permutation, it's moved + idx = list(axes).index(0) + # idx is new position of original axis 0 + if idx != 0: + return True + + ### 3) moveaxis: source/destination can be ints or sequences ### + for name in ("source", "destination"): + val = bound.arguments.get(name, inspect._empty) + if val is not inspect._empty: + if val is None: + continue + elif (_arg_as_sequence(val)) and (0 in val): + return True + elif val == 0: + return True + + ### 4) swapaxes / similar ### + axis1 = bound.arguments.get("axis1", inspect._empty) + axis2 = bound.arguments.get("axis2", inspect._empty) + if (axis1 is not inspect._empty) and (axis1 == 0): + return True + if (axis2 is not inspect._empty) and (axis2 == 0): + return True + + # If none of the checks triggered, assume axis 0 is not modified. + return False diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 3893822c5..3cbb079a4 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -4,7 +4,7 @@ compute_eventcorrelogram, compute_isi_distribution, ) -from .decoding import decode_1d, decode_2d +from .decoding import decode_1d, decode_2d, decode_bayes, decode_template from .filtering import ( apply_bandpass_filter, apply_bandstop_filter, @@ -36,6 +36,9 @@ compute_2d_tuning_curves, compute_2d_tuning_curves_continuous, compute_discrete_tuning_curves, + compute_mutual_information, + compute_response_per_epoch, + compute_tuning_curves, ) from .warping import build_tensor, warp_tensor from .wavelets import compute_wavelet_transform, generate_morlet_filterbank diff --git a/pynapple/process/_process_functions.py b/pynapple/process/_process_functions.py index 94e07e75f..86d491e81 100644 --- a/pynapple/process/_process_functions.py +++ b/pynapple/process/_process_functions.py @@ -246,18 +246,13 @@ def _perievent_continuous( w_sizes = slice_idx[:, 1] - slice_idx[:, 0] # Different sizes - all_w_sizes = np.unique(w_sizes) - all_w_start = np.unique(w_starts) - - for w_size in all_w_sizes: - for w_start in all_w_start: - col_idx = w_sizes == w_size - new_idx = np.zeros((w_size, np.sum(col_idx)), dtype=int) - for i, tmp in enumerate(slice_idx[col_idx]): - new_idx[:, i] = np.arange(tmp[0], tmp[1]) - - new_data_array[w_start : w_start + w_size, col_idx] = data_array[ - new_idx - ] + unique_pairs = np.unique(np.column_stack([w_sizes, w_starts]), axis=0) + for w_size, w_start in unique_pairs: + col_idx = (w_sizes == w_size) & (w_starts == w_start) + new_idx = np.zeros((w_size, np.sum(col_idx)), dtype=int) + for i, slc in enumerate(slice_idx[col_idx]): + new_idx[:, i] = np.arange(slc[0], slc[1]) + + new_data_array[w_start : w_start + w_size, col_idx] = data_array[new_idx] return new_data_array diff --git a/pynapple/process/decoding.py b/pynapple/process/decoding.py index e730138b5..0675c3f32 100644 --- a/pynapple/process/decoding.py +++ b/pynapple/process/decoding.py @@ -1,212 +1,696 @@ """ -Decoding functions for timestamps data (spike times). The first argument is always a tuning curves object. +Functions to decode n-dimensional features. """ +import inspect +import warnings +from functools import wraps + import numpy as np +import xarray as xr +from scipy.spatial.distance import cdist from .. import core as nap -def decode_1d(tuning_curves, group, ep, bin_size, time_units="s", feature=None): - """ - Perform Bayesian decoding over a one dimensional feature. - See: - Zhang, K., Ginzburg, I., McNaughton, B. L., & Sejnowski, T. J. - (1998). Interpreting neuronal population activity by - reconstruction: unified framework with application to - hippocampal place cells. Journal of neurophysiology, 79(2), - 1017-1044. +def _format_decoding_inputs(func): + @wraps(func) + def wrapper(*args, **kwargs): + # Validate each positional argument + sig = inspect.signature(func) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + kwargs = bound.arguments + + # check tuning curves + tuning_curves = kwargs["tuning_curves"] + if not isinstance(tuning_curves, xr.DataArray): + raise TypeError( + "tuning_curves should be an xarray.DataArray as computed by compute_tuning_curves." + ) + + # check data + data = kwargs["data"] + was_continuous = True + if isinstance(data, nap.TsdFrame): + # check match bin_size + actual_bin_size = np.mean(data.time_diff().values) + passed_bin_size = kwargs["bin_size"] + if not isinstance(passed_bin_size, (int, float)): + raise ValueError("bin_size should be a number.") + if not np.isclose( + actual_bin_size, + nap.TsIndex.format_timestamps( + np.array([passed_bin_size], dtype=np.float64), + units=kwargs["time_units"], + ), + )[0]: + warnings.warn("passed bin_size is different from actual data bin size.") + elif isinstance(data, nap.TsGroup): + data = data.count( + kwargs["bin_size"], kwargs["epochs"], time_units=kwargs["time_units"] + ) + was_continuous = False + else: + raise TypeError("Unknown format for data.") + + # check match tuning curves and data + if tuning_curves.sizes["unit"] != data.shape[1]: + raise ValueError("Different shapes for tuning_curves and data.") + if not np.all(tuning_curves.coords["unit"] == data.columns.values): + raise ValueError("Different indices for tuning curves and data keys.") + + if ( + "uniform_prior" in kwargs + and not kwargs["uniform_prior"] + and "occupancy" not in tuning_curves.attrs + ): + raise ValueError( + "uniform_prior set to False but no occupancy found in tuning curves." + ) + + # smooth + sliding_window_size = kwargs["sliding_window_size"] + if sliding_window_size is not None: + if not isinstance(sliding_window_size, int): + raise ValueError("sliding_window_size should be a integer.") + if sliding_window_size < 1: + raise ValueError("sliding_window_size should be >= 1.") + data = data.convolve( + np.ones(sliding_window_size), + ep=kwargs["epochs"], + ) + if was_continuous: + data = data / sliding_window_size + else: + bin_size = sliding_window_size * kwargs["bin_size"] + kwargs["bin_size"] = bin_size + + kwargs["data"] = data + + # Call the original function with validated inputs + return func(**kwargs) + + return wrapper + + +def _format_decoding_outputs(dist, tuning_curves, data, epochs, greater_is_better): + # Get the index of the decoded class + filler = -np.inf if greater_is_better else np.inf + filled = np.where(np.isnan(dist), filler, dist) + idx = getattr(np, "argmax" if greater_is_better else "argmin")(filled, axis=1) + + # Replace with -1 where all values were NaN + all_nan = np.isnan(dist).all(axis=1) + idx[all_nan] = -1 + + # Format probability/distance distribution + dist = dist.reshape(dist.shape[0], *tuning_curves.shape[1:]) + if dist.ndim > 2: + dist = nap.TsdTensor( + t=data.index, + d=dist, + time_support=epochs, + ) + else: + dist = nap.TsdFrame( + t=data.index, + d=dist, + time_support=epochs, + columns=tuning_curves.coords[tuning_curves.dims[1]].values, + ) - Parameters - ---------- - tuning_curves : pandas.DataFrame - Each column is the tuning curve of one neuron relative to the feature. - Index should be the center of the bin. - group : TsGroup, TsdFrame or dict of Ts/Tsd object. - A group of neurons with the same index as tuning curves column names. - You may also pass a TsdFrame with smoothed rates (recommended). - ep : IntervalSet - The epoch on which decoding is computed - bin_size : float - Bin size. Default is second. Use the parameter time_units to change it. - time_units : str, optional - Time unit of the bin size ('s' [default], 'ms', 'us'). - feature : Tsd, optional - The 1d feature used to compute the tuning curves. Used to correct for occupancy. - If feature is not passed, the occupancy is uniform. + # Format decoded index + shape = tuning_curves.shape[1:] + valid = idx != -1 + + if tuning_curves.ndim == 2: + decoded_values = np.full(len(idx), np.nan) + decoded_values[valid] = tuning_curves.coords[tuning_curves.dims[1]].values[ + idx[valid] + ] + decoded = nap.Tsd( + t=data.index, + d=decoded_values, + time_support=epochs, + ) + else: + # unravel valid indices only + unraveled = [np.full(len(idx), np.nan) for _ in shape] + unraveled_indices = np.unravel_index(idx[valid], shape) + for i in range(len(shape)): + unraveled[i][valid] = tuning_curves.coords[ + tuning_curves.dims[1 + i] + ].values[unraveled_indices[i]] + + decoded = nap.TsdFrame( + t=data.index, + d=np.stack(unraveled, axis=1), + time_support=epochs, + columns=tuning_curves.dims[1:], + ) - Returns - ------- - Tsd - The decoded feature - TsdFrame - The probability distribution of the decoded feature for each time bin - - Raises - ------ - RuntimeError - If group is not a dict of Ts/Tsd or TsGroup. - If different size of neurons for tuning_curves and group. - If indexes don't match between tuning_curves and group. + return decoded, dist + + +@_format_decoding_inputs +def decode_bayes( + tuning_curves, + data, + epochs, + bin_size, + sliding_window_size=None, + time_units="s", + uniform_prior=True, +): """ - if isinstance(group, nap.TsdFrame): - newgroup = group.restrict(ep) + Performs Bayesian decoding over n-dimensional features. - if tuning_curves.shape[1] != newgroup.shape[1]: - raise RuntimeError("Different shapes for tuning_curves and group") + The algorithm is based on Bayes' rule: - if not np.all(tuning_curves.columns.values == np.array(newgroup.columns)): - raise RuntimeError("Different indices for tuning curves and group keys") + .. math:: - count = group + P(x|n) \\propto P(n|x) P(x) - elif isinstance(group, nap.TsGroup): - newgroup = group.restrict(ep) + where: - if tuning_curves.shape[1] != len(newgroup): - raise RuntimeError("Different shapes for tuning_curves and group") + - :math:`P(x|n)` is the **posterior probability** of the feature value given the observed neural activity. + - :math:`P(n|x)` is the **likelihood** of the neural activity given the feature value. + - :math:`P(x)` is the **prior** probability of the feature value. - if not np.all(tuning_curves.columns.values == np.array(newgroup.keys())): - raise RuntimeError("Different indices for tuning curves and group keys") + Mapping this to the function: - # Bin spikes - count = newgroup.count(bin_size, ep, time_units) + - :math:`P(x|n)` is the estimated probability distribution over the decoded feature for each time bin. + This is the output of the function. The decoded value is the one with the maximum posterior probability. + - :math:`P(n|x)` is determined by the tuning curves. Assuming spikes follow a Poisson distribution and + neurons are conditionally independent: - elif isinstance(group, dict): - newgroup = nap.TsGroup(group, time_support=ep) - count = newgroup.count(bin_size, ep, time_units) + .. math:: - else: - raise RuntimeError("Unknown format for group") + P(n|x) = \\prod_{i=1}^{N} P(n_i|x) = \\prod_{i=1}^{N} \\frac{\\lambda_i^{n_i} e^{-\\lambda_i}}{n_i!} - # Occupancy - if feature is None: - occupancy = np.ones(tuning_curves.shape[0]) - elif isinstance(feature, nap.Tsd): - diff = np.diff(tuning_curves.index.values) - bins = tuning_curves.index.values[:-1] - diff / 2 - bins = np.hstack( - (bins, [bins[-1] + diff[-1], bins[-1] + 2 * diff[-1]]) - ) # assuming the size of the last 2 bins is equal - occupancy, _ = np.histogram(feature.values, bins) - else: - raise RuntimeError("Unknown format for feature in decode_1d") + where :math:`\\lambda_i` is the expected firing rate of neuron :math:`i` at feature value :math:`x`, + and :math:`n_i` is the spike count of neuron :math:`i`. + + - :math:`P(x)` depends on the value of the ``uniform_prior`` argument. + If ``uniform_prior=True``, it is a uniform distribution over feature values. + If ``uniform_prior=False``, it is based on the occupancy (i.e. the time spent in each feature bin during tuning curve estimation). + + References + ---------- + .. [1] Zhang, K., Ginzburg, I., McNaughton, B. L., & Sejnowski, T. J. + (1998). Interpreting neuronal population activity by + reconstruction: unified framework with application to + hippocampal place cells. Journal of neurophysiology, 79(2), + 1017-1044. + + Parameters + ---------- + tuning_curves : xarray.DataArray + Tuning curves as computed by :func:`~pynapple.process.tuning_curves.compute_tuning_curves`. + data : TsGroup or TsdFrame + Neural activity with the same keys as the tuning curves. + You may also pass a TsdFrame with smoothed counts. + epochs : IntervalSet + The epochs on which decoding is computed + bin_size : float + Bin size. Default in seconds. Use ``time_units`` to change it. + sliding_window_size : int, optional + The size, in number of bins, for a uniform window to be convolved with the counts array for each neuron. Value should be >= 1. + If None (default), no smoothing is applied. + time_units : str, optional + Time unit of the bin size (``s`` [default], ``ms``, ``us``). + uniform_prior : bool, optional + If True (default), uses a uniform distribution as a prior. + If False, uses the occupancy from the tuning curves as a prior over the feature + probability distribution. - # Transforming to pure numpy array - tc = tuning_curves.values - ct = count.values + Returns + ------- + Tsd + The decoded feature. + TsdFrame, TsdTensor + The probability distribution of the decoded feature for each time bin. + + Examples + -------- + In the simplest case, we can decode a single feature (e.g., position) from a group of neurons: + + >>> import pynapple as nap + >>> import numpy as np + >>> data = nap.TsGroup({i: nap.Ts(t=np.arange(0, 50) + 50 * i) for i in range(2)}) + >>> feature = nap.Tsd(t=np.arange(0, 100, 1), d=np.repeat(np.arange(0, 2), 50)) + >>> tuning_curves = nap.compute_tuning_curves(data, feature, bins=2, range=(-.5, 1.5)) + >>> epochs = nap.IntervalSet([0, 100]) + >>> decoded, p = nap.decode_bayes(tuning_curves, data, epochs=epochs, bin_size=1) + >>> decoded + Time (s) + ---------- -- + 0.5 0 + 1.5 0 + 2.5 0 + 3.5 0 + 4.5 0 + 5.5 0 + 6.5 0 + ... + 93.5 1 + 94.5 1 + 95.5 1 + 96.5 1 + 97.5 1 + 98.5 1 + 99.5 1 + dtype: float64, shape: (100,) + + decode is a `Tsd` object containing the decoded feature for each time bin. + + >>> p + Time (s) 0.0 1.0 + ---------- ----- ----- + 0.5 1.0 0.0 + 1.5 1.0 0.0 + 2.5 1.0 0.0 + 3.5 1.0 0.0 + 4.5 1.0 0.0 + 5.5 1.0 0.0 + 6.5 1.0 0.0 + ... ... ... + 93.5 0.0 1.0 + 94.5 0.0 1.0 + 95.5 0.0 1.0 + 96.5 0.0 1.0 + 97.5 0.0 1.0 + 98.5 0.0 1.0 + 99.5 0.0 1.0 + dtype: float64, shape: (100, 2) + + p is a `TsdFrame` object containing the probability distribution for each time bin. + + The function also works for multiple features, in which case it does n-dimensional decoding: + + >>> features = nap.TsdFrame( + ... t=np.arange(0, 100, 1), + ... d=np.vstack((np.repeat(np.arange(0, 2), 50), np.tile(np.arange(0, 2), 50))).T, + ... ) + >>> data = nap.TsGroup( + ... { + ... 0: nap.Ts(np.arange(0, 50, 2)), + ... 1: nap.Ts(np.arange(1, 51, 2)), + ... 2: nap.Ts(np.arange(50, 100, 2)), + ... 3: nap.Ts(np.arange(51, 101, 2)), + ... } + ... ) + >>> tuning_curves = nap.compute_tuning_curves(data, features, bins=2, range=[(-.5, 1.5)]*2) + >>> decoded, p = nap.decode_bayes(tuning_curves, data, epochs=epochs, bin_size=1) + >>> decoded + Time (s) 0 1 + ---------- --- --- + 0.5 0.0 0.0 + 1.5 0.0 1.0 + 2.5 0.0 0.0 + 3.5 0.0 1.0 + 4.5 0.0 0.0 + 5.5 0.0 1.0 + 6.5 0.0 0.0 + ... ... ... + 93.5 1.0 1.0 + 94.5 1.0 0.0 + 95.5 1.0 1.0 + 96.5 1.0 0.0 + 97.5 1.0 1.0 + 98.5 1.0 0.0 + 99.5 1.0 1.0 + dtype: float64, shape: (100, 2) + + decoded is now a `TsdFrame` object containing the decoded features for each time bin. + + >>> p + Time (s) + ---------- -------------- + 0.5 [[1., 0.] ...] + 1.5 [[0., 1.] ...] + 2.5 [[1., 0.] ...] + 3.5 [[0., 1.] ...] + 4.5 [[1., 0.] ...] + 5.5 [[0., 1.] ...] + 6.5 [[1., 0.] ...] + ... + 93.5 [[0., 0.] ...] + 94.5 [[0., 0.] ...] + 95.5 [[0., 0.] ...] + 96.5 [[0., 0.] ...] + 97.5 [[0., 0.] ...] + 98.5 [[0., 0.] ...] + 99.5 [[0., 0.] ...] + dtype: float64, shape: (100, 2, 2) + + and p is a `TsdTensor` object containing the probability distribution for each time bin. + + It is also possible to pass continuous values instead of spikes (e.g. smoothed spike counts): + + >>> data = data.count(1).smooth(2) + >>> tuning_curves = nap.compute_tuning_curves(data, features, bins=2, range=[(-.5, 1.5)]*2) + >>> decoded, p = nap.decode_bayes(tuning_curves, data, epochs=epochs, bin_size=1) + >>> decoded + Time (s) 0 1 + ---------- --- --- + 0.5 0.0 1.0 + 1.5 0.0 1.0 + 2.5 0.0 1.0 + 3.5 0.0 1.0 + 4.5 0.0 0.0 + 5.5 0.0 0.0 + 6.5 0.0 0.0 + ... ... ... + 92.5 1.0 0.0 + 93.5 1.0 0.0 + 94.5 1.0 0.0 + 95.5 1.0 1.0 + 96.5 1.0 1.0 + 97.5 1.0 1.0 + 98.5 1.0 1.0 + dtype: float64, shape: (98, 2) + """ + prior = ( + np.ones_like(tuning_curves[0]).flatten() + if uniform_prior + else tuning_curves.attrs["occupancy"].flatten() + ) + prior = prior.astype(np.float64) + prior /= prior.sum() + rate_map = tuning_curves.values.reshape(tuning_curves.sizes["unit"], -1).T + observed_counts = data.values bin_size_s = nap.TsIndex.format_timestamps( np.array([bin_size], dtype=np.float64), time_units )[0] + observed_counts_expanded = np.tile( + observed_counts[:, np.newaxis, :], (1, rate_map.shape[0], 1) + ) - p1 = np.exp(-bin_size_s * tc.sum(1)) - p2 = occupancy / occupancy.sum() + EPS = 1e-12 + log_likelihood = np.nansum( + observed_counts_expanded * np.log(rate_map + EPS) - bin_size_s * rate_map, + axis=-1, + ) - ct2 = np.tile(ct[:, np.newaxis, :], (1, tc.shape[0], 1)) + log_posterior = log_likelihood + np.log(prior) + posterior = np.exp(log_posterior - log_posterior.max(axis=1, keepdims=True)) + posterior /= posterior.sum(axis=1, keepdims=True) - p3 = np.prod(tc**ct2, -1) + return _format_decoding_outputs( + posterior, tuning_curves, data, epochs, greater_is_better=True + ) - p = p1 * p2 * p3 - p = p / p.sum(1)[:, np.newaxis] - idxmax = np.argmax(p, 1) +@_format_decoding_inputs +def decode_template( + tuning_curves, + data, + epochs, + bin_size, + metric="correlation", + sliding_window_size=None, + time_units="s", +): + """ + Performs template matching decoding over n-dimensional features. - p = nap.TsdFrame( - t=count.index, d=p, time_support=ep, columns=tuning_curves.index.values - ) + The algorithm decodes as follow: - decoded = nap.Tsd( - t=count.index, d=tuning_curves.index.values[idxmax], time_support=ep - ) + .. math:: - return decoded, p + \\hat{x}(t) = \\arg\\min\\limits_{x} [dist(f(x), n(t))] + where: -def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=None): - """ - Performs Bayesian decoding over 2 dimensional features. + - :math:`f(x)` is the the tuning curve function. + - :math:`n(t)` is input neural activity at time :math:`t`. + - :math:`dist` is a distance metric. + + The algorithm computes the distance between the observed neural activity and the tuning curves for every time bin. + The decoded feature at each time bin corresponds to the tuning curve bin with the smallest distance. - See: - Zhang, K., Ginzburg, I., McNaughton, B. L., & Sejnowski, T. J. - (1998). Interpreting neuronal population activity by - reconstruction: unified framework with application to - hippocampal place cells. Journal of neurophysiology, 79(2), - 1017-1044. + See :func:`scipy.spatial.distance.cdist` for available distance metrics and how they are computed. + + References + ---------- + .. [1] Zhang, K., Ginzburg, I., McNaughton, B. L., & Sejnowski, T. J. + (1998). Interpreting neuronal population activity by + reconstruction: unified framework with application to + hippocampal place cells. Journal of neurophysiology, 79(2), + 1017-1044. Parameters ---------- - tuning_curves : dict - Dictionary of 2d tuning curves (one for each neuron). - group : TsGroup, TsdFrame or dict of Ts/Tsd object. - A group of neurons with the same keys as tuning_curves dictionary. - You may also pass a TsdFrame with smoothed rates (recommended). - ep : IntervalSet - The epoch on which decoding is computed + tuning_curves : xarray.DataArray + Tuning curves as computed by :func:`~pynapple.process.tuning_curves.compute_tuning_curves`. + data : TsGroup or TsdFrame + Neural activity with the same keys as the tuning curves. + You may also pass a TsdFrame with smoothed counts. + epochs : IntervalSet + The epochs on which decoding is computed bin_size : float - Bin size. Default is second. Use the parameter time_units to change it. - xy : tuple - A tuple of bin positions for the tuning curves i.e. xy=(x,y) + Bin size. Default is second. Use ``time_units`` to change it. + metric : str or callable, optional + The distance metric to use for template matching. + + If a string, passed to :func:`scipy.spatial.distance.cdist`, must be one of: + ``braycurtis``, ``canberra``, ``chebyshev``, ``cityblock``, ``correlation``, + ``cosine``, ``dice``, ``euclidean``, ``hamming``, ``jaccard``, ``jensenshannon``, + ``kulczynski1``, ``mahalanobis``, ``matching``, ``minkowski``, ``rogerstanimoto``, + ``russellrao``, ``seuclidean``, ``sokalmichener``, ``sokalsneath``, + ``sqeuclidean`` or ``yule``. + + Default is ``correlation``. + + .. note:: + Some metrics may not be suitable for all types of data. + For example, metrics such as ``hamming`` do not handle NaN values. + + If a callable, it must have the signature ``metric(u, v) -> float`` and + return the distance between two 1D arrays. + sliding_window_size : int, optional + The size, in number of bins, for a uniform window to be convolved with the counts array for each neuron. Value should be >= 1. + If None (default), no smoothing is applied. time_units : str, optional - Time unit of the bin size ('s' [default], 'ms', 'us'). - features : TsdFrame - The 2 columns features used to compute the tuning curves. Used to correct for occupancy. - If feature is not passed, the occupancy is uniform. + Time unit of the bin size (``s`` [default], ``ms``, ``us``). Returns ------- Tsd - The decoded feature in 2d - numpy.ndarray - The probability distribution of the decoded trajectory for each time bin - - Raises - ------ - RuntimeError - If group is not a dict of Ts/Tsd or TsGroup. - If different size of neurons for tuning_curves and group. - If indexes don't match between tuning_curves and group. - + The decoded feature + TsdFrame or TsdTensor + The distance matrix between the neural activity and the tuning curves for each time bin. + + Examples + -------- + In the simplest case, we can decode a single feature (e.g., position) from a group of neurons: + + >>> import pynapple as nap + >>> import numpy as np + >>> group = nap.TsGroup({i: nap.Ts(t=np.arange(0, 50) + 50 * i) for i in range(2)}) + >>> feature = nap.Tsd(t=np.arange(0, 100, 1), d=np.repeat(np.arange(0, 2), 50)) + >>> tuning_curves = nap.compute_tuning_curves(group, feature, bins=2, range=(-.5, 1.5)) + >>> epochs = nap.IntervalSet([0, 100]) + >>> decoded, dist = nap.decode_template(tuning_curves, group, epochs=epochs, bin_size=1) + >>> decoded + Time (s) + ---------- -- + 0.5 0 + 1.5 0 + 2.5 0 + 3.5 0 + 4.5 0 + 5.5 0 + 6.5 0 + ... + 93.5 1 + 94.5 1 + 95.5 1 + 96.5 1 + 97.5 1 + 98.5 1 + 99.5 1 + dtype: float64, shape: (100,) + + decode is a `Tsd` object containing the decoded feature for each time bin. + + >>> p + Time (s) 0.0 1.0 + ---------- ----- ----- + 0.5 0.0 2.0 + 1.5 0.0 2.0 + 2.5 0.0 2.0 + 3.5 0.0 2.0 + 4.5 0.0 2.0 + 5.5 0.0 2.0 + ... ... ... + 94.5 2.0 0.0 + 95.5 2.0 0.0 + 96.5 2.0 0.0 + 97.5 2.0 0.0 + 98.5 2.0 0.0 + 99.5 2.0 0.0 + dtype: float64, shape: (100, 2) + + dist is a `TsdFrame` object containing the distances for each time bin. + + The function also works for multiple features, in which case it does n-dimensional decoding: + + >>> features = nap.TsdFrame( + ... t=np.arange(0, 100, 1), + ... d=np.vstack((np.repeat(np.arange(0, 2), 50), np.tile(np.arange(0, 2), 50))).T, + ... ) + >>> group = nap.TsGroup( + ... { + ... 0: nap.Ts(np.arange(0, 50, 2)), + ... 1: nap.Ts(np.arange(1, 51, 2)), + ... 2: nap.Ts(np.arange(50, 100, 2)), + ... 3: nap.Ts(np.arange(51, 101, 2)), + ... } + ... ) + >>> tuning_curves = nap.compute_tuning_curves(group, features, bins=2, range=[(-.5, 1.5)]*2) + >>> decoded, dist = nap.decode_template(tuning_curves, group, epochs=epochs, bin_size=1) + >>> decoded + Time (s) 0 1 + ---------- --- --- + 0.5 0.0 0.0 + 1.5 0.0 1.0 + 2.5 0.0 0.0 + 3.5 0.0 1.0 + 4.5 0.0 0.0 + 5.5 0.0 1.0 + 6.5 0.0 0.0 + ... ... ... + 93.5 1.0 1.0 + 94.5 1.0 0.0 + 95.5 1.0 1.0 + 96.5 1.0 0.0 + 97.5 1.0 1.0 + 98.5 1.0 0.0 + 99.5 1.0 1.0 + dtype: float64, shape: (100, 2) + + decoded is now a `TsdFrame` object containing the decoded features for each time bin. + + >>> dist + Time (s) + ---------- -------------------------- + 0.5 [[0. , 1.333333] ...] + 1.5 [[1.333333, 0. ] ...] + 2.5 [[0. , 1.333333] ...] + 3.5 [[1.333333, 0. ] ...] + 4.5 [[0. , 1.333333] ...] + 5.5 [[1.333333, 0. ] ...] + ... + 94.5 [[1.333333, 1.333333] ...] + 95.5 [[1.333333, 1.333333] ...] + 96.5 [[1.333333, 1.333333] ...] + 97.5 [[1.333333, 1.333333] ...] + 98.5 [[1.333333, 1.333333] ...] + 99.5 [[1.333333, 1.333333] ...] + dtype: float64, shape: (100, 2, 2) + + and dist is a `TsdTensor` object containing the distances for each time bin. + + It is also possible to pass continuous values instead of spikes (e.g. calcium imaging): + + >>> time = np.arange(0,100, 0.1) + >>> group = nap.TsdFrame(t=time, d=np.stack([time % 0.5, time %1], axis=1)) + >>> tuning_curves = nap.compute_tuning_curves(group, features, bins=2, range=[(-.5, 1.5)]*2) + >>> decoded, dist = nap.decode_template(tuning_curves, group, epochs=epochs, bin_size=1) + >>> decoded + Time (s) 0 1 + ---------- --- --- + 0.0 0.0 0.0 + 0.1 0.0 0.0 + 0.2 0.0 0.0 + 0.3 0.0 0.0 + 0.4 0.0 0.0 + 0.5 1.0 1.0 + 0.6 1.0 1.0 + ... ... ... + 99.3 0.0 0.0 + 99.4 0.0 0.0 + 99.5 1.0 1.0 + 99.6 1.0 1.0 + 99.7 1.0 1.0 + 99.8 1.0 1.0 + 99.9 1.0 1.0 + dtype: float64, shape: (1000, 2) """ + tc = tuning_curves.values.reshape(tuning_curves.sizes["unit"], -1) + ct = data.values + + return _format_decoding_outputs( + cdist(ct, tc.T, metric=metric), + tuning_curves, + data, + epochs, + greater_is_better=False, + ) - if type(group) is nap.TsdFrame: - newgroup = group.restrict(ep) - numcells = newgroup.shape[1] - - if len(tuning_curves) != numcells: - raise RuntimeError("Different shapes for tuning_curves and group") - - if not np.all( - np.array(list(tuning_curves.keys())) == np.array(newgroup.columns) - ): - raise RuntimeError("Different indices for tuning curves and group keys") - - count = group - - elif type(group) is nap.TsGroup: - newgroup = group.restrict(ep) - numcells = len(newgroup) - - if len(tuning_curves) != numcells: - raise RuntimeError("Different shapes for tuning_curves and group") - - if not np.all( - np.array(list(tuning_curves.keys())) == np.array(newgroup.keys()) - ): - raise RuntimeError("Different indices for tuning curves and group keys") - count = newgroup.count(bin_size, ep, time_units) +# ------------------------------------------------------------------------------------- +# Deprecated functions for backward compatibility +# ------------------------------------------------------------------------------------- - elif type(group) is dict: - newgroup = nap.TsGroup(group, time_support=ep) - count = newgroup.count(bin_size, ep, time_units) +def decode_1d(tuning_curves, group, ep, bin_size, time_units="s", feature=None): + """ + .. deprecated:: 0.9.2 + `decode_1d` will be removed in Pynapple 1.0.0, it is replaced by + `decode_bayes` because the latter works for N dimensions. + """ + warnings.warn( + "decode_1d is deprecated and will be removed in a future version; use decode_bayes instead.", + FutureWarning, + stacklevel=2, + ) + # Occupancy + if feature is None: + occupancy = np.ones(tuning_curves.shape[0]) + elif isinstance(feature, nap.Tsd): + diff = np.diff(tuning_curves.index.values) + bins = tuning_curves.index.values[:-1] - diff / 2 + bins = np.hstack( + (bins, [bins[-1] + diff[-1], bins[-1] + 2 * diff[-1]]) + ) # assuming the size of the last 2 bins is equal + occupancy, _ = np.histogram(feature.values, bins) else: - raise RuntimeError("Unknown format for group") + raise RuntimeError("Unknown format for feature in decode_1d") + return decode_bayes( + xr.DataArray( + data=tuning_curves.values.T, + coords={ + "unit": tuning_curves.columns.values, + "0": tuning_curves.index.values, + }, + attrs={"occupancy": occupancy}, + ), + nap.TsGroup(group) if isinstance(group, dict) else group, + ep, + bin_size, + time_units=time_units, + uniform_prior=feature is None, + ) - indexes = list(tuning_curves.keys()) +def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=None): + """ + .. deprecated:: 0.9.2 + `decode_2d` will be removed in Pynapple 1.0.0, it is replaced by + `decode_bayes` because the latter works for N dimensions. + """ + warnings.warn( + "decode_2d is deprecated and will be removed in a future version; use decode_bayes instead.", + FutureWarning, + stacklevel=2, + ) # Occupancy + indexes = list(tuning_curves.keys()) if features is None: occupancy = np.ones_like(tuning_curves[indexes[0]]).flatten() else: @@ -223,42 +707,15 @@ def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=N features[:, 0].values, features[:, 1].values, [binsxy[0], binsxy[1]] ) occupancy = occupancy.flatten() - - # Transforming to pure numpy array - tc = np.array([tuning_curves[i] for i in tuning_curves.keys()]) - tc = tc.reshape(tc.shape[0], np.prod(tc.shape[1:])) - tc = tc.T - ct = count.values - bin_size_s = nap.TsIndex.format_timestamps( - np.array([bin_size], dtype=np.float64), time_units - )[0] - - p1 = np.exp(-bin_size_s * np.nansum(tc, 1)) - p2 = occupancy / occupancy.sum() - - ct2 = np.tile(ct[:, np.newaxis, :], (1, tc.shape[0], 1)) - - p3 = np.nanprod(tc**ct2, -1) - - p = p1 * p2 * p3 - p = p / p.sum(1)[:, np.newaxis] - - idxmax = np.argmax(p, 1) - - p = p.reshape(p.shape[0], len(xy[0]), len(xy[1])) - - idxmax2d = np.unravel_index(idxmax, (len(xy[0]), len(xy[1]))) - - if features is not None: - cols = features.columns - else: - cols = np.arange(2) - - decoded = nap.TsdFrame( - t=count.index, - d=np.vstack((xy[0][idxmax2d[0]], xy[1][idxmax2d[1]])).T, - time_support=ep, - columns=cols, + return decode_bayes( + xr.DataArray( + data=[tuning_curves[i] for i in indexes], + coords={"unit": indexes, "0": xy[0], "1": xy[1]}, + attrs={"occupancy": occupancy}, + ), + nap.TsGroup(group) if isinstance(group, dict) else group, + ep, + bin_size, + time_units=time_units, + uniform_prior=features is None, ) - - return decoded, p diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 5fe05a4e8..68add1260 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd -from scipy.signal import butter, sosfiltfilt, sosfreqz +from scipy.signal import butter, filtfilt, sosfiltfilt, sosfreqz from .. import core as nap @@ -510,3 +510,96 @@ def get_filter_frequency_response( ) else: raise ValueError("Unrecognized filter mode. Choose either 'butter' or 'sinc'") + + +def detect_oscillatory_events( + data, + epoch, + freq_band, + thresh_band, + duration_band, + min_inter_duration, + fs=None, + wsize=51, +): + """ + Simple helper for detecting oscillatory events (e.g. ripples, spindles) + + Parameters + ---------- + data : Tsd + 1-dimensional time series + epoch : IntervalSet + The epoch for restricting the detection + freq_band : tuple + The (low, high) frequency to bandpass the signal + thresh_band : tuple + The (min, max) value for thresholding the normalized squared signal after filtering + duration_band : tuple + The (min, max) duration of an event in second + min_inter_duration : float + The minimum duration between two events otherwise they are merged (in seconds) + fs : float, optional + The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. + wsize : int, optional + The size of the window for digital filtering + + Returns + ------- + IntervalSet + The interval set of detected events with metadata containing + the power, amplitude, and peak_time + """ + data = data.restrict(epoch) + + if fs is None: + fs = data.rate + + signal = apply_bandpass_filter(data, freq_band, fs) + squared_signal = np.square(signal.values) + window = np.ones(wsize) / wsize + + nSS = filtfilt(window, 1, squared_signal) + nSS = (nSS - np.mean(nSS)) / np.std(nSS) + nSS = nap.Tsd(t=signal.index.values, d=nSS, time_support=epoch) + + # Detect oscillation periods by thresholding normalized signal + nSS2 = nSS.threshold(thresh_band[0], method="above") + nSS3 = nSS2.threshold(thresh_band[1], method="below") + + # Exclude oscillation where min_duration < length < max_duration + osc_ep = nSS3.time_support + osc_ep = osc_ep.drop_short_intervals(duration_band[0], time_units="s") + osc_ep = osc_ep.drop_long_intervals(duration_band[1], time_units="s") + + # Merge if inter-oscillation period is too short + osc_ep = osc_ep.merge_close_intervals(min_inter_duration, time_units="s") + + # Compute power, amplitude, and peak_time for each interval + powers = [] + amplitudes = [] + peak_times = [] + + for s, e in osc_ep.values: + seg = signal.get(s, e) + if len(seg) == 0: + powers.append(np.nan) + amplitudes.append(np.nan) + peak_times.append(np.nan) + continue + power = np.mean(np.square(seg)) + power_db = 10 * np.log10(power) + amplitude = np.max(np.abs(seg.values)) + peak_idx = np.argmax(np.abs(seg.values)) + peak_time = seg.index.values[peak_idx] + powers.append(power_db) + amplitudes.append(amplitude) + peak_times.append(peak_time) + + metadata = { + "power": powers, + "amplitude": amplitudes, + "peak_time": peak_times, + } + + return nap.IntervalSet(start=osc_ep.start, end=osc_ep.end, metadata=metadata) diff --git a/pynapple/process/tuning_curves.py b/pynapple/process/tuning_curves.py index fba5600b2..867d3868e 100644 --- a/pynapple/process/tuning_curves.py +++ b/pynapple/process/tuning_curves.py @@ -1,6 +1,5 @@ """ -Functions to compute tuning curves for features in 1 dimension or 2 dimension. - +Functions to compute n-dimensional tuning curves. """ import inspect @@ -10,10 +9,521 @@ import numpy as np import pandas as pd +import xarray as xr from .. import core as nap +def compute_tuning_curves( + data, + features, + bins=10, + range=None, + epochs=None, + fs=None, + feature_names=None, + return_pandas=False, + return_counts=False, +): + """ + Computes n-dimensional tuning curves relative to n features. + + Parameters + ---------- + data : TsGroup, TsdFrame, Ts, Tsd + The data for which the tuning curves will be computed. This usually corresponds to the activity of the + neurons, either as spike times (TsGroup or Ts) or continuous values (TsdFrame or Tsd). + features : Tsd, TsdFrame + The features (i.e. one column per feature). This usually corresponds to behavioral variables such as + position, head direction, speed, etc. + bins : sequence or int + The bin specification: + + * A sequence of arrays describing the monotonically increasing bin + edges along each dimension. + * The number of bins for each dimension (nx, ny, ... =bins) + * The number of bins for all dimensions (nx=ny=...=bins). + range : sequence, optional + A sequence of entries per feature, each an optional (lower, upper) tuple giving + the outer bin edges to be used if the edges are not given explicitly in + `bins`. + An entry of None in the sequence results in the minimum and maximum + values being used for the corresponding dimension. + The default, None, is equivalent to passing a tuple of D None values. + epochs : IntervalSet, optional + The epochs on which tuning curves are computed. + If None, the epochs are the time support of the features. + fs : float, optional + The exact sampling frequency of the features used to normalise the tuning curves. + Unit should match that of the features. If not passed, it is estimated. + feature_names : list, optional + A list of feature names. If not passed, the column names in `features` are used. + return_pandas : bool, optional + If True, the function returns a pandas.DataFrame instead of an xarray.DataArray. + Note that this will not work if the features are not 1D and that occupancy and bin edges + will not be stored as attributes. + return_counts : bool, optional + If True, does not divide the spike counts by occupancy, but returns the counts directly. + The occupancy is stored in the xarray attributes, so the division can be performed after any + particular processing steps. + If the input is a TsdFrame, this does not do anything. + + Returns + ------- + xarray.DataArray + A tensor containing the tuning curves with labeled bin centres. + The bin edges and occupancy are stored as attributes. + + Examples + -------- + In the simplest case, we can pass a group of spikes per neuron and a single feature: + + >>> import pynapple as nap + >>> import numpy as np; np.random.seed(42) + >>> group = nap.TsGroup({ + ... 1: nap.Ts(np.arange(0, 100, 0.1)), + ... 2: nap.Ts(np.arange(0, 100, 0.2)) + ... }) + >>> feature = nap.Tsd(d=np.arange(0, 100, 0.1) % 1, t=np.arange(0, 100, 0.1)) + >>> tcs = nap.compute_tuning_curves(group, feature, bins=10) + >>> tcs + Size: 160B + array([[10., 10., 10., 10., 10., 10., 10., 10., 10., 10.], + [10., 0., 10., 0., 10., 0., 10., 0., 10., 0.]]) + Coordinates: + * unit (unit) int64 16B 1 2 + * 0 (0) float64 80B 0.045 0.135 0.225 0.315 ... 0.585 0.675 0.765 0.855 + Attributes: + occupancy: [100. 100. 100. 100. 100. 100. 100. 100. 100. 100.] + bin_edges: [array([0. , 0.09, 0.18, 0.27, 0.36, 0.45, 0.54, 0.63, 0.72,... + + The function can also take multiple features, in which case it computes n-dimensional tuning curves. + We can specify the number of bins for each feature: + + >>> features = nap.TsdFrame( + ... d=np.stack( + ... [ + ... np.arange(0, 100, 0.1) % 1, + ... np.arange(0, 100, 0.1) % 2 + ... ], + ... axis=1 + ... ), + ... t=np.arange(0, 100, 0.1) + ... ) + >>> tcs = nap.compute_tuning_curves(group, features, bins=[5, 3]) + >>> tcs + Size: 240B + array([[[10., 10., nan], + [10., 10., 10.], + [10., nan, 10.], + [10., 10., 10.], + [nan, 10., 10.]], + ... + [[ 5., 5., nan], + [ 5., 10., 0.], + [ 5., nan, 5.], + [10., 0., 5.], + [nan, 5., 5.]]]) + Coordinates: + * unit (unit) int64 16B 1 2 + * 0 (0) float64 40B 0.09 0.27 0.45 0.63 0.81 + * 1 (1) float64 24B 0.3167 0.95 1.583 + Attributes: + occupancy: [[100. 100. nan]\\n [100. 50. 50.]\\n [100. nan 100.]\\n [ 5... + bin_edges: [array([0. , 0.18, 0.36, 0.54, 0.72, 0.9 ]), array([0. ... + + Or even specify the bin edges directly: + + >>> tcs = nap.compute_tuning_curves( + ... group, + ... features, + ... bins=[np.linspace(0, 1, 5), np.linspace(0, 2, 3)] + ... ) + >>> tcs + Size: 128B + array([[[10. , 10. ], + [10. , 10. ], + [10. , 10. ], + [10. , 10. ]], + ... + [[ 6.66666667, 6.66666667], + [ 5. , 5. ], + [ 3.33333333, 3.33333333], + [ 5. , 5. ]]]) + Coordinates: + * unit (unit) int64 16B 1 2 + * 0 (0) float64 32B 0.125 0.375 0.625 0.875 + * 1 (1) float64 16B 0.5 1.5 + Attributes: + occupancy: [[150. 150.]\\n [100. 100.]\\n [150. 150.]\\n [100. 100.]] + bin_edges: [array([0. , 0.25, 0.5 , 0.75, 1. ]), array([0., 1., 2.])] + + In all of these cases, it is also possible to pass continuous values instead of spikes (e.g. calcium imaging data), in that case the mean response is computed: + + >>> frame = nap.TsdFrame(d=np.random.rand(2000, 3), t=np.arange(0, 100, 0.05)) + >>> tcs = nap.compute_tuning_curves(frame, feature, bins=10) + >>> tcs + Size: 240B + array([[0.49147343, 0.50190395, 0.50971339, 0.50128013, 0.54332711, + 0.49712328, 0.49594611, 0.5110517 , 0.52247351, 0.52057658], + [0.51132036, 0.46410557, 0.47732505, 0.49830908, 0.53523019, + 0.53099429, 0.48668499, 0.44198555, 0.49222208, 0.47453398], + [0.46591801, 0.50662914, 0.46875882, 0.48734997, 0.51836574, + 0.50722266, 0.48943577, 0.49730095, 0.47944075, 0.48623693]]) + Coordinates: + * unit (unit) int64 24B 0 1 2 + * 0 (0) float64 80B 0.045 0.135 0.225 0.315 ... 0.585 0.675 0.765 0.855 + Attributes: + occupancy: [100. 100. 100. 100. 100. 100. 100. 100. 100. 100.] + bin_edges: [array([0. , 0.09, 0.18, 0.27, 0.36, 0.45, 0.54, 0.63, 0.72,... + """ + + # check data + if not isinstance(data, (nap.TsdFrame, nap.TsGroup, nap.Ts, nap.Tsd)): + raise TypeError("data should be a TsdFrame, TsGroup, Ts, or Tsd.") + + # check features + if not isinstance(features, (nap.TsdFrame, nap.Tsd)): + raise TypeError("features should be a Tsd or TsdFrame.") + + # check feature names + if feature_names is None: + feature_names = ( + features.columns if isinstance(features, nap.TsdFrame) else ["0"] + ) + else: + if ( + not hasattr(feature_names, "__len__") + or isinstance(feature_names, str) + or not all(isinstance(n, str) for n in feature_names) + ): + raise TypeError("feature_names should be a list of strings.") + if len(feature_names) != ( + 1 if isinstance(features, nap.Tsd) else features.shape[-1] + ): + raise ValueError("feature_names should match the number of features.") + + # check epochs + if epochs is None: + epochs = features.time_support + elif isinstance(epochs, nap.IntervalSet): + features = features.restrict(epochs) + else: + raise TypeError("epochs should be an IntervalSet.") + data = data.restrict(epochs) + + # check fs + if fs is None: + fs = 1 / np.mean(features.time_diff(epochs=epochs).values) + if not isinstance(fs, (int, float)): + raise TypeError("fs should be a number (int or float)") + + # check range + if range is not None and isinstance(range, tuple): + if features.ndim == 1 or features.shape[1] == 1: + range = [range] + else: + raise ValueError( + "range should be a sequence of tuples, one for each feature." + ) + + # check return_pandas + if ( + return_pandas != 1 + and return_pandas != 0 + and not isinstance(return_pandas, bool) + ): + raise TypeError("return_pandas should be a boolean.") + + # check return_counts + if ( + return_counts != 1 + and return_counts != 0 + and not isinstance(return_counts, bool) + ): + raise TypeError("return_counts should be a boolean.") + + # occupancy + occupancy, bin_edges = np.histogramdd(features, bins=bins, range=range) + + # tuning curves + keys = ( + data.keys() + if isinstance(data, nap.TsGroup) + else data.columns if isinstance(data, nap.TsdFrame) else [0] + ) + tcs = np.zeros([len(keys), *occupancy.shape]) + if isinstance(data, (nap.TsGroup, nap.Ts)): + # SPIKES + if isinstance(data, nap.Ts): + data = {0: data} + for i, n in enumerate(keys): + tcs[i] = np.histogramdd( + data[n].value_from(features), + bins=bin_edges, + )[0] + occupancy[occupancy == 0.0] = np.nan + if not return_counts: + tcs = (tcs / occupancy) * fs + else: + # RATES + values = data.value_from(features) + if isinstance(data, nap.Tsd): + data = np.expand_dims(data.values, -1) + counts = np.histogramdd(values, bins=bin_edges)[0] + counts[counts == 0] = np.nan + for i, n in enumerate(keys): + tcs[i] = np.histogramdd( + values, + weights=data[:, i], + bins=bin_edges, + )[0] + tcs /= counts + tcs[np.isnan(tcs)] = 0.0 + tcs[:, occupancy == 0.0] = np.nan + + tcs = xr.DataArray( + tcs, + coords={ + "unit": keys, + **{ + str(feature_name): e[:-1] + np.diff(e) / 2 + for feature_name, e in zip(feature_names, bin_edges) + }, + }, + attrs={"occupancy": occupancy, "bin_edges": bin_edges, "fs": fs}, + ) + if return_pandas: + return tcs.to_pandas().T + else: + return tcs + + +def compute_response_per_epoch(data, epochs_dict, return_pandas=False): + """ + Compute mean response per epoch, given a dictionary of epochs. + + Parameters + ---------- + data : TsGroup, TsdFrame, Ts, Tsd + The data for which the tuning curves will be computed. + epochs_dict : dict + Dictionary of IntervalSets. + return_pandas : bool, optional + If True, the function returns a pandas.DataFrame instead of an xarray.DataArray. + + Returns + ------- + xarray.DataArray + A tensor containing the tuning curves with labeled epochs. + + Examples + -------- + This function is typically used for a set of discrete stimuli being presented for multiple epochs. + The stimulus epochs can overlap, though note that epochs within an IntervalSet can not overlap. + + >>> import pynapple as nap + >>> import numpy as np; np.random.seed(42) + >>> epochs_dict = { + ... "stim0": nap.IntervalSet(start=0, end=30), + ... "stim1":nap.IntervalSet(start=60, end=90) + ... } + >>> group = nap.TsGroup({ + ... 1: nap.Ts(np.arange(0, 100, 0.1)), + ... 2: nap.Ts(np.arange(0, 100, 0.2)) + ... }) + >>> tcs = nap.compute_response_per_epoch(group, epochs_dict) + >>> tcs + Size: 32B + array([[10.03333333, 10.03333333], + [ 5.03333333, 5.03333333]]) + Coordinates: + * unit (unit) int64 16B 1 2 + * epochs (epochs) >> frame = nap.TsdFrame(d=np.random.rand(2000, 3), t=np.arange(0, 100, 0.05)) + >>> tcs = nap.compute_response_per_epoch(frame, epochs_dict) + >>> tcs + Size: 48B + array([[0.50946668, 0.50897635], + [0.48343249, 0.48191892], + [0.50063158, 0.48748094]]) + Coordinates: + * unit (unit) int64 24B 0 1 2 + * epochs (epochs) >> import pynapple as nap + >>> import numpy as np; np.random.seed(42) + >>> epoch = nap.IntervalSet([0, 100]) + >>> t = np.arange(0, 100, 0.01) + >>> feature = nap.Tsd(t=t, d=np.clip(t*0.01 + np.random.normal(0, 0.02, len(t)), 0, 1), time_support=epoch) + >>> group = nap.TsGroup({ + ... 1: nap.Ts(t[(feature.values >= 0.2) & (feature.values < 0.3)]), + ... 2: nap.Ts(t[(feature.values >= 0.7) & (feature.values < 0.8)]) + ... }, time_support=epoch) + >>> tcs = nap.compute_tuning_curves(group, feature, bins=10) + >>> tcs + Size: 160B + array([[ 0., 0., 100., 0., 0., 0., 0., 0., 0., 0.], + [ 0., 0., 0., 0., 0., 0., 0., 100., 0., 0.]]) + Coordinates: + * unit (unit) int64 16B 1 2 + * 0 (0) float64 80B 0.05 0.15 0.25 0.35 0.45 0.55 0.65 0.75 0.85 0.95 + Attributes: + occupancy: [ 985. 1009. 1014. 996. 993. 1008. 991. 1008. 999. 997.] + bin_edges: [array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ])] + >>> MI = nap.compute_mutual_information(tcs) + >>> MI + bits/sec bits/spike + 1 33.480966 3.301870 + 2 33.369159 3.310432 + """ + if not isinstance(tuning_curves, xr.DataArray): + raise TypeError( + "tuning_curves should be an xr.DataArray as computed by compute_tuning_curves." + ) + + if "occupancy" not in tuning_curves.attrs: + raise ValueError("No occupancy found in tuning curves.") + occupancy = tuning_curves.attrs["occupancy"] + occupancy = occupancy / np.nansum(occupancy) + + fx = tuning_curves.values + axes = tuple(range(1, fx.ndim)) + fr_keepdims = np.nansum(fx * occupancy, axis=axes, keepdims=True) + fr_scalar = np.squeeze(fr_keepdims, axis=axes) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + fxfr = fx / fr_keepdims + logfx = np.log2(fxfr) + logfx[~np.isfinite(logfx)] = 0.0 + MI_bits_per_sec = np.nansum(occupancy * fx * logfx, axis=axes) + with np.errstate(divide="ignore", invalid="ignore"): + MI_bits_per_spike = MI_bits_per_sec / fr_scalar + + return pd.DataFrame( + data=np.stack([MI_bits_per_sec, MI_bits_per_spike], axis=1), + index=tuning_curves.coords["unit"], + columns=["bits/sec", "bits/spike"], + ) + + +# ===================================================================================== +# OLD FUNCTIONS, DEPRECATED +# ===================================================================================== + + def _validate_tuning_inputs(func): @wraps(func) def wrapper(*args, **kwargs): @@ -56,7 +566,7 @@ def wrapper(*args, **kwargs): if not isinstance(kwargs["dict_ep"], dict): raise TypeError("dict_ep should be a dictionary of IntervalSet") if not all( - [isinstance(v, nap.IntervalSet) for v in kwargs["dict_ep"].values()] + isinstance(v, nap.IntervalSet) for v in kwargs["dict_ep"].values() ): raise TypeError("dict_ep argument should contain only IntervalSet.") if "tc" in kwargs: @@ -82,309 +592,159 @@ def wrapper(*args, **kwargs): @_validate_tuning_inputs -def compute_discrete_tuning_curves(group, dict_ep): +def compute_1d_tuning_curves(group, feature, nb_bins, ep=None, minmax=None): """ - Compute discrete tuning curves of a TsGroup using a dictionary of epochs. - The function returns a pandas DataFrame with each row being a key of the dictionary of epochs - and each column being a neurons. - - This function can typically being used for a set of stimulus being presented for multiple epochs. - An example of the dictionary is : - - >>> dict_ep = { - "stim0": nap.IntervalSet(start=0, end=1), - "stim1":nap.IntervalSet(start=2, end=3) - } - In this case, the function will return a pandas DataFrame : - - >>> tc - neuron0 neuron1 neuron2 - stim0 0 Hz 1 Hz 2 Hz - stim1 3 Hz 4 Hz 5 Hz - - - Parameters - ---------- - group : nap.TsGroup - The group of Ts/Tsd for which the tuning curves will be computed - dict_ep : dict - Dictionary of IntervalSets - - Returns - ------- - pandas.DataFrame - Table of firing rate for each neuron and each IntervalSet - - Raises - ------ - RuntimeError - If group is not a TsGroup object. + .. deprecated:: 0.9.2 + `compute_1d_tuning_curves` will be removed in Pynapple 1.0.0, it is replaced by + `compute_tuning_curves` because the latter works for N dimensions. """ - idx = np.sort(list(dict_ep.keys())) - tuning_curves = pd.DataFrame(index=idx, columns=list(group.keys()), data=0.0) - - for k in dict_ep.keys(): - for n in group.keys(): - tuning_curves.loc[k, n] = float(len(group[n].restrict(dict_ep[k]))) - - tuning_curves.loc[k] = tuning_curves.loc[k] / dict_ep[k].tot_length("s") - - return tuning_curves + warnings.warn( + "compute_1d_tuning_curves is deprecated and will be removed in a future version;" + "use compute_tuning_curves instead.", + FutureWarning, + stacklevel=2, + ) + return ( + compute_tuning_curves( + group, + feature, + nb_bins, + range=None if minmax is None else [minmax], + epochs=ep, + ) + .to_pandas() + .T + ) @_validate_tuning_inputs -def compute_1d_tuning_curves(group, feature, nb_bins, ep=None, minmax=None): +def compute_1d_tuning_curves_continuous( + tsdframe, feature, nb_bins, ep=None, minmax=None +): """ - Computes 1-dimensional tuning curves relative to a 1d feature. - - Parameters - ---------- - group : TsGroup - The group of Ts/Tsd for which the tuning curves will be computed - feature : Tsd (or TsdFrame with 1 column only) - The 1-dimensional target feature (e.g. head-direction) - nb_bins : int - Number of bins in the tuning curve - ep : IntervalSet, optional - The epoch on which tuning curves are computed. - If None, the epoch is the time support of the feature. - minmax : tuple or list, optional - The min and max boundaries of the tuning curves. - If None, the boundaries are inferred from the target feature - - Returns - ------- - pandas.DataFrame - DataFrame to hold the tuning curves - - Raises - ------ - RuntimeError - If group is not a TsGroup object. - + .. deprecated:: 0.9.2 + `compute_1d_tuning_curves` will be removed in Pynapple 1.0.0, it is replaced by + `compute_tuning_curves` because the latter works for N dimensions and continuous data. """ - if minmax is not None and len(minmax) != 2: - raise ValueError("minmax should be of length 2.") - if ep is None: - ep = feature.time_support - - if minmax is None: - bins = np.linspace(np.nanmin(feature), np.nanmax(feature), nb_bins + 1) - else: - bins = np.linspace(minmax[0], minmax[1], nb_bins + 1) - - idx = bins[0:-1] + np.diff(bins) / 2 - - tuning_curves = pd.DataFrame(index=idx, columns=list(group.keys())) - - group_value = group.value_from(feature, ep) - - occupancy, _ = np.histogram(feature.restrict(ep).values, bins) - - for k in group_value: - count, _ = np.histogram(group_value[k].values, bins) - count = count / occupancy - tuning_curves[k] = count - tuning_curves[k] = count * feature.rate - - return tuning_curves + warnings.warn( + "compute_1d_tuning_curves_continuous is deprecated and will be removed in a future version;" + "use compute_tuning_curves instead.", + FutureWarning, + stacklevel=2, + ) + return ( + compute_tuning_curves( + tsdframe, + feature, + nb_bins, + range=None if minmax is None else [minmax], + epochs=ep, + ) + .to_pandas() + .T + ) @_validate_tuning_inputs def compute_2d_tuning_curves(group, features, nb_bins, ep=None, minmax=None): """ - Computes 2-dimensional tuning curves relative to a 2d features - - Parameters - ---------- - group : TsGroup - The group of Ts/Tsd for which the tuning curves will be computed - features : TsdFrame - The 2d features (i.e. 2 columns features). - nb_bins : int or tuple - Number of bins in the tuning curves (separate for 2 feature dimensions if tuple provided) - ep : IntervalSet, optional - The epoch on which tuning curves are computed. - If None, the epoch is the time support of the feature. - minmax : tuple or list, optional - The min and max boundaries of the tuning curves given as: - (minx, maxx, miny, maxy) - If None, the boundaries are inferred from the target features - - Returns - ------- - tuple - A tuple containing: \n - tc (dict): Dictionary of the tuning curves with dimensions (nb_bins, nb_bins).\n - xy (list): List of bins center in the two dimensions - - Raises - ------ - RuntimeError - If group is not a TsGroup object or if features is not 2 columns only. - + .. deprecated:: 0.9.2 + `compute_1d_tuning_curves` will be removed in Pynapple 1.0.0, it is replaced by + `compute_tuning_curves` because the latter works for N dimensions. """ - if minmax is not None and len(minmax) != 4: - raise ValueError("minmax should be of length 4.") - - if isinstance(nb_bins, tuple) and len(nb_bins) != 2: - raise ValueError( - "nb_bins should be of type int (or tuple with (int, int) for 2D tuning curves)." - ) - - if isinstance(nb_bins, int): - nb_bins = (nb_bins, nb_bins) - - if ep is None: - ep = features.time_support - else: - features = features.restrict(ep) - - groups_value = {} - binsxy = {} - - for i in range(2): - groups_value[i] = group.value_from(features[:, i], ep) - if minmax is None: - bins = np.linspace( - np.nanmin(features[:, i]), np.nanmax(features[:, i]), nb_bins[i] + 1 - ) - else: - bins = np.linspace(minmax[i + i % 2], minmax[i + 1 + i % 2], nb_bins[i] + 1) - binsxy[i] = bins - - occupancy, _, _ = np.histogram2d( - features[:, 0].values.flatten(), - features[:, 1].values.flatten(), - [binsxy[0], binsxy[1]], + warnings.warn( + "compute_2d_tuning_curves is deprecated and will be removed in a future version;" + "use compute_tuning_curves instead.", + FutureWarning, + stacklevel=2, ) - - tc = {} - for n in group.keys(): - count, _, _ = np.histogram2d( - groups_value[0][n].values.flatten(), - groups_value[1][n].values.flatten(), - [binsxy[0], binsxy[1]], - ) - count = count / occupancy - tc[n] = count * features.rate - - xy = [binsxy[i][0:-1] + np.diff(binsxy[i]) / 2 for i in range(2)] - - return tc, xy + xarray = compute_tuning_curves( + group, + features, + nb_bins, + range=( + None if minmax is None else [[minmax[0], minmax[1]], [minmax[2], minmax[3]]] + ), + epochs=ep, + ) + tcs = {c: xarray.sel(unit=c).values for c in xarray.coords["unit"].values} + bins = [xarray.coords[dim].values for dim in xarray.coords if dim != "unit"] + return tcs, bins @_validate_tuning_inputs -def compute_1d_mutual_info(tc, feature, ep=None, minmax=None, bitssec=False): +def compute_2d_tuning_curves_continuous( + tsdframe, features, nb_bins, ep=None, minmax=None +): """ - Mutual information of a tuning curve computed from a 1-d feature. - - See: - - Skaggs, W. E., McNaughton, B. L., & Gothard, K. M. (1993). - An information-theoretic approach to deciphering the hippocampal code. - In Advances in neural information processing systems (pp. 1030-1037). - - Parameters - ---------- - tc : pandas.DataFrame or numpy.ndarray - Tuning curves in columns - feature : Tsd (or TsdFrame with 1 column only) - The 1-dimensional target feature (e.g. head-direction) - ep : IntervalSet, optional - The epoch over which the tuning curves were computed - If None, the epoch is the time support of the feature. - minmax : tuple or list, optional - The min and max boundaries of the tuning curves. - If None, the boundaries are inferred from the target feature - bitssec : bool, optional - By default, the function return bits per spikes. - Set to true for bits per seconds - - Returns - ------- - pandas.DataFrame - Spatial Information (default is bits/spikes) + .. deprecated:: 0.9.2 + `compute_1d_tuning_curves` will be removed in Pynapple 1.0.0, it is replaced by + `compute_tuning_curves` because the latter works for N dimensions and continuous data. """ - if isinstance(tc, pd.DataFrame): - columns = tc.columns.values - fx = np.atleast_2d(tc.values) - else: - fx = np.atleast_2d(tc) - columns = np.arange(tc.shape[1]) - - nb_bins = tc.shape[0] + 1 - if minmax is None: - bins = np.linspace(np.nanmin(feature), np.nanmax(feature), nb_bins) - else: - bins = np.linspace(minmax[0], minmax[1], nb_bins) + warnings.warn( + "compute_2d_tuning_curves_continuous is deprecated and will be removed in a future version;" + "use compute_tuning_curves instead.", + FutureWarning, + stacklevel=2, + ) + xarray = compute_tuning_curves( + tsdframe, + features, + nb_bins, + range=( + None if minmax is None else [[minmax[0], minmax[1]], [minmax[2], minmax[3]]] + ), + epochs=ep, + ) + tcs = {c: xarray.sel(unit=c).values for c in xarray.coords["unit"].values} + bins = [xarray.coords[dim].values for dim in xarray.coords if dim != "unit"] + return tcs, bins - if isinstance(ep, nap.IntervalSet): - occupancy, _ = np.histogram(feature.restrict(ep).values, bins) - else: - occupancy, _ = np.histogram(feature.values, bins) - occupancy = occupancy / occupancy.sum() - occupancy = occupancy[:, np.newaxis] - fr = np.sum(fx * occupancy, 0) - fxfr = fx / fr - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - logfx = np.log2(fxfr) - logfx[np.isinf(logfx)] = 0.0 - SI = np.sum(occupancy * fx * logfx, 0) +@_validate_tuning_inputs +def compute_discrete_tuning_curves(group, dict_ep): + """ + .. deprecated:: 0.9.2 + `compute_discrete_tuning_curves` will be removed in Pynapple 1.0.0, it is replaced by + `compute_response_per_epoch`. + """ + warnings.warn( + "compute_discrete_tuning_curves is deprecated and will be removed in a future version;" + "use compute_response_per_epoch instead.", + FutureWarning, + stacklevel=2, + ) - if bitssec: - SI = pd.DataFrame(index=columns, columns=["SI"], data=SI) - return SI - else: - SI = SI / fr - SI = pd.DataFrame(index=columns, columns=["SI"], data=SI) - return SI + return compute_response_per_epoch(group, dict_ep, return_pandas=True) @_validate_tuning_inputs def compute_2d_mutual_info(dict_tc, features, ep=None, minmax=None, bitssec=False): """ - Mutual information of a tuning curve computed from 2-d features. - - See: - - Skaggs, W. E., McNaughton, B. L., & Gothard, K. M. (1993). - An information-theoretic approach to deciphering the hippocampal code. - In Advances in neural information processing systems (pp. 1030-1037). - - Parameters - ---------- - dict_tc : dict of numpy.ndarray or numpy.ndarray - If array, first dimension should be the neuron - features : TsdFrame - The 2 columns features that were used to compute the tuning curves - ep : IntervalSet, optional - The epoch over which the tuning curves were computed - If None, the epoch is the time support of the feature. - minmax : tuple or list, optional - The min and max boundaries of the tuning curves. - If None, the boundaries are inferred from the target features - bitssec : bool, optional - By default, the function return bits per spikes. - Set to true for bits per seconds - - Returns - ------- - pandas.DataFrame - Spatial Information (default is bits/spikes) + .. deprecated:: 0.9.2 + `compute_2d_mutual_info` will be removed in Pynapple 1.0.0, it is replaced by + `compute_mutual_information` because the latter works for N dimensions. """ - # A bit tedious here + warnings.warn( + "compute_2d_mutual_info is deprecated and will be removed in a future version;" + "use compute_mutual_information instead.", + FutureWarning, + stacklevel=2, + ) if type(dict_tc) is dict: - fx = np.array([dict_tc[i] for i in dict_tc.keys()]) - idx = list(dict_tc.keys()) + tcs = xr.DataArray( + np.array([dict_tc[i] for i in dict_tc.keys()]), + coords={"unit": list(dict_tc.keys())}, + dims=["unit", "0", "1"], + ) else: - fx = dict_tc - idx = np.arange(len(dict_tc)) - - nb_bins = (fx.shape[1] + 1, fx.shape[2] + 1) + tcs = xr.DataArray( + dict_tc, + coords={"unit": np.arange(len(dict_tc))}, + dims=["unit", "0", "1"], + ) + nb_bins = (tcs.shape[1] + 1, tcs.shape[2] + 1) bins = [] for i in range(2): if minmax is None: @@ -408,190 +768,48 @@ def compute_2d_mutual_info(dict_tc, features, ep=None, minmax=None, bitssec=Fals ) occupancy = occupancy / occupancy.sum() - fr = np.nansum(fx * occupancy, (1, 2)) - fr = fr[:, np.newaxis, np.newaxis] - fxfr = fx / fr - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - logfx = np.log2(fxfr) - logfx[np.isinf(logfx)] = 0.0 - SI = np.nansum(occupancy * fx * logfx, (1, 2)) + tcs.attrs["occupancy"] = occupancy + MI = compute_mutual_information(tcs) - if bitssec: - SI = pd.DataFrame(index=idx, columns=["SI"], data=SI) - return SI - else: - SI = SI / fr[:, 0, 0] - SI = pd.DataFrame(index=idx, columns=["SI"], data=SI) - return SI + column = "bits/sec" if bitssec else "bits/spike" + return MI[[column]].rename({column: "SI"}, axis=1) @_validate_tuning_inputs -def compute_1d_tuning_curves_continuous( - tsdframe, feature, nb_bins, ep=None, minmax=None -): +def compute_1d_mutual_info(tc, feature, ep=None, minmax=None, bitssec=False): """ - Computes 1-dimensional tuning curves relative to a feature with continuous data. - - Parameters - ---------- - tsdframe : Tsd or TsdFrame - Input data (e.g. continuous calcium data - where each column is the calcium activity of one neuron) - feature : Tsd (or TsdFrame with 1 column only) - The 1-dimensional target feature (e.g. head-direction) - nb_bins : int - Number of bins in the tuning curves - ep : IntervalSet, optional - The epoch on which tuning curves are computed. - If None, the epoch is the time support of the feature. - minmax : tuple or list, optional - The min and max boundaries of the tuning curves. - If None, the boundaries are inferred from the target feature - - Returns - ------- - pandas.DataFrame to hold the tuning curves - - Raises - ------ - RuntimeError - If tsdframe is not a Tsd or a TsdFrame object. - + .. deprecated:: 0.9.2 + `compute_1d_mutual_info` will be removed in Pynapple 1.0.0, it is replaced by + `compute_mutual_information` because the latter works for N dimensions. """ - if minmax is not None and len(minmax) != 2: - raise ValueError("minmax should be of length 2.") - - feature = np.squeeze(feature) - - if isinstance(ep, nap.IntervalSet): - feature = feature.restrict(ep) - tsdframe = tsdframe.restrict(ep) + warnings.warn( + "compute_1d_mutual_info is deprecated and will be removed in a future version;" + "use compute_mutual_information instead.", + FutureWarning, + stacklevel=2, + ) + if isinstance(tc, pd.DataFrame): + tcs = xr.DataArray( + tc.values.T, coords={"unit": tc.columns.values, "0": tc.index} + ) else: - tsdframe = tsdframe.restrict(feature.time_support) - - if isinstance(tsdframe, nap.Tsd): - tsdframe = tsdframe[:, np.newaxis] + tcs = xr.DataArray( + tc.T, coords={"unit": np.arange(tc.shape[1])}, dims=["unit", "0"] + ) + nb_bins = tc.shape[0] + 1 if minmax is None: - bins = np.linspace(np.nanmin(feature), np.nanmax(feature), nb_bins + 1) + bins = np.linspace(np.nanmin(feature), np.nanmax(feature), nb_bins) else: - bins = np.linspace(minmax[0], minmax[1], nb_bins + 1) - - align_times = tsdframe.value_from(feature) - idx = np.digitize(align_times.values, bins) - 1 - - tc = np.zeros((len(bins) - 1, tsdframe.shape[1])) - for i in range(0, nb_bins): - tc[i] = np.mean(tsdframe.values[idx == i], axis=0) - tc[np.isnan(tc)] = 0.0 - - # Assigning nans if bin is not visited. - occupancy, _ = np.histogram(feature, bins) - tc[occupancy == 0.0] = np.nan - - tc = pd.DataFrame( - index=bins[0:-1] + np.diff(bins) / 2, data=tc, columns=tsdframe.columns - ) - return tc - - -@_validate_tuning_inputs -def compute_2d_tuning_curves_continuous( - tsdframe, features, nb_bins, ep=None, minmax=None -): - """ - Computes 2-dimensional tuning curves relative to a 2d feature with continuous data. - - Parameters - ---------- - tsdframe : Tsd or TsdFrame - Input data (e.g. continuous calcium data - where each column is the calcium activity of one neuron) - features : TsdFrame - The 2d feature (two columns) - nb_bins : int or tuple - Number of bins in the tuning curves (separate for 2 feature dimensions if tuple provided) - ep : IntervalSet, optional - The epoch on which tuning curves are computed. - If None, the epoch is the time support of the feature. - minmax : tuple or list, optional - The min and max boundaries of the tuning curves. - Should be a tuple of minx, maxx, miny, maxy - If None, the boundaries are inferred from the target feature - - Returns - ------- - tuple - A tuple containing: \n - tc (dict): Dictionary of the tuning curves with dimensions (nb_bins, nb_bins).\n - xy (list): List of bins center in the two dimensions - - Raises - ------ - RuntimeError - If tsdframe is not a Tsd/TsdFrame or if features is not 2 columns - - """ - if minmax is not None and len(minmax) != 4: - raise ValueError("minmax should be of length 4.") - - if isinstance(nb_bins, tuple) and len(nb_bins) != 2: - raise ValueError( - "nb_bins should be of type int (or tuple with (int, int) for 2D tuning curves)." - ) + bins = np.linspace(minmax[0], minmax[1], nb_bins) if isinstance(ep, nap.IntervalSet): - features = features.restrict(ep) - tsdframe = tsdframe.restrict(ep) + occupancy, _ = np.histogram(feature.restrict(ep).values, bins) else: - tsdframe = tsdframe.restrict(features.time_support) - - if isinstance(tsdframe, nap.Tsd): - tsdframe = tsdframe[:, np.newaxis] - - if isinstance(nb_bins, int): - nb_bins = (nb_bins, nb_bins) - - binsxy = [] - idxs = [] - - for i in range(2): - if minmax is None: - bins = np.linspace( - np.nanmin(features[:, i]), np.nanmax(features[:, i]), nb_bins[i] + 1 - ) - else: - bins = np.linspace(minmax[i + i % 2], minmax[i + 1 + i % 2], nb_bins[i] + 1) - - align_times = tsdframe.value_from(features[:, i], ep) - idxs.append(np.digitize(align_times.values.flatten(), bins) - 1) - binsxy.append(bins) - - idxs = np.transpose(np.array(idxs)) - - tc = np.zeros((tsdframe.shape[1], nb_bins[0], nb_bins[1])) - - for i in range(nb_bins[0]): - for j in range(nb_bins[1]): - tc[:, i, j] = np.mean( - tsdframe.values[np.logical_and(idxs[:, 0] == i, idxs[:, 1] == j)], 0 - ) - - tc[np.isnan(tc)] = 0.0 - - # Assigning nans if bin is not visited. - occupancy, _, _ = np.histogram2d( - features[:, 0].values.flatten(), - features[:, 1].values.flatten(), - [binsxy[0], binsxy[1]], - ) - occupancy = occupancy[np.newaxis, :, :] - occupancy = np.repeat(occupancy, len(tc), axis=0) - tc[occupancy == 0.0] = np.nan - - xy = [binsxy[i][0:-1] + np.diff(binsxy[i]) / 2 for i in range(2)] - - tc = {c: tc[i] for i, c in enumerate(tsdframe.columns)} + occupancy, _ = np.histogram(feature.values, bins) + occupancy = occupancy / occupancy.sum() + tcs.attrs["occupancy"] = occupancy + MI = compute_mutual_information(tcs) - return tc, xy + column = "bits/sec" if bitssec else "bits/spike" + return MI[[column]].rename({column: "SI"}, axis=1) diff --git a/pyproject.toml b/pyproject.toml index 4b66ca66a..1470f95e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,8 @@ dependencies = [ "pynwb>=2.0.0", "tabulate", "h5py", - "rich" + "rich", + "xarray>=2023.1.0", ] requires-python = ">=3.8" diff --git a/scripts/check_parameter_naming.py b/scripts/check_parameter_naming.py new file mode 100644 index 000000000..a99d12fab --- /dev/null +++ b/scripts/check_parameter_naming.py @@ -0,0 +1,247 @@ +import ast +import difflib +import itertools +import os +import pathlib +from collections import defaultdict +from typing import Dict, List, Optional + +# Pairs of parameter names that are lexically similar but intentionally allowed. + +# During parameter name similarity checks, some pairs of names may be flagged +# as potentially inconsistent due to their high string similarity. This list +# enumerates such known, acceptable pairs that should be *excluded* from warnings. + +# Each pair is stored as a set of two strings (e.g., {"a", "a_1"}), and comparison +# is done using set equality, i.e., order does not matter. + +# These typically include: +# - semantically equivalent alternatives (e.g., {"conv_time_series", "time_series"}) +# - mirrored structures (e.g., {"inhib_a", "inhib_b"}) +# - systematic naming conventions (e.g., {"basis1", "basis2"}) +# - commonly used argument patterns (e.g., {"args", "kwargs"}) +VALID_PAIRS = [ + {"ep", "sep"}, + {"ts", "tsd"}, + {"args", "kwargs"}, + {"channel", "n_channels"}, + {"interval_size", "intervalset"}, + {"new_time_support", "time_support"}, + {"ufunc", "func"}, + {"keys", "key"}, + {"value", "values"}, + *({a, b} for (a, b) in itertools.combinations(["starts", "start", "start1", "start2"], r=2)), + *({a, b} for (a, b) in itertools.combinations(["ends", "end", "end1", "end2"], r=2)), + {"windowsize", "window"}, + {"windowsize", "windows"}, +] + + +def handle_matches( + current_parameter: str, + current_path: str, + matches: List[str], + results: Dict, + valid_pairs: List[set[str]], +): + """ + Handle matched parameter names by updating or creating groups in the results dictionary. + + A parameter is considered valid if it has no matches or if all its matches appear in + `valid_pairs` as a set with the current parameter. Valid parameters are added as new entries + in the results dictionary. Invalid parameters (i.e., those with partial or conflicting matches) + are added to existing groups if any of their matches are already present in those groups. + + Note: This function allows overlapping groups. If `current_parameter` is similar to multiple + parameter groups (e.g., "timin" may match both "time" and "timing"), it will be added to each + of the matching groups independently. + + Parameters + ---------- + current_parameter : + The name of the parameter currently being processed. + + current_path : + The path or context in which the parameter was found (e.g., a file or data structure path). + + matches : + A list of other parameter names that are similar to ``current_parameter``. + + results : + A dictionary of grouped parameters. Keys are group names, and values are dictionaries + containing: + - "unique_names": a set of parameter names in the group. + - "info": a list of (parameter, path) tuples for matched entries. + + valid_pairs : + A list of valid two-element sets. Each set contains a pair of parameter names that are + considered equivalent or compatible. + + """ + # a parameter name is valid if no matches or all matches in valid pairs + list_invalid = [ + match for match in matches if {match, current_parameter} not in valid_pairs + ] + if len(list_invalid) == 0: + # if all matches are valid, create a new group for this parameter + results[current_parameter] = { + "unique_names": {current_parameter}, + "info": [(current_parameter, current_path)], + } + else: + + # if there is an invalid match, then add to existing result entry + for k, v in results.items(): + # Otherwise, add the parameter to any existing groups where it has a match + # + # Note: We *intentionally allow overlapping groups*. If `current_parameter` + # is similar to multiple different parameter groups + # (e.g. "timin" may be similar to both "time" and "timing", but "time" and "timing" may + # belong to two different groups), + # it will be added to each of those groups. + is_in_category = any(match in v["unique_names"] for match in list_invalid) + if is_in_category: + v["info"].append((current_parameter, current_path)) + v["unique_names"].add(current_parameter) + + +def extract_parameters_from_ast( + tree: ast.Module, + file_path: pathlib.Path, + results: Dict, + valid_pairs: List[set[str]], + unique_param_names: set, + similarity_cutoff: float, +): + + class ParamVisitor(ast.NodeVisitor): + def __init__(self): + self.class_name = None + + def visit_ClassDef(self, node): + prev_class = self.class_name + self.class_name = node.name + self.generic_visit(node) + self.class_name = prev_class + + def visit_FunctionDef(self, node): + qualified_name = ( + f"{self.class_name}.{node.name}" if self.class_name else node.name + ) + param_names = [str(arg.arg) for arg in node.args.args] + for par in param_names: + # if perfect match is present just add there + if par in results: + results[par]["unique_names"].add(par) + results[par]["info"].append( + (par, f"{file_path.as_posix()}:{qualified_name}") + ) + continue + + matches = difflib.get_close_matches( + par, unique_param_names, n=100, cutoff=similarity_cutoff + ) + handle_matches( + par, + f"{file_path.as_posix()}:{qualified_name}", + matches, + results, + valid_pairs, + ) + unique_param_names.add(par) + self.generic_visit(node) + + def visit_AsyncFunctionDef(self, node): + self.visit_FunctionDef(node) + + ParamVisitor().visit(tree) + + +def collect_similar_parameter_names_ast( + root_dir: str | pathlib.Path, + similarity_cutoff: float = 0.8, + valid_pairs: Optional[List[set[str]]] = None, +) -> Dict[str, Dict]: + if valid_pairs is None: + valid_pairs = VALID_PAIRS + + results = {} + unique_param_names = set() + + for dirpath, _, filenames in os.walk(root_dir): + dirpath = pathlib.Path(dirpath) + + if "third_party" in dirpath.parts: + continue + + for filename in filenames: + if filename.endswith(".py"): + full_path = dirpath / filename + try: + with open(full_path, "r", encoding="utf-8") as f: + source = f.read() + tree = ast.parse(source, filename=full_path) + extract_parameters_from_ast( + tree, + full_path, + results, + valid_pairs, + unique_param_names, + similarity_cutoff, + ) + except (UnicodeDecodeError, FileNotFoundError): + continue + + return results + + +if __name__ == "__main__": + import argparse + import logging + import sys + + default_path = pathlib.Path(__file__).parent.parent / "pynapple" + + parser = argparse.ArgumentParser( + description="Check parameter naming consistency using AST." + ) + parser.add_argument( + "--path", + "-p", + type=pathlib.Path, + help="Root path to the package (source folder).", + default=default_path, + ) + parser.add_argument( + "--threshold", + "-t", + type=float, + default=0.8, + help="Similarity threshold for parameter name grouping.", + ) + args = parser.parse_args() + + logger = logging.getLogger("check_parameter_naming") + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + params = collect_similar_parameter_names_ast( + args.path, similarity_cutoff=args.threshold + ) + invalid = [name for name, d in params.items() if len(d["unique_names"]) > 1] + + if invalid: + msg_lines = ["Inconsistency in parameter naming found!\n"] + for name in invalid: + msg_lines.append(f"{name}:\n") + grouped_info = defaultdict(list) + for param_name, path in sorted(params[name]["info"], key=lambda x: x[1]): + grouped_info[param_name].append(path) + for param_name in sorted(params[name]["unique_names"]): + msg_lines.append(f"\t- {param_name}:\n") + for path in grouped_info[param_name]: + msg_lines.append(f"\t\t- {path}\n") + msg_lines.append("\n") + logger.error("".join(msg_lines)) + sys.exit(1) + else: + logger.info("No parameter naming inconsistencies found.") diff --git a/tests/test_decoding.py b/tests/test_decoding.py index c5f95802c..7b3838212 100644 --- a/tests/test_decoding.py +++ b/tests/test_decoding.py @@ -1,22 +1,395 @@ -# -*- coding: utf-8 -*- -# @Author: gviejo -# @Date: 2022-03-30 11:16:39 -# @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-01-29 11:15:41 -#!/usr/bin/env python - """Tests of decoding for `pynapple` package.""" +from contextlib import nullcontext as does_not_raise +from itertools import product + import numpy as np -import pandas as pd import pytest import pynapple as nap +def get_testing_set_n(n_features=1, binned=False, bin_size=1.0, time_units="s"): + combos = np.array(list(product([0, 1], repeat=n_features))) # (2^F, F) + reps = 5 + feature_data = np.tile(combos, (reps, 1)) # (T, F) + times = np.arange(len(feature_data)) + + features = nap.TsdFrame(t=times, d=feature_data) + epochs = nap.IntervalSet(start=0, end=len(times)) + + data = nap.TsGroup( + { + i: nap.Ts(t=times[np.all(feature_data == combo, axis=1)]) + for i, combo in enumerate(combos) + } + ) + + if binned: + frame = data.count(bin_size=bin_size, ep=epochs, time_units=time_units) + data = nap.TsdFrame( + frame.times() - 0.5, + frame.values, + time_support=epochs, + ) + + tuning_curves = nap.compute_tuning_curves( + data, features, bins=2, range=[(-0.5, 1.5)] * n_features + ) + + return { + "features": features, + "tuning_curves": tuning_curves, + "data": data, + "epochs": epochs, + "bin_size": bin_size, + } + + +@pytest.mark.filterwarnings("ignore") +@pytest.mark.parametrize( + "overwrite_default_args, expectation", + [ + # tuning_curves + ( + {"tuning_curves": []}, + pytest.raises( + TypeError, + match="tuning_curves should be an xarray.DataArray as computed by compute_tuning_curves.", + ), + ), + ( + {"tuning_curves": 1}, + pytest.raises( + TypeError, + match="tuning_curves should be an xarray.DataArray as computed by compute_tuning_curves.", + ), + ), + ( + {"tuning_curves": get_testing_set_n()["tuning_curves"].to_pandas().T}, + pytest.raises( + TypeError, + match="tuning_curves should be an xarray.DataArray as computed by compute_tuning_curves.", + ), + ), + ( + {"tuning_curves": get_testing_set_n(2)["tuning_curves"]}, + pytest.raises( + ValueError, + match="Different shapes for tuning_curves and data.", + ), + ), + ( + {"tuning_curves": get_testing_set_n(2, binned=True)["tuning_curves"]}, + pytest.raises( + ValueError, + match="Different shapes for tuning_curves and data.", + ), + ), + ( + { + "tuning_curves": get_testing_set_n()["tuning_curves"].assign_coords( + unit=[2, 3] + ) + }, + pytest.raises( + ValueError, + match="Different indices for tuning curves and data keys.", + ), + ), + ( + { + "tuning_curves": get_testing_set_n(binned=True)[ + "tuning_curves" + ].assign_coords(unit=[2, 3]) + }, + pytest.raises( + ValueError, + match="Different indices for tuning curves and data keys.", + ), + ), + ({}, does_not_raise()), + (get_testing_set_n(1), does_not_raise()), + (get_testing_set_n(2), does_not_raise()), + # data + ( + {"data": []}, + pytest.raises( + TypeError, + match="Unknown format for data.", + ), + ), + ( + {"data": 1}, + pytest.raises( + TypeError, + match="Unknown format for data.", + ), + ), + ( + {"data": get_testing_set_n(2)["data"]}, + pytest.raises( + ValueError, + match="Different shapes for tuning_curves and data.", + ), + ), + ( + { + "data": nap.TsGroup( + {2: nap.Ts(t=np.arange(0, 50)), 3: nap.Ts(t=np.arange(0, 50))} + ) + }, + pytest.raises( + ValueError, + match="Different indices for tuning curves and data keys.", + ), + ), + ( + {"data": get_testing_set_n(binned=True)["data"]}, + does_not_raise(), + ), + ( + get_testing_set_n(2, binned=True), + does_not_raise(), + ), + ( + get_testing_set_n(3, binned=True), + does_not_raise(), + ), + # bin_size + ( + {"data": get_testing_set_n(binned=True)["data"], "bin_size": None}, + pytest.raises( + ValueError, + match="bin_size should be a number.", + ), + ), + ( + {"data": get_testing_set_n(binned=True)["data"], "bin_size": "1.0"}, + pytest.raises( + ValueError, + match="bin_size should be a number.", + ), + ), + ( + {"data": get_testing_set_n(binned=True)["data"], "bin_size": 2.0}, + pytest.warns( + UserWarning, + match="passed bin_size is different from actual data bin size.", + ), + ), + ( + {"data": get_testing_set_n(binned=True)["data"], "bin_size": 1.0}, + does_not_raise(), + ), + ], +) +def test_decode_input_errors(overwrite_default_args, expectation): + default_args = get_testing_set_n() + default_args.update(overwrite_default_args) + default_args.pop("features") + with expectation: + nap.decode_bayes(**default_args) + nap.decode_template(**default_args) + + +@pytest.mark.filterwarnings("ignore") +@pytest.mark.parametrize( + "overwrite_default_args, expectation", + [ + # smoothing + ( + {"sliding_window_size": "1"}, + pytest.raises( + ValueError, + match="sliding_window_size should be a integer.", + ), + ), + ( + {"sliding_window_size": 0}, + pytest.raises( + ValueError, + match="sliding_window_size should be >= 1.", + ), + ), + ( + {"sliding_window_size": 1}, + does_not_raise(), + ), + ( + {"sliding_window_size": None}, + does_not_raise(), + ), + ( + { + "data": get_testing_set_n(binned=True)["data"], + "sliding_window_size": 1, + }, + does_not_raise(), + ), + ( + { + **get_testing_set_n(2, binned=True), + "sliding_window_size": 1, + }, + does_not_raise(), + ), + ], +) +def test_decode_input_errors_sliding_window_size(overwrite_default_args, expectation): + default_args = get_testing_set_n() + default_args.update(overwrite_default_args) + default_args.pop("features") + with expectation: + nap.decode_bayes(**default_args) + nap.decode_template(**default_args) + + +@pytest.mark.filterwarnings("ignore") +@pytest.mark.parametrize( + "overwrite_default_args, expectation", + [ + # uniform_prior + ( + { + "uniform_prior": False, + "tuning_curves": (lambda x: (x.attrs.clear(), x)[1])( + get_testing_set_n()["tuning_curves"] + ), + }, + pytest.raises( + ValueError, + match="uniform_prior set to False but no occupancy found in tuning curves.", + ), + ), + ( + {"uniform_prior": True}, + does_not_raise(), + ), + ], +) +def test_decode_bayes_input_errors(overwrite_default_args, expectation): + default_args = get_testing_set_n() + default_args.update(overwrite_default_args) + default_args.pop("features") + with expectation: + nap.decode_bayes(**default_args) + + +@pytest.mark.parametrize("uniform_prior", [True, False]) +@pytest.mark.parametrize("binned", [True, False]) +@pytest.mark.parametrize("sliding_window_size", [None, 1, 3]) +@pytest.mark.parametrize( + "n_features, bin_size, time_units", + [ + (1, 1.0, "s"), + (2, 1.0, "s"), + (3, 1.0, "s"), + (2, 1.0, "s"), + (3, 1.0, "s"), + (1, 1e3, "ms"), + (2, 1e3, "ms"), + (3, 1e3, "ms"), + (2, 1e3, "ms"), + (3, 1e3, "ms"), + (1, 1e6, "us"), + (2, 1e6, "us"), + (3, 1e6, "us"), + (2, 1e6, "us"), + (3, 1e6, "us"), + ], +) +def test_decode_bayes( + n_features, binned, bin_size, sliding_window_size, time_units, uniform_prior +): + features, tuning_curves, data, epochs, bin_size = get_testing_set_n( + n_features, binned=binned, bin_size=bin_size, time_units=time_units + ).values() + decoded, proba = nap.decode_bayes( + tuning_curves=tuning_curves, + data=data, + epochs=epochs, + bin_size=bin_size, + sliding_window_size=sliding_window_size, + time_units=time_units, + uniform_prior=uniform_prior, + ) + assert isinstance(decoded, nap.Tsd if features.shape[1] == 1 else nap.TsdFrame) + + if sliding_window_size is None or sliding_window_size == 1: + np.testing.assert_array_almost_equal(decoded.values, features.values.squeeze()) + + assert isinstance( + proba, + nap.TsdFrame if features.shape[1] == 1 else nap.TsdTensor, + ) + expected_proba = np.zeros((len(features), *tuning_curves.shape[1:])) + target_indices = [np.arange(len(features))] + [ + features[:, d] for d in range(features.shape[1]) + ] + expected_proba[tuple(target_indices)] = 1.0 + np.testing.assert_array_almost_equal(proba.values, expected_proba) + + +@pytest.mark.parametrize("metric", ["correlation", "euclidean", "cosine"]) +@pytest.mark.parametrize("binned", [True, False]) +@pytest.mark.parametrize("sliding_window_size", [None, 1, 3]) +@pytest.mark.parametrize( + "n_features, bin_size, time_units", + [ + (1, 1.0, "s"), + (2, 1.0, "s"), + (3, 1.0, "s"), + (2, 1.0, "s"), + (3, 1.0, "s"), + (1, 1e3, "ms"), + (2, 1e3, "ms"), + (3, 1e3, "ms"), + (2, 1e3, "ms"), + (3, 1e3, "ms"), + (1, 1e6, "us"), + (2, 1e6, "us"), + (3, 1e6, "us"), + (2, 1e6, "us"), + (3, 1e6, "us"), + ], +) +def test_decode_template( + metric, n_features, binned, bin_size, sliding_window_size, time_units +): + features, tuning_curves, data, epochs, bin_size = get_testing_set_n( + n_features, binned=binned, bin_size=bin_size, time_units=time_units + ).values() + decoded, dist = nap.decode_template( + tuning_curves=tuning_curves, + data=data, + epochs=epochs, + metric=metric, + bin_size=bin_size, + sliding_window_size=sliding_window_size, + time_units=time_units, + ) + assert isinstance(decoded, nap.Tsd if features.shape[1] == 1 else nap.TsdFrame) + assert isinstance( + dist, + nap.TsdFrame if features.shape[1] == 1 else nap.TsdTensor, + ) + + if sliding_window_size is None or sliding_window_size == 1: + np.testing.assert_array_almost_equal( + decoded.values.astype(int), features.values.squeeze() + ) + + +# ------------------------------------------------------------------------------------ +# OLD DECODING TESTS +# ------------------------------------------------------------------------------------ + + +@pytest.mark.filterwarnings("ignore") def get_testing_set_1d(): feature = nap.Tsd(t=np.arange(0, 100, 1), d=np.repeat(np.arange(0, 2), 50)) - group = nap.TsGroup({i: nap.Ts(t=np.arange(0, 50) + 50 * i) for i in range(2)}) + group = nap.TsGroup({i: nap.Ts(t=np.arange(0, 50) + 50 * i) for i in range(3)}) tc = nap.compute_1d_tuning_curves( group=group, feature=feature, nb_bins=2, minmax=(-0.5, 1.5) ) @@ -24,6 +397,7 @@ def get_testing_set_1d(): return feature, group, tc, ep +@pytest.mark.filterwarnings("ignore") def test_decode_1d(): feature, group, tc, ep = get_testing_set_1d() decoded, proba = nap.decode_1d(tc, group, ep, bin_size=1) @@ -38,6 +412,7 @@ def test_decode_1d(): np.testing.assert_array_almost_equal(proba.values, tmp) +@pytest.mark.filterwarnings("ignore") def test_decode_1d_with_TsdFrame(): feature, group, tc, ep = get_testing_set_1d() count = group.count(bin_size=1, ep=ep) @@ -53,6 +428,7 @@ def test_decode_1d_with_TsdFrame(): np.testing.assert_array_almost_equal(proba.values, tmp) +@pytest.mark.filterwarnings("ignore") def test_decode_1d_with_feature(): feature, group, tc, ep = get_testing_set_1d() decoded, proba = nap.decode_1d(tc, group, ep, bin_size=1, feature=feature) @@ -68,6 +444,7 @@ def test_decode_1d_with_feature(): np.testing.assert_array_almost_equal(proba.values, tmp) +@pytest.mark.filterwarnings("ignore") def test_decode_1d_with_dict(): feature, group, tc, ep = get_testing_set_1d() group = dict(group) @@ -84,6 +461,7 @@ def test_decode_1d_with_dict(): np.testing.assert_array_almost_equal(proba.values, tmp) +@pytest.mark.filterwarnings("ignore") def test_decode_1d_with_wrong_feature(): feature, group, tc, ep = get_testing_set_1d() with pytest.raises(RuntimeError) as e_info: @@ -91,6 +469,7 @@ def test_decode_1d_with_wrong_feature(): assert str(e_info.value) == "Unknown format for feature in decode_1d" +@pytest.mark.filterwarnings("ignore") def test_decode_1d_with_time_units(): feature, group, tc, ep = get_testing_set_1d() for t, tu in zip([1, 1e3, 1e6], ["s", "ms", "us"]): @@ -98,25 +477,7 @@ def test_decode_1d_with_time_units(): np.testing.assert_array_almost_equal(feature.values, decoded.values) -def test_decoded_1d_raise_errors(): - feature, group, tc, ep = get_testing_set_1d() - with pytest.raises(Exception) as e_info: - nap.decode_1d(tc, np.random.rand(10), ep, 1) - assert str(e_info.value) == "Unknown format for group" - - feature, group, tc, ep = get_testing_set_1d() - tc[2] = np.random.rand(2) - with pytest.raises(Exception) as e_info: - nap.decode_1d(tc, group, ep, 1) - assert str(e_info.value) == "Different shapes for tuning_curves and group" - - feature, group, tc, ep = get_testing_set_1d() - tc.columns = [0, 2] - with pytest.raises(Exception) as e_info: - nap.decode_1d(tc, group, ep, 1) - assert str(e_info.value) == "Different indices for tuning curves and group keys" - - +@pytest.mark.filterwarnings("ignore") def get_testing_set_2d(): features = nap.TsdFrame( t=np.arange(0, 100, 1), @@ -138,12 +499,13 @@ def get_testing_set_2d(): return features, group, tc, ep, tuple(xy) +@pytest.mark.filterwarnings("ignore") def test_decode_2d(): features, group, tc, ep, xy = get_testing_set_2d() decoded, proba = nap.decode_2d(tc, group, ep, 1, xy) assert isinstance(decoded, nap.TsdFrame) - assert isinstance(proba, np.ndarray) + assert isinstance(proba, nap.TsdTensor) np.testing.assert_array_almost_equal(features.values, decoded.values) assert len(decoded) == 100 assert len(proba) == 100 @@ -158,13 +520,14 @@ def test_decode_2d(): np.testing.assert_array_almost_equal(proba[:, :, 1], tmp) +@pytest.mark.filterwarnings("ignore") def test_decode_2d_with_TsdFrame(): features, group, tc, ep, xy = get_testing_set_2d() count = group.count(bin_size=1, ep=ep) decoded, proba = nap.decode_2d(tc, count, ep, 1, xy) assert isinstance(decoded, nap.TsdFrame) - assert isinstance(proba, np.ndarray) + assert isinstance(proba, nap.TsdTensor) np.testing.assert_array_almost_equal(features.values, decoded.values) assert len(decoded) == 100 assert len(proba) == 100 @@ -179,13 +542,14 @@ def test_decode_2d_with_TsdFrame(): np.testing.assert_array_almost_equal(proba[:, :, 1], tmp) +@pytest.mark.filterwarnings("ignore") def test_decode_2d_with_dict(): features, group, tc, ep, xy = get_testing_set_2d() group = dict(group) decoded, proba = nap.decode_2d(tc, group, ep, 1, xy) assert isinstance(decoded, nap.TsdFrame) - assert isinstance(proba, np.ndarray) + assert isinstance(proba, nap.TsdTensor) np.testing.assert_array_almost_equal(features.values, decoded.values) assert len(decoded) == 100 assert len(proba) == 100 @@ -200,33 +564,16 @@ def test_decode_2d_with_dict(): np.testing.assert_array_almost_equal(proba[:, :, 1], tmp) +@pytest.mark.filterwarnings("ignore") def test_decode_2d_with_feature(): features, group, tc, ep, xy = get_testing_set_2d() decoded, proba = nap.decode_2d(tc, group, ep, 1, xy) np.testing.assert_array_almost_equal(features.values, decoded.values) +@pytest.mark.filterwarnings("ignore") def test_decode_2d_with_time_units(): features, group, tc, ep, xy = get_testing_set_2d() for t, tu in zip([1, 1e3, 1e6], ["s", "ms", "us"]): decoded, proba = nap.decode_2d(tc, group, ep, 1.0 * t, xy, time_units=tu) np.testing.assert_array_almost_equal(features.values, decoded.values) - - -def test_decoded_2d_raise_errors(): - features, group, tc, ep, xy = get_testing_set_2d() - with pytest.raises(Exception) as e_info: - nap.decode_2d(tc, np.random.rand(10), ep, 1, xy) - assert str(e_info.value) == "Unknown format for group" - - features, group, tc, ep, xy = get_testing_set_2d() - tc[5] = np.random.rand(2, 2) - with pytest.raises(Exception) as e_info: - nap.decode_2d(tc, group, ep, 1, xy) - assert str(e_info.value) == "Different shapes for tuning_curves and group" - - features, group, tc, ep, xy = get_testing_set_2d() - tc = {k: tc[i] for k, i in zip(np.arange(0, 40, 10), tc.keys())} - with pytest.raises(Exception) as e_info: - nap.decode_2d(tc, group, ep, 1, xy) - assert str(e_info.value) == "Different indices for tuning curves and group keys" diff --git a/tests/test_filtering.py b/tests/test_filtering.py index 81e9a4aec..499bda874 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -523,3 +523,55 @@ def test_get_filter_frequency_response_error(): ValueError, match="Unrecognized filter mode. Choose either 'butter' or 'sinc'" ): nap.get_filter_frequency_response(250, 1000, "lowpass", "a", 4, 0.02) + + +@pytest.mark.parametrize( + "freq_band, thresh_band, num_events, start, end", + [ + ((10, 30), (1, 10), 1, 0, 2), + ((40, 60), (1, 10), 1, 3, 5), + ((100, 150), (1, 10), 0, None, None), + ], +) +def test_detect_oscillatory_events(freq_band, thresh_band, num_events, start, end): + fs = 1000 + duration = 5 + min_dur = 0.1 + max_dur = 2 + min_inter = 0.02 + + t = np.linspace(0, duration, int(fs * duration), endpoint=False) + signal = np.zeros_like(t) + + # 25 Hz oscillation from 0-2s + freq_1 = 25 + mask1 = (t >= 0) & (t < 2) + signal[mask1] = np.sin(2 * np.pi * freq_1 * t[mask1]) + + # 50 Hz oscillation from 3-5s + freq_2 = 50 + mask2 = (t >= 3) & (t < 5) + signal[mask2] = np.sin(2 * np.pi * freq_2 * t[mask2]) + + ts = nap.Tsd(t=t, d=signal) + epoch = nap.IntervalSet(start=0, end=duration) + osc_ep = nap.filtering.detect_oscillatory_events( + ts, epoch, freq_band, thresh_band, (min_dur, max_dur), min_inter + ) + + assert len(osc_ep) == num_events # Only one event in given freq_band + + if num_events > 0: + # Start and end should be close to actuals +/- a small amount + detected_start = osc_ep.start[0] + detected_end = osc_ep.end[0] + assert np.isclose(start, detected_start, atol=0.05) + assert np.isclose(end, detected_end, atol=0.05) + + # Check we store power, amplitude, and peak_time + for key in ["power", "amplitude", "peak_time"]: + assert key in osc_ep._metadata + + # Check peak_time is within the interval + peak_time = osc_ep._metadata["peak_time"][0] + assert start <= peak_time <= end diff --git a/tests/test_metadata.py b/tests/test_metadata.py index b5c95a85c..905ea37f1 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd import pytest +import xarray import pynapple as nap from pynapple.core.metadata_class import _Metadata @@ -1464,6 +1465,75 @@ def test_drop_metadata_error(self, obj, obj_len, drop, error): if isinstance(drop, list) and ("label" in drop): assert "label" in obj.metadata_columns + def test_restrict_metadata(self, obj, obj_len): + """ + Test for restricting metadata with restrict_info. + """ + info = np.ones(obj_len) + obj.set_info(l1=info, l2=info * 2, l3=info * 3) + for col in ["l1", "l2", "l3"]: + assert col in obj.metadata_columns + + # restrict to 1 key + obj.restrict_info("l1") + assert "l1" in obj.metadata_columns + for col in ["l2", "l3"]: + assert col not in obj.metadata_columns + + # rate should always be present in TsGroup + if isinstance(obj, nap.TsGroup): + assert "rate" in obj.metadata_columns + + # restrict to multiple keys + obj.set_info(l2=info * 2, l3=info * 3, l4=info * 4) + obj.restrict_info(["l1", "l2"]) + for col in ["l1", "l2"]: + assert col in obj.metadata_columns + for col in ["l3", "l4"]: + assert col not in obj.metadata_columns + + # rate should always be present in TsGroup + if isinstance(obj, nap.TsGroup): + assert "rate" in obj.metadata_columns + + @pytest.mark.parametrize( + "keep, error", + [ + ( + "not_info", + pytest.raises( + KeyError, + match=r"Metadata column\(s\) \['not_info'\] not found", + ), + ), + ( + ["not_info", "not_info2"], + pytest.raises( + KeyError, + match=r"Metadata column\(s\) \['not_info', 'not_info2'\] not found", + ), + ), + ( + ["label", 0], + pytest.raises(KeyError, match=r"Metadata column\(s\) \[0\] not found"), + ), + (0, pytest.raises(TypeError, match="Invalid metadata column")), + ], + ) + def test_restrict_metadata_error(self, obj, obj_len, keep, error): + """ + Test for errors when dropping metadata. + """ + info = np.ones(obj_len) + obj.set_info(label=info, other=info * 2) + + with error: + obj.restrict_info(keep) + + # make sure nothing gets dropped + assert "label" in obj.metadata_columns + assert "other" in obj.metadata_columns + # test naming overlap of shared attributes @pytest.mark.parametrize( "name", @@ -1786,18 +1856,18 @@ def test_metadata_groupby_apply_func_kwargs( "func, ep, err", [ ( # input_key is not string - nap.compute_1d_tuning_curves, + nap.compute_tuning_curves, 1, pytest.raises(TypeError, match="input_key must be a string"), ), ( # input_key does not exist in function - nap.compute_1d_tuning_curves, + nap.compute_tuning_curves, "epp", pytest.raises(KeyError, match="does not have input parameter"), ), ( # function missing required inputs, or incorrect input type - nap.compute_1d_tuning_curves, - "ep", + nap.compute_tuning_curves, + "epochs", pytest.raises(TypeError), ), ], @@ -2375,7 +2445,7 @@ def tsdframe_gba(self): def test_metadata_groupby_apply_tuning_curves(self, tsgroup_gba, iset_gba): """ - Test for groupby_apply with nap.compute_1d_tuning_curves when: + Test for groupby_apply with nap.compute_tuning_curves when: 1. a TsGroup is grouped 2. an IntervalSet is grouped and makes sure the outputs are different. @@ -2385,30 +2455,31 @@ def test_metadata_groupby_apply_tuning_curves(self, tsgroup_gba, iset_gba): # apply to intervalset out = iset_gba.groupby_apply( "label", - nap.compute_1d_tuning_curves, - "ep", - group=tsgroup_gba, - feature=feature, - nb_bins=5, + nap.compute_tuning_curves, + "epochs", + data=tsgroup_gba, + features=feature, + bins=5, ) for grp, idx in iset_gba.groupby("label").items(): - tmp = nap.compute_1d_tuning_curves( - tsgroup_gba, feature, nb_bins=5, ep=iset_gba[idx] + tmp = nap.compute_tuning_curves( + tsgroup_gba, feature, bins=5, epochs=iset_gba[idx] ) - pd.testing.assert_frame_equal(out[grp], tmp) + xarray.testing.assert_identical(out[grp], tmp) # apply to tsgroup out2 = tsgroup_gba.groupby_apply( "label", - nap.compute_1d_tuning_curves, - feature=feature, - nb_bins=5, + nap.compute_tuning_curves, + features=feature, + bins=5, ) + # make sure groups are different assert out2.keys() != out.keys() for grp, idx in tsgroup_gba.groupby("label").items(): - tmp = nap.compute_1d_tuning_curves(tsgroup_gba[idx], feature, nb_bins=5) - pd.testing.assert_frame_equal(out2[grp], tmp) + tmp = nap.compute_tuning_curves(tsgroup_gba[idx], feature, bins=5) + xarray.testing.assert_identical(out2[grp], tmp) def test_metadata_groupby_apply_tsgroup_lambda(self, tsgroup_gba): """ @@ -2525,7 +2596,7 @@ def test_no_conflict_between_class_and_metadatamixin(nap_class): conflicting_members = iset_members.intersection(metadatamixin_members) # set_info, get_info, drop_info, groupby, and groupby_apply are overwritten for class-specific examples in docstrings - assert len(conflicting_members) == 5, ( + assert len(conflicting_members) == 6, ( f"Conflict detected! The following methods/attributes are " f"overwritten in IntervalSet: {conflicting_members}" ) diff --git a/tests/test_npz_file.py b/tests/test_npz_file.py index ef83677b9..c525eb47f 100644 --- a/tests/test_npz_file.py +++ b/tests/test_npz_file.py @@ -52,6 +52,25 @@ }, minfo=[1, 2, 3], ), + "tsdgroup": nap.TsGroup( + { + 0: nap.Tsd(t=np.arange(0, 200), d=np.random.rand(200)), + 1: nap.Tsd( + t=np.arange(0, 200, 0.5), d=np.arange(0, 200, 0.5) + 1, time_units="s" + ), + 2: nap.Ts(t=np.arange(0, 300, 0.2), time_units="s"), + } + ), + "tsdgroup_minfo": nap.TsGroup( + { + 0: nap.Tsd(t=np.arange(0, 200), d=np.random.rand(200)), + 1: nap.Tsd( + t=np.arange(0, 200, 0.5), d=np.arange(0, 200, 0.5) + 1, time_units="s" + ), + 2: nap.Ts(t=np.arange(0, 300, 0.2), time_units="s"), + }, + minfo=[1, 2, 3], + ), "iset": nap.IntervalSet(start=np.array([0.0, 5.0]), end=np.array([1.0, 6.0])), "iset_minfo": nap.IntervalSet( start=np.array([0.0, 5.0]), end=np.array([1.0, 6.0]), metadata={"minfo": [1, 2]} @@ -82,6 +101,8 @@ def test_init(path): "tsdframe_minfo", "tsgroup", "tsgroup_minfo", + "tsdgroup", + "tsdgroup_minfo", "iset", "iset_minfo", ], @@ -109,6 +130,20 @@ def test_load_tsgroup(path, k): ) +@pytest.mark.parametrize("path", [path]) +@pytest.mark.parametrize("k", ["tsdgroup", "tsdgroup_minfo"]) +def test_load_tsdgroup(path, k): + file_path = path / (k + ".npz") + file = nap.NPZFile(file_path) + tmp = file.load() + assert isinstance(tmp, type(data[k])) + assert tmp.keys() == data[k].keys() + assert np.all(tmp[neu] == data[k][neu] for neu in tmp.keys()) + np.testing.assert_array_almost_equal( + tmp.time_support.values, data[k].time_support.values + ) + + @pytest.mark.parametrize("path", [path]) @pytest.mark.parametrize("k", ["tsgroup", "tsgroup_minfo"]) def test_load_tsgroup_backward_compatibility(path, k): diff --git a/tests/test_numpy_compatibility.py b/tests/test_numpy_compatibility.py index c4323147f..c993dd64a 100644 --- a/tests/test_numpy_compatibility.py +++ b/tests/test_numpy_compatibility.py @@ -1,3 +1,6 @@ +import warnings +from numbers import Number + import numpy as np import numpy.core.umath as _umath import pytest @@ -500,3 +503,492 @@ def test_concatenate(self, tsd): def test_fft(self, tsd): with pytest.raises(TypeError): np.fft.fft(tsd) + + +@pytest.mark.parametrize( + "func, kwargs", + [ + ("concatenate", {}), + ("concatenate", {"axis": 0}), + ("concatenate", {"axis": 1}), + ("concatenate", {"axis": 2}), + ("stack", {}), + ("stack", {"axis": 0}), + ("stack", {"axis": 1}), + ("stack", {"axis": 2}), + ("stack", {"axis": -1}), + ("vstack", {}), + ("hstack", {}), + ("dstack", {}), + ("column_stack", {}), + ], +) +@pytest.mark.parametrize( + "tsds", + [ + ( + nap.Tsd(t=np.arange(10), d=np.random.rand(10)), + nap.Tsd(t=np.arange(10) + 15, d=np.random.rand(10)), + ), + ( + nap.Tsd(t=np.arange(10), d=np.random.rand(10)), + nap.Tsd(t=np.arange(10), d=np.random.rand(10)), + ), + ( + nap.Tsd(t=np.arange(10) + 15, d=np.random.rand(10)), + nap.Tsd(t=np.arange(10), d=np.random.rand(10)), + ), + ( + nap.TsdFrame(t=np.arange(10), d=np.random.rand(10, 5)), + nap.TsdFrame(t=np.arange(10) + 15, d=np.random.rand(10, 5)), + ), + ( + nap.TsdFrame(t=np.arange(10) + 15, d=np.random.rand(10, 5)), + nap.TsdFrame(t=np.arange(10), d=np.random.rand(10, 5)), + ), + ( + nap.TsdFrame(t=np.arange(10), d=np.random.rand(10, 5)), + nap.TsdFrame(t=np.arange(10), d=np.random.rand(10, 5)), + ), + ( + nap.TsdTensor(t=np.arange(10), d=np.random.rand(10, 5, 2)), + nap.TsdTensor(t=np.arange(10), d=np.random.rand(10, 5, 2)), + ), + ( + nap.TsdTensor(t=np.arange(10) + 15, d=np.random.rand(10, 5, 2)), + nap.TsdTensor(t=np.arange(10), d=np.random.rand(10, 5, 2)), + ), + ], +) +def test_concatenate_all(func, kwargs, tsds): + tsd1, tsd2 = tsds + try: + b = getattr(np, func)((tsd1.values, tsd2.values), **kwargs) + except (ValueError, RuntimeError): + pytest.skip("Skipping invalid axis operation") + + try: + with warnings.catch_warnings(record=True) as record: + warnings.simplefilter("always") + a = getattr(np, func)((tsd1, tsd2), **kwargs) + except (ValueError, RuntimeError) as e: + error_msg = str(e) + assert ( + error_msg + == "The order of the time series indexes should be strictly increasing and non overlapping." + ) + return + + if a.ndim == tsd1.ndim: + if a.shape[0] == tsd1.shape[0] + tsd2.shape[0]: # Stacking vertically + assert isinstance(a, tsd1.__class__) + np.testing.assert_array_almost_equal( + a.index, np.concatenate((tsd1.index, tsd2.index)) + ) + np.testing.assert_array_almost_equal(a.values, b) + np.testing.assert_array_equal( + np.vstack((tsd1.time_support.values, tsd2.time_support.values)), + a.time_support.values, + ) + else: + # Check if operation was allowed + if isinstance(a, tsd1.__class__): + np.testing.assert_array_almost_equal(tsd1.index, tsd2.index) + np.testing.assert_array_almost_equal(a.values, b) + np.testing.assert_array_equal(a.index, tsd1.index) + if hasattr(tsd1, "columns") and hasattr(tsd2, "columns"): + np.testing.assert_array_equal( + a.columns, np.concatenate((tsd1.columns, tsd2.columns), axis=0) + ) + else: + assert isinstance(a, np.ndarray) + np.testing.assert_array_almost_equal(a, b) + else: + # Check if operation was allowed + if hasattr(a, "nap_class"): + np.testing.assert_array_almost_equal(tsd1.index, tsd2.index) + np.testing.assert_array_almost_equal(a.values, b) + np.testing.assert_array_equal(a.index, tsd1.index) + np.testing.assert_array_equal( + tsd1.time_support.values, tsd2.time_support.values + ) + else: + assert isinstance(a, np.ndarray) + np.testing.assert_array_almost_equal(a, b) + + if len(record) > 0: + warning_msg = str(record[0].message) + assert warning_msg in [ + "Time indexes and time supports are not all equals up to pynapple precision. Returning numpy array!", + "Time indexes are not all equals up to pynapple precision. Returning numpy array!", + "Time supports are not all equals up to pynapple precision. Returning numpy array!", + ] + assert isinstance(a, np.ndarray) + + +@pytest.mark.parametrize( + "tsd", + [ + nap.TsdFrame( + t=np.arange(10), + d=np.random.rand(10, 10), + ), + nap.TsdTensor(t=np.arange(10), d=np.random.rand(10, 10, 10), time_units="s"), + ], +) +@pytest.mark.parametrize( + "func, kwargs", + [ + ("sum", {}), + ("sum", {"axis": 0}), + ("sum", {"axis": 1}), + ("sum", {"axis": -1}), + ("sum", {"axis": (0, 1)}), + ("sum", {"axis": None}), + ], +) +def test_square_arrays(tsd, func, kwargs): + a = getattr(np, func)(tsd, **kwargs) + b = getattr(np, func)(tsd.values, **kwargs) + + if "axis" in kwargs: + axis = kwargs["axis"] + else: + axis = None + + if axis is None or np.isscalar(b): + assert np.isscalar(a) + assert a == b + else: + if (axis == 0) or (isinstance(axis, tuple) and 0 in axis): + assert isinstance(a, (np.ndarray, Number)) + np.testing.assert_array_almost_equal(a, b) + else: + assert not isinstance(a, np.ndarray) + np.testing.assert_array_almost_equal(a.index, tsd.index) + np.testing.assert_array_almost_equal(a.values, b) + + +@pytest.mark.parametrize( + "tsd", + [ + nap.Tsd(t=np.arange(10), d=np.random.rand(10)), + nap.TsdFrame(t=np.arange(10), d=np.random.rand(10, 10)), + nap.TsdTensor(t=np.arange(10), d=np.random.rand(10, 10, 10), time_units="s"), + ], +) +@pytest.mark.parametrize( + "func, kwargs, expected_type", + [ + ("transpose", {}, (nap.Tsd, np.ndarray)), + ("transpose", {"axes": (2, 0, 1)}, np.ndarray), + ("transpose", {"axes": (0, 2, 1)}, nap.TsdTensor), + ("moveaxis", {"source": 0, "destination": 1}, np.ndarray), + ("moveaxis", {"source": 1, "destination": 0}, np.ndarray), + ("moveaxis", {"source": 2, "destination": 1}, nap.TsdTensor), + ("swapaxes", {"axis1": 0, "axis2": 1}, np.ndarray), + ("swapaxes", {"axis1": 1, "axis2": 2}, nap.TsdTensor), + ("swapaxes", {"axis1": 2, "axis2": 0}, np.ndarray), + ( + "rollaxis", + {"axis": 0, "start": 1}, + {"Tsd": np.ndarray, "TsdFrame": np.ndarray, "TsdTensor": np.ndarray}, + ), + ("rollaxis", {"axis": 1, "start": 0}, np.ndarray), + ("rollaxis", {"axis": 1, "start": 2}, (nap.TsdTensor, nap.TsdFrame)), + ("flipud", {}, np.ndarray), + ("fliplr", {}, (nap.TsdFrame, nap.TsdTensor)), + ("flip", {"axis": 0}, np.ndarray), + ("flip", {"axis": None}, np.ndarray), + ("flip", {"axis": 1}, (nap.TsdFrame, nap.TsdTensor)), + ("flip", {"axis": 2}, nap.TsdTensor), + ("rot90", {}, np.ndarray), + ("rot90", {"k": 2}, np.ndarray), + ("roll", {"shift": 2, "axis": 0}, np.ndarray), + ("roll", {"shift": -2, "axis": 1}, (nap.TsdFrame, nap.TsdTensor)), + ("roll", {"shift": 1, "axis": 2}, nap.TsdTensor), + ], +) +def test_axis_moving(tsd, func, kwargs, expected_type): + try: + b = getattr(np, func)(tsd.values, **kwargs) + except (ValueError, RuntimeError): + pytest.skip("Skipping invalid axis operation") + + a = getattr(np, func)(tsd, **kwargs) + + if isinstance(expected_type, dict): + assert isinstance(a, expected_type[tsd.nap_class]) + else: + assert isinstance(a, expected_type) + + if not isinstance(a, np.ndarray): + np.testing.assert_array_almost_equal(a.index, tsd.index) + + if hasattr(a, "values"): + np.testing.assert_array_almost_equal(a.values, b) + else: + np.testing.assert_array_almost_equal(a, b) + + +@pytest.mark.parametrize( + "tsd", + [ + nap.Tsd(t=np.arange(10), d=np.random.rand(10)), + nap.TsdFrame(t=np.arange(10), d=np.random.rand(10, 10)), + nap.TsdTensor(t=np.arange(10), d=np.random.rand(10, 10, 10)), + nap.TsdTensor(t=np.arange(10), d=np.random.rand(10, 10, 1)), + ], +) +@pytest.mark.parametrize( + "func, kwargs, expected_type", + [ + ("expand_dims", {"axis": 0}, np.ndarray), + ("expand_dims", {"axis": 1}, (nap.TsdFrame, nap.TsdTensor)), + ("expand_dims", {"axis": -1}, (nap.TsdFrame, nap.TsdTensor)), + ( + "expand_dims", + {"axis": -2}, + {"Tsd": np.ndarray, "TsdFrame": nap.TsdTensor, "TsdTensor": nap.TsdTensor}, + ), + ("squeeze", {}, (nap.Tsd, nap.TsdFrame, nap.TsdTensor)), + ( + "ravel", + {}, + {"Tsd": nap.Tsd, "TsdFrame": np.ndarray, "TsdTensor": np.ndarray}, + ), + ( + "ravel", + {"order": "F"}, + {"Tsd": nap.Tsd, "TsdFrame": np.ndarray, "TsdTensor": np.ndarray}, + ), + ( + "tile", + {"reps": 2}, + {"Tsd": np.ndarray, "TsdFrame": nap.TsdFrame, "TsdTensor": nap.TsdTensor}, + ), + ( + "tile", + {"reps": (2, 1)}, + {"Tsd": np.ndarray, "TsdFrame": np.ndarray, "TsdTensor": nap.TsdTensor}, + ), + ( + "tile", + {"reps": (1, 2)}, + {"Tsd": np.ndarray, "TsdFrame": nap.TsdFrame, "TsdTensor": nap.TsdTensor}, + ), + ], +) +def test_shape_change(tsd, func, kwargs, expected_type): + try: + b = getattr(np, func)(tsd.values, **kwargs) + except (ValueError, RuntimeError): + pytest.skip("Skipping invalid axis operation") + + a = getattr(np, func)(tsd, **kwargs) + + if isinstance(expected_type, dict): + assert isinstance(a, expected_type[tsd.nap_class]) + else: + assert isinstance(a, expected_type) + + if not isinstance(a, np.ndarray): + np.testing.assert_array_almost_equal(a.index, tsd.index) + + if hasattr(a, "values"): + np.testing.assert_array_almost_equal(a.values, b) + else: + np.testing.assert_array_almost_equal(a, b) + + +@pytest.mark.parametrize( + "tsd, slicing, expected_type", + [ + ( + nap.Tsd(t=np.arange(10), d=np.random.rand(10)), + lambda x: x[None, :], + np.ndarray, + ), + ( + nap.Tsd(t=np.arange(10), d=np.random.rand(10)), + lambda x: x[:, None], + nap.TsdFrame, + ), + ( + nap.TsdFrame(t=np.arange(10), d=np.random.rand(10, 10)), + lambda x: x[:, None], + nap.TsdTensor, + ), + ( + nap.TsdFrame(t=np.arange(10), d=np.random.rand(10, 10)), + lambda x: x[:, :, None], + nap.TsdTensor, + ), + ( + nap.TsdTensor(t=np.arange(10), d=np.random.rand(10, 10, 10)), + lambda x: x[:, None], + nap.TsdTensor, + ), + ( + nap.TsdTensor(t=np.arange(10), d=np.random.rand(10, 10, 1)), + lambda x: x[None, :], + np.ndarray, + ), + ], +) +def test_shape_change_2(tsd, slicing, expected_type): + a = slicing(tsd) + assert isinstance(a, expected_type) + if hasattr(a, "index"): + np.testing.assert_array_almost_equal(a.index, tsd.index) + if hasattr(a, "values"): + np.testing.assert_array_almost_equal(a.values, slicing(tsd.values)) + + +@pytest.mark.parametrize( + "a, b, expected_type", + [ + ( + nap.Tsd(t=np.arange(10), d=np.random.rand(10)), + nap.Tsd(t=np.arange(10), d=np.random.rand(10)), + float, + ), + ( + nap.TsdFrame(t=np.arange(10), d=np.random.rand(10, 5)), + np.random.randn(5, 3), + nap.TsdFrame, + ), + ( + nap.TsdTensor(t=np.arange(10), d=np.random.rand(10, 4, 2)), + np.random.rand(2, 3), + nap.TsdTensor, + ), + ( + np.random.rand(5, 10), + nap.TsdFrame(t=np.arange(10), d=np.random.rand(10, 5)), + np.ndarray, + ), + ], +) +def test_dot_product(a, b, expected_type): + out = np.dot(a, b) + + assert isinstance(out, expected_type) + + if hasattr(a, "values") and hasattr(b, "values"): + out2 = np.dot(a.values, b.values) + elif hasattr(a, "values"): + out2 = np.dot(a.values, b) + elif hasattr(b, "values"): + out2 = np.dot(a, b.values) + else: + out2 = np.dot(a, b) + + if hasattr(out, "values"): + np.testing.assert_array_almost_equal(out.values, out2) + else: + if isinstance(out2, float): + assert out == out2 + else: + np.testing.assert_array_almost_equal(out, out2) + + +@pytest.mark.parametrize( + "a, b, expected_type", + [ + ( + nap.TsdFrame(t=np.arange(10), d=np.random.rand(10, 5)), + np.random.randn(5, 3), + nap.TsdFrame, + ), + ( + nap.TsdTensor(t=np.arange(10), d=np.random.rand(10, 4, 2)), + np.random.rand(2, 3), + nap.TsdTensor, + ), + ( + np.random.rand(5, 10), + nap.TsdFrame(t=np.arange(10), d=np.random.rand(10, 5)), + np.ndarray, + ), + ], +) +def test_matmul_product(a, b, expected_type): + out = np.matmul(a, b) + + assert isinstance(out, expected_type) + + if hasattr(a, "values") and hasattr(b, "values"): + out2 = np.dot(a.values, b.values) + elif hasattr(a, "values"): + out2 = np.dot(a.values, b) + elif hasattr(b, "values"): + out2 = np.dot(a, b.values) + else: + out2 = np.dot(a, b) + + if hasattr(out, "values"): + np.testing.assert_array_almost_equal(out.values, out2) + else: + if isinstance(out2, float): + assert out == out2 + else: + np.testing.assert_array_almost_equal(out, out2) + + +@pytest.mark.parametrize( + "a, b, subscripts, expected_type", + [ + ( + nap.Tsd(t=np.arange(10), d=np.random.rand(10)), + nap.Tsd(t=np.arange(10), d=np.random.rand(10)), + "i,i->", + float, + ), + ( + nap.TsdFrame(t=np.arange(10), d=np.random.rand(10, 5)), + np.random.randn(5, 3), + "ij,jk->ik", + nap.TsdFrame, + ), + ( + nap.TsdTensor(t=np.arange(10), d=np.random.rand(10, 4, 2)), + np.random.rand(2, 3), + "ijk,kl->ijl", + nap.TsdTensor, + ), + ( + np.random.rand(5, 10), + nap.TsdFrame(t=np.arange(10), d=np.random.rand(10, 5)), + "ij,jk->ik", + np.ndarray, + ), + ( + nap.TsdFrame(t=np.arange(10), d=np.random.rand(10, 5)), + np.random.randn(5, 10), + "ij,ji->i", + nap.Tsd, + ), + ], +) +def test_einsum(a, b, subscripts, expected_type): + out = np.einsum(subscripts, a, b) + + assert isinstance(out, expected_type) + + if hasattr(a, "values") and hasattr(b, "values"): + out2 = np.einsum(subscripts, a.values, b.values) + elif hasattr(a, "values"): + out2 = np.einsum(subscripts, a.values, b) + elif hasattr(b, "values"): + out2 = np.einsum(subscripts, a, b.values) + else: + out2 = np.einsum(subscripts, a, b) + + if hasattr(out, "values"): + np.testing.assert_array_almost_equal(out.values, out2) + else: + if isinstance(out2, float): + assert out == out2 + else: + np.testing.assert_array_almost_equal(out, out2) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 4472b634f..a2e1005c2 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -539,6 +539,29 @@ def test_restrict_inherit_time_support(self, tsd): np.testing.assert_approx_equal(tsd2.time_support.start[0], ep.start[0]) np.testing.assert_approx_equal(tsd2.time_support.end[0], ep.end[0]) + @pytest.mark.parametrize( + "ep, true_inds, false_inds", + [ + (nap.IntervalSet(start=0, end=50), np.arange(51), np.arange(51, 100)), + ( + nap.IntervalSet(start=[0, 20], end=[10, 30]), + np.hstack((np.arange(11), np.arange(20, 31))), + np.hstack((np.arange(11, 20), np.arange(31, 100))), + ), + ], + ) + def test_in_interval(self, tsd, ep, true_inds, false_inds): + tsd2 = tsd.in_interval(ep) + assert isinstance(tsd2, nap.Tsd) + assert all(tsd2.time_support.start == tsd.time_support.start) + assert all(tsd2.time_support.end == tsd.time_support.end) + assert all(tsd2[true_inds] == True) + assert all(tsd2[false_inds] == False) + + def test_in_interval_error(self, tsd): + with pytest.raises(TypeError, match=r"Argument should be IntervalSet"): + tsd.in_interval([0, 1]) + def test_get_interval(self, tsd): tsd2 = tsd.get(10, 20) assert len(tsd2) == 11 @@ -1604,6 +1627,27 @@ def test_vert_and_horz_slicing(self, tsdframe, row, col, expected): output, tsdframe.values[row, col] ) + def test_tsd_indexing(self, tsdframe): + tsd_index = tsdframe[:, 0] > 0 + output = tsdframe[tsd_index] + np.testing.assert_array_almost_equal( + output.values, tsdframe.values[tsd_index.values] + ) + assert isinstance(output, nap.TsdFrame) + + with pytest.raises(ValueError, match="must contain boolean values"): + tsdframe[tsd_index + 1] + + tsdframe_index = tsdframe > 0 + output = tsdframe[tsdframe_index] + np.testing.assert_array_almost_equal( + output, tsdframe.values[tsdframe_index.values] + ) + assert isinstance(output, np.ndarray) + + with pytest.raises(ValueError, match="must contain boolean values"): + tsdframe[tsdframe_index + 1] + @pytest.mark.parametrize("index", [0, [0, 2]]) def test_str_indexing(self, tsdframe, index): columns = tsdframe.columns diff --git a/tests/test_tuning_curves.py b/tests/test_tuning_curves.py index bb54891a2..d5534945d 100644 --- a/tests/test_tuning_curves.py +++ b/tests/test_tuning_curves.py @@ -1,154 +1,1317 @@ -"""Tests of tuning curves for `pynapple` package.""" +"""Tests of N-dimensional tuning curves for `pynapple` package.""" from contextlib import nullcontext as does_not_raise import numpy as np import pandas as pd import pytest +import xarray as xr import pynapple as nap -######################## -# Type Error -######################## -def get_group(): - return nap.TsGroup({0: nap.Ts(t=np.arange(0, 100))}) - - -def get_feature(): - return nap.Tsd( - t=np.arange(0, 100, 0.1), - d=np.arange(0, 100, 0.1) % 1.0, - time_support=nap.IntervalSet(0, 100), +def get_group_n(n): + return nap.TsGroup( + {i + 1: nap.Ts(t=np.arange(0, 100, 10 ** (i - 1))) for i in range(n)} ) -def get_features(): - tmp = np.vstack( - (np.repeat(np.arange(0, 100), 10), np.tile(np.arange(0, 100), 10)) - ).T +def get_features_n(n, fs=10.0): return nap.TsdFrame( - t=np.arange(0, 200, 0.1), - d=np.vstack((tmp, tmp[::-1])), - time_support=nap.IntervalSet(0, 200), + t=np.arange(0, 100, 1 / fs), + d=np.stack( + [np.arange(0, 100, 1 / fs) % 10 * i for i in range(1, n + 1)], axis=1 + ), + columns=[f"feature{i}" for i in range(n)], ) -def get_ep(): - return nap.IntervalSet(start=0, end=50) +@pytest.mark.parametrize( + "data, features, kwargs, expectation", + [ + # data + ( + [1], + get_features_n(1), + {}, + pytest.raises( + TypeError, match="data should be a TsdFrame, TsGroup, Ts, or Tsd." + ), + ), + ( + None, + get_features_n(1), + {}, + pytest.raises( + TypeError, match="data should be a TsdFrame, TsGroup, Ts, or Tsd." + ), + ), + ( + {1: nap.Ts([1, 2, 3])}, + get_features_n(1), + {}, + pytest.raises( + TypeError, match="data should be a TsdFrame, TsGroup, Ts, or Tsd." + ), + ), + (get_group_n(1), get_features_n(1), {}, does_not_raise()), + (get_group_n(3), get_features_n(1), {}, does_not_raise()), + (get_group_n(1).count(0.1), get_features_n(1), {}, does_not_raise()), + (get_group_n(3).count(0.1), get_features_n(1), {}, does_not_raise()), + (nap.Tsd(t=[1, 2, 3], d=[1, 1, 1]), get_features_n(1), {}, does_not_raise()), + (nap.Ts([1, 2, 3]), get_features_n(1), {}, does_not_raise()), + # features + ( + get_group_n(1), + [1], + {}, + pytest.raises(TypeError, match="features should be a Tsd or TsdFrame"), + ), + ( + get_group_n(1), + None, + {}, + pytest.raises(TypeError, match="features should be a Tsd or TsdFrame"), + ), + ( + get_group_n(1), + nap.Tsd(d=[1, 1, 1], t=[1, 2, 3]), + {}, + does_not_raise(), + ), + ( + get_group_n(1), + get_features_n(3), + {}, + does_not_raise(), + ), + # epochs + ( + get_group_n(1), + get_features_n(1), + {"epochs": 1}, + pytest.raises(TypeError, match="epochs should be an IntervalSet."), + ), + ( + get_group_n(1), + get_features_n(1), + {"epochs": [1, 2]}, + pytest.raises(TypeError, match="epochs should be an IntervalSet."), + ), + ( + get_group_n(1), + get_features_n(1), + {"epochs": None}, + does_not_raise(), + ), + ( + get_group_n(1), + get_features_n(1), + {"epochs": nap.IntervalSet(0.0, 50.0)}, + does_not_raise(), + ), + ( + get_group_n(1), + get_features_n(1), + {"epochs": nap.IntervalSet([0.0, 30.0], [10.0, 50.0])}, + does_not_raise(), + ), + ( + get_group_n(1), + get_features_n(1), + {"epochs": nap.IntervalSet([0.0, 1000.0])}, + does_not_raise(), + ), + # range + ( + get_group_n(1), + get_features_n(2), + {"range": (0, 1)}, + pytest.raises( + ValueError, + match="range should be a sequence of tuples, one for each feature.", + ), + ), + ( + get_group_n(1), + get_features_n(1), + {"range": (0, 1)}, + does_not_raise(), + ), + ( + get_group_n(1), + get_features_n(1), + {"range": [(0, 1)]}, + does_not_raise(), + ), + # fs + ( + get_group_n(1), + get_features_n(1), + {"fs": "1"}, + pytest.raises(TypeError, match="fs should be a number"), + ), + ( + get_group_n(1), + get_features_n(1), + {"fs": []}, + pytest.raises(TypeError, match="fs should be a number"), + ), + ( + get_group_n(1), + get_features_n(1), + {"fs": 1}, + does_not_raise(), + ), + ( + get_group_n(1), + get_features_n(1), + {"fs": 1.0}, + does_not_raise(), + ), + # feature names + ( + get_group_n(1), + get_features_n(1), + {"feature_names": "feature0"}, + pytest.raises( + TypeError, + match="feature_names should be a list of strings.", + ), + ), + ( + get_group_n(1), + get_features_n(1), + {"feature_names": 0}, + pytest.raises( + TypeError, + match="feature_names should be a list of strings.", + ), + ), + ( + get_group_n(1), + get_features_n(1), + {"feature_names": [1]}, + pytest.raises( + TypeError, + match="feature_names should be a list of strings.", + ), + ), + ( + get_group_n(1), + get_features_n(1), + {"feature_names": [(1,)]}, + pytest.raises( + TypeError, + match="feature_names should be a list of strings.", + ), + ), + ( + get_group_n(1), + get_features_n(1), + {"feature_names": [(1, 1)]}, + pytest.raises( + TypeError, + match="feature_names should be a list of strings.", + ), + ), + ( + get_group_n(1), + get_features_n(1), + {"feature_names": ["feature0", "feature1"]}, + pytest.raises( + ValueError, match="feature_names should match the number of features." + ), + ), + ( + get_group_n(1), + get_features_n(1), + {"feature_names": ["feature0"]}, + does_not_raise(), + ), + ( + get_group_n(1), + get_features_n(2), + {"feature_names": ["feature0", "feature1"]}, + does_not_raise(), + ), + ( + get_group_n(1), + get_features_n(1), + {"feature_names": np.array(["feature0"])}, + does_not_raise(), + ), + ( + get_group_n(1), + get_features_n(2), + {"feature_names": np.array(["feature0", "feature1"])}, + does_not_raise(), + ), + ( + get_group_n(1), + get_features_n(1), + {"feature_names": ("feature0",)}, + does_not_raise(), + ), + ( + get_group_n(1), + get_features_n(2), + {"feature_names": ("feature0", "feature1")}, + does_not_raise(), + ), + # return_pandas + ( + get_group_n(1), + get_features_n(1), + {"return_pandas": 2}, + pytest.raises( + TypeError, + match="return_pandas should be a boolean.", + ), + ), + ( + get_group_n(1), + get_features_n(1), + {"return_pandas": "1"}, + pytest.raises( + TypeError, + match="return_pandas should be a boolean.", + ), + ), + ( + get_group_n(1), + get_features_n(1), + {"return_pandas": 0}, + does_not_raise(), + ), + ( + get_group_n(1), + get_features_n(1), + {"return_pandas": 1}, + does_not_raise(), + ), + ( + get_group_n(1), + get_features_n(1), + {"return_pandas": True}, + does_not_raise(), + ), + ( + get_group_n(1), + get_features_n(1), + {"return_pandas": False}, + does_not_raise(), + ), + ( + get_group_n(1), + get_features_n(2), + {"return_pandas": True}, + pytest.raises( + ValueError, + match="Cannot convert arrays with 3 dimensions into pandas objects. Requires 2 or fewer dimensions.", + ), + ), + # return_counts + ( + get_group_n(1), + get_features_n(1), + {"return_counts": 2}, + pytest.raises( + TypeError, + match="return_counts should be a boolean.", + ), + ), + ( + get_group_n(1), + get_features_n(1), + {"return_counts": "1"}, + pytest.raises( + TypeError, + match="return_counts should be a boolean.", + ), + ), + ( + get_group_n(1), + get_features_n(1), + {"return_counts": 0}, + does_not_raise(), + ), + ( + get_group_n(1), + get_features_n(1), + {"return_counts": 1}, + does_not_raise(), + ), + ( + get_group_n(1), + get_features_n(1), + {"return_counts": True}, + does_not_raise(), + ), + ( + get_group_n(1), + get_features_n(1), + {"return_counts": False}, + does_not_raise(), + ), + ], +) +def test_compute_tuning_curves_type_errors(data, features, kwargs, expectation): + with expectation: + nap.compute_tuning_curves(data, features, **kwargs) -def get_tsdframe(): - return nap.TsdFrame(t=np.arange(0, 100), d=np.ones((100, 2))) +@pytest.mark.parametrize( + "data, features, kwargs, expectation", + [ + # single rate unit, single feature + ( + get_group_n(1).count(1.0), + get_features_n(1), + {}, + xr.DataArray( + np.full((1, 10), 10.0), + dims=["unit", "feature0"], + coords={"unit": [1], "feature0": np.linspace(0, 9.9, 11)[:-1] + 0.495}, + attrs={ + "fs": 10.0, + "occupancy": np.full(10, 100.0), + "bin_edges": [np.linspace(0, 9.9, 11)], + }, + ), + ), + # multiple rate units, single feature + ( + get_group_n(2).count(1.0), + get_features_n(1), + {}, + xr.DataArray( + np.concatenate([np.full((1, 10), 10.0), np.full((1, 10), 1.0)]), + dims=["unit", "feature0"], + coords={ + "unit": [1, 2], + "feature0": np.linspace(0, 9.9, 11)[:-1] + 0.495, + }, + attrs={ + "fs": 10.0, + "occupancy": np.full(10, 100.0), + "bin_edges": [np.linspace(0, 9.9, 11)], + }, + ), + ), + # multiple rate units, multiple features + ( + get_group_n(2).count(1.0), + get_features_n(2), + {}, + xr.DataArray( + np.stack( + [ + np.where(np.eye(10), 10.0, np.nan), + np.where(np.eye(10), 1.0, np.nan), + ] + ), + dims=["unit", "feature0", "feature1"], + coords={ + "unit": [1, 2], + "feature0": np.linspace(0, 9.9, 11)[:-1] + 0.495, + "feature1": np.linspace(0, 19.8, 11)[:-1] + 0.99, + }, + attrs={ + "fs": 10.0, + "occupancy": np.where(np.eye(10), 100.0, 0.0), + "bin_edges": [np.linspace(0, i * 9.9, 11) for i in range(1, 3)], + }, + ), + ), + # single unit, single feature + ( + get_group_n(1), + get_features_n(1), + {}, + xr.DataArray( + np.full((1, 10), 10.0), + dims=["unit", "feature0"], + coords={"unit": [1], "feature0": np.linspace(0, 9.9, 11)[:-1] + 0.495}, + attrs={ + "fs": 10.0, + "occupancy": np.full(10, 100.0), + "bin_edges": [np.linspace(0, 9.9, 11)], + }, + ), + ), + # multiple units, single feature + ( + get_group_n(2), + get_features_n(1), + {}, + xr.DataArray( + np.concatenate([np.full((1, 10), 10.0), np.full((1, 10), 1.0)]), + dims=["unit", "feature0"], + coords={ + "unit": [1, 2], + "feature0": np.linspace(0, 9.9, 11)[:-1] + 0.495, + }, + attrs={ + "fs": 10.0, + "occupancy": np.full(10, 100.0), + "bin_edges": [np.linspace(0, 9.9, 11)], + }, + ), + ), + # multiple units, multiple features + ( + get_group_n(2), + get_features_n(2), + {}, + xr.DataArray( + np.stack( + [ + np.where(np.eye(10), 10.0, np.nan), + np.where(np.eye(10), 1.0, np.nan), + ] + ), + dims=["unit", "feature0", "feature1"], + coords={ + "unit": [1, 2], + "feature0": np.linspace(0, 9.9, 11)[:-1] + 0.495, + "feature1": np.linspace(0, 19.8, 11)[:-1] + 0.99, + }, + attrs={ + "fs": 10.0, + "occupancy": np.where(np.eye(10), 100.0, np.nan), + "bin_edges": [np.linspace(0, i * 9.9, 11) for i in range(1, 3)], + }, + ), + ), + # single unit, single feature, specified number of bins + ( + get_group_n(1), + get_features_n(1), + {"bins": 5}, + xr.DataArray( + np.full((1, 5), 10.0), + dims=["unit", "feature0"], + coords={"unit": [1], "feature0": np.linspace(0, 9.9, 6)[:-1] + 0.99}, + attrs={ + "fs": 10.0, + "occupancy": np.full(5, 200.0), + "bin_edges": [np.linspace(0, 9.9, 6)], + }, + ), + ), + # single unit, multiple features, specified number of bins + ( + get_group_n(1), + get_features_n(2), + {"bins": 5}, + xr.DataArray( + np.where(np.eye(5), 10.0, np.nan)[None, :], + dims=["unit", "feature0", "feature1"], + coords={ + "unit": [1], + "feature0": np.linspace(0, 9.9, 6)[:-1] + 0.99, + "feature1": np.linspace(0, 19.8, 6)[:-1] + 1.98, + }, + attrs={ + "fs": 10.0, + "occupancy": np.where(np.eye(5), 200.0, np.nan), + "bin_edges": [np.linspace(0, i * 9.9, 6) for i in range(1, 3)], + }, + ), + ), + # single unit, multiple features, specified number of bins per feature + ( + get_group_n(1), + get_features_n(2), + {"bins": (5, 4)}, + xr.DataArray( + np.array( + [ + [ + [10.0, np.nan, np.nan, np.nan], + [10.0, 10.0, np.nan, np.nan], + [np.nan, 10.0, 10.0, np.nan], + [np.nan, np.nan, 10.0, 10.0], + [np.nan, np.nan, np.nan, 10.0], + ] + ] + ), + dims=["unit", "feature0", "feature1"], + coords={ + "unit": [1], + "feature0": np.linspace(0, 9.9, 6)[:-1] + 0.99, + "feature1": np.linspace(0, 19.8, 5)[:-1] + 2.475, + }, + attrs={ + "fs": 10.0, + "occupancy": np.array( + [ + [200.0, np.nan, np.nan, np.nan], + [50.0, 150.0, np.nan, np.nan], + [np.nan, 100.0, 100.0, np.nan], + [np.nan, np.nan, 150.0, 50.0], + [np.nan, np.nan, np.nan, 200.0], + ] + ), + "bin_edges": [np.linspace(0, 9.9, 6), np.linspace(0, 19.8, 5)], + }, + ), + ), + # single unit, single feature, specified bins + ( + get_group_n(1), + get_features_n(1), + {"bins": [np.linspace(0, 10, 6)]}, + xr.DataArray( + np.full((1, 5), 10.0), + dims=["unit", "feature0"], + coords={"unit": [1], "feature0": np.arange(1, 11, 2)}, + attrs={ + "fs": 10.0, + "occupancy": np.full(5, 200.0), + "bin_edges": [np.linspace(0, 10, 6)], + }, + ), + ), + # single unit, multiple features, specified bins + ( + get_group_n(1), + get_features_n(2), + {"bins": [np.linspace(0, 10, 6), np.linspace(0, 20, 6)]}, + xr.DataArray( + np.where(np.eye(5), 10.0, np.nan)[None, :], + dims=["unit", "feature0", "feature1"], + coords={ + "unit": [1], + "feature0": np.arange(1, 11, 2), + "feature1": np.arange(2, 22, 4), + }, + attrs={ + "fs": 10.0, + "occupancy": np.where(np.eye(5), 200.0, np.nan), + "bin_edges": [np.linspace(0, i * 10, 6) for i in range(1, 3)], + }, + ), + ), + # single unit, single feature, specified range + ( + get_group_n(1), + get_features_n(1), + {"range": [(0, 5)]}, + xr.DataArray( + np.full((1, 10), 10.0), + dims=["unit", "feature0"], + coords={"unit": [1], "feature0": np.linspace(0, 5.0, 11)[:-1] + 0.25}, + attrs={ + "fs": 10.0, + "occupancy": np.concatenate([np.full(9, 50.0), [60]]), + "bin_edges": [np.linspace(0, 5.0, 11)], + }, + ), + ), + # single unit, multiple features, specified range per feature + ( + get_group_n(1), + get_features_n(2), + {"range": [(0, 5), (0, 10)]}, + xr.DataArray( + np.where(np.eye(10), 10.0, np.nan)[None, :], + dims=["unit", "feature0", "feature1"], + coords={ + "unit": [1], + "feature0": np.linspace(0, 5.0, 11)[:-1] + 0.25, + "feature1": np.linspace(0, 10.0, 11)[:-1] + 0.5, + }, + attrs={ + "fs": 10.0, + "occupancy": np.where( + np.eye(10), + 50.0 + + 10.0 * (np.arange(10) == 9)[:, None] * (np.arange(10) == 9), + np.nan, + ), + "bin_edges": [np.linspace(0, i * 5, 11) for i in range(1, 3)], + }, + ), + ), + # single unit, single feature, specified range and number of bins + ( + get_group_n(1), + get_features_n(1), + {"bins": 10, "range": [(0, 5)]}, + xr.DataArray( + np.full((1, 10), 10.0), + dims=["unit", "feature0"], + coords={"unit": [1], "feature0": np.linspace(0, 5.0, 11)[:-1] + 0.25}, + attrs={ + "fs": 10.0, + "occupancy": np.concatenate([np.full(9, 50.0), [60]]), + "bin_edges": [np.linspace(0, 5.0, 11)], + }, + ), + ), + # single unit, multiple features, specified range per feature and number of bins + ( + get_group_n(1), + get_features_n(2), + {"bins": 10, "range": [(0, 5), (0, 10)]}, + xr.DataArray( + np.where(np.eye(10), 10.0, np.nan)[None, :], + dims=["unit", "feature0", "feature1"], + coords={ + "unit": [1], + "feature0": np.linspace(0, 5.0, 11)[:-1] + 0.25, + "feature1": np.linspace(0, 10.0, 11)[:-1] + 0.5, + }, + attrs={ + "fs": 10.0, + "occupancy": np.where( + np.eye(10), + 50.0 + + 10.0 * (np.arange(10) == 9)[:, None] * (np.arange(10) == 9), + np.nan, + ), + "bin_edges": [np.linspace(0, i * 5, 11) for i in range(1, 3)], + }, + ), + ), + # single unit, multiple features, specified range and number of bins per feature + ( + get_group_n(1), + get_features_n(2), + {"bins": (10, 10), "range": [(0, 5), (0, 10)]}, + xr.DataArray( + np.where(np.eye(10), 10.0, np.nan)[None, :], + dims=["unit", "feature0", "feature1"], + coords={ + "unit": [1], + "feature0": np.linspace(0, 5.0, 11)[:-1] + 0.25, + "feature1": np.linspace(0, 10.0, 11)[:-1] + 0.5, + }, + attrs={ + "fs": 10.0, + "occupancy": np.where( + np.eye(10), + 50.0 + + 10.0 * (np.arange(10) == 9)[:, None] * (np.arange(10) == 9), + np.nan, + ), + "bin_edges": [np.linspace(0, i * 5, 11) for i in range(1, 3)], + }, + ), + ), + # single unit, single feature, specified epochs (smaller) + ( + get_group_n(1), + get_features_n(1), + {"epochs": nap.IntervalSet([0.0, 50.0])}, + xr.DataArray( + np.full((1, 10), 10.0), + dims=["unit", "feature0"], + coords={"unit": [1], "feature0": np.linspace(0, 9.9, 11)[:-1] + 0.495}, + attrs={ + "fs": 10.0, + "occupancy": np.concatenate([[51], np.full(9, 50.0)]), + "bin_edges": [np.linspace(0, 9.9, 11)], + }, + ), + ), + # single unit, single feature, specified epochs (larger) + ( + get_group_n(1), + get_features_n(1), + {"epochs": nap.IntervalSet([0.0, 200.0])}, + xr.DataArray( + np.full((1, 10), 10.0), + dims=["unit", "feature0"], + coords={"unit": [1], "feature0": np.linspace(0, 9.9, 11)[:-1] + 0.495}, + attrs={ + "fs": 10.0, + "occupancy": np.full(10, 100.0), + "bin_edges": [np.linspace(0, 9.9, 11)], + }, + ), + ), + # single unit, single feature, specified epochs (multiple) + ( + get_group_n(1), + get_features_n(1), + {"epochs": nap.IntervalSet([0.0, 50.0], [20.0, 70.0])}, + xr.DataArray( + np.full((1, 10), 10.0), + dims=["unit", "feature0"], + coords={"unit": [1], "feature0": np.linspace(0, 9.9, 11)[:-1] + 0.495}, + attrs={ + "fs": 10.0, + "occupancy": np.concatenate([[42], np.full(9, 40.0)]), + "bin_edges": [np.linspace(0, 9.9, 11)], + }, + ), + ), + # single unit, single feature, specified feature name + ( + get_group_n(1), + get_features_n(1), + {"feature_names": ["f0"]}, + xr.DataArray( + np.full((1, 10), 10.0), + dims=["unit", "f0"], + coords={"unit": [1], "f0": np.linspace(0, 9.9, 11)[:-1] + 0.495}, + attrs={ + "fs": 10.0, + "occupancy": np.full(10, 100.0), + "bin_edges": [np.linspace(0, 9.9, 11)], + }, + ), + ), + # single unit, multiple features, specified feature names + ( + get_group_n(1), + get_features_n(2), + {"feature_names": ["f0", "f1"]}, + xr.DataArray( + np.where(np.eye(10), 10.0, np.nan)[None, :], + dims=["unit", "f0", "f1"], + coords={ + "unit": [1], + "f0": np.linspace(0, 9.9, 11)[:-1] + 0.495, + "f1": np.linspace(0, 19.8, 11)[:-1] + 0.99, + }, + attrs={ + "fs": 10.0, + "occupancy": np.where(np.eye(10), 100.0, np.nan), + "bin_edges": [np.linspace(0, i * 9.9, 11) for i in range(1, 3)], + }, + ), + ), + # single unit, single feature, return_pandas=True + ( + get_group_n(1), + get_features_n(1), + {"return_pandas": True}, + xr.DataArray( + np.full((1, 10), 10.0), + dims=["unit", "feature0"], + coords={"unit": [1], "feature0": np.linspace(0, 9.9, 11)[:-1] + 0.495}, + attrs={ + "fs": 10.0, + "occupancy": np.full(10, 100.0), + "bin_edges": [np.linspace(0, i * 9.9, 11) for i in range(1, 2)], + }, + ) + .to_pandas() + .T, + ), + # single unit, single feature, return_counts=True + ( + get_group_n(1), + get_features_n(1), + {"return_counts": True}, + xr.DataArray( + np.full((1, 10), 100.0), + dims=["unit", "feature0"], + coords={"unit": [1], "feature0": np.linspace(0, 9.9, 11)[:-1] + 0.495}, + attrs={ + "fs": 10.0, + "occupancy": np.full(10, 100.0), + "bin_edges": [np.linspace(0, i * 9.9, 11) for i in range(1, 2)], + }, + ), + ), + # multiple units, multiple features, return_counts=True + ( + get_group_n(2), + get_features_n(2), + {"return_counts": True}, + xr.DataArray( + np.stack( + [ + np.where(np.eye(10), 100.0, 0.0), + np.where(np.eye(10), 10.0, 0.0), + ] + ), + dims=["unit", "feature0", "feature1"], + coords={ + "unit": [1, 2], + "feature0": np.linspace(0, 9.9, 11)[:-1] + 0.495, + "feature1": np.linspace(0, 19.8, 11)[:-1] + 0.99, + }, + attrs={ + "fs": 10.0, + "occupancy": np.where(np.eye(10), 100.0, np.nan), + "bin_edges": [np.linspace(0, i * 9.9, 11) for i in range(1, 3)], + }, + ), + ), + ], +) +def test_compute_tuning_curves(data, features, kwargs, expectation): + tcs = nap.compute_tuning_curves(data, features, **kwargs) + if isinstance(expectation, pd.DataFrame): + pd.testing.assert_frame_equal(tcs, expectation) + else: + xr.testing.assert_allclose(tcs, expectation) + for attribute in expectation.attrs: + assert attribute in tcs.attrs + if isinstance(expectation.attrs[attribute], (np.ndarray, float)): + print(tcs.attrs[attribute]) + np.testing.assert_array_almost_equal( + tcs.attrs[attribute], expectation.attrs[attribute] + ) + else: + for i in range(len(expectation.attrs[attribute])): + np.testing.assert_array_almost_equal( + tcs.attrs[attribute][i], expectation.attrs[attribute][i] + ) + + +# ------------------------------------------------------------------------------------ +# DISCRETE TUNING CURVE TESTS +# ------------------------------------------------------------------------------------ @pytest.mark.parametrize( - "group, dict_ep, expected_exception", + "data, epochs_dict, kwargs, expectation", [ + # data ( - "a", - { - 0: nap.IntervalSet(start=0, end=50), - 1: nap.IntervalSet(start=50, end=100), - }, - pytest.raises(TypeError, match="group should be a TsGroup."), + [1], + {"0": nap.IntervalSet(0, 100)}, + {}, + pytest.raises( + TypeError, match="data should be a TsdFrame, TsGroup, Ts, or Tsd." + ), ), ( - get_group(), - "a", + None, + {"0": nap.IntervalSet(0, 100)}, + {}, pytest.raises( - TypeError, match="dict_ep should be a dictionary of IntervalSet" + TypeError, match="data should be a TsdFrame, TsGroup, Ts, or Tsd." ), ), ( - get_group(), - {0: "a", 1: nap.IntervalSet(start=50, end=100)}, + {1: nap.Ts([1, 2, 3])}, + {"0": nap.IntervalSet(0, 100)}, + {}, pytest.raises( - TypeError, match="dict_ep argument should contain only IntervalSet." + TypeError, match="data should be a TsdFrame, TsGroup, Ts, or Tsd." + ), + ), + (get_group_n(1), {"0": nap.IntervalSet(0, 100)}, {}, does_not_raise()), + (get_group_n(3), {"0": nap.IntervalSet(0, 100)}, {}, does_not_raise()), + ( + get_group_n(1).count(0.1), + {"0": nap.IntervalSet(0, 100)}, + {}, + does_not_raise(), + ), + ( + get_group_n(3).count(0.1), + {"0": nap.IntervalSet(0, 100)}, + {}, + does_not_raise(), + ), + ( + nap.Tsd(t=[1, 2, 3], d=[1, 1, 1]), + {"0": nap.IntervalSet(0, 100)}, + {}, + does_not_raise(), + ), + (nap.Ts([1, 2, 3]), {"0": nap.IntervalSet(0, 100)}, {}, does_not_raise()), + # epochs_dict + ( + get_group_n(1), + 1, + {}, + pytest.raises( + TypeError, match="epochs_dict should be a dictionary of IntervalSets." + ), + ), + ( + get_group_n(1), + None, + {}, + pytest.raises( + TypeError, match="epochs_dict should be a dictionary of IntervalSets." + ), + ), + ( + get_group_n(1), + nap.IntervalSet(0, 100), + {}, + pytest.raises( + TypeError, match="epochs_dict should be a dictionary of IntervalSets." + ), + ), + ( + get_group_n(1), + [nap.IntervalSet(0, 100)], + {}, + pytest.raises( + TypeError, match="epochs_dict should be a dictionary of IntervalSets." + ), + ), + ( + get_group_n(1), + {"0": nap.IntervalSet(0, 100), "1": 0}, + {}, + pytest.raises( + TypeError, match="epochs_dict should be a dictionary of IntervalSets." ), ), + ( + get_group_n(1), + {"0": nap.IntervalSet(0, 100)}, + {}, + does_not_raise(), + ), + ( + get_group_n(1), + {"0": nap.IntervalSet(0, 100), "1": nap.IntervalSet(0, 50)}, + {}, + does_not_raise(), + ), + # return pandas + ( + get_group_n(1), + {"0": nap.IntervalSet(0, 100)}, + {"return_pandas": 2}, + pytest.raises( + TypeError, + match="return_pandas should be a boolean.", + ), + ), + ( + get_group_n(1), + {"0": nap.IntervalSet(0, 100)}, + {"return_pandas": "1"}, + pytest.raises( + TypeError, + match="return_pandas should be a boolean.", + ), + ), + ( + get_group_n(1), + {"0": nap.IntervalSet(0, 100)}, + {"return_pandas": True}, + does_not_raise(), + ), + ( + get_group_n(1), + {"0": nap.IntervalSet(0, 100)}, + {"return_pandas": False}, + does_not_raise(), + ), + ( + get_group_n(1), + {"0": nap.IntervalSet(0, 100)}, + {"return_pandas": 0}, + does_not_raise(), + ), + ( + get_group_n(1), + {"0": nap.IntervalSet(0, 100)}, + {"return_pandas": 1}, + does_not_raise(), + ), ], ) -def test_compute_discrete_tuning_curves_errors(group, dict_ep, expected_exception): - with expected_exception: - nap.compute_discrete_tuning_curves(group, dict_ep) +def test_compute_response_per_epoch_type_errors(data, epochs_dict, kwargs, expectation): + with expectation: + nap.compute_response_per_epoch(data, epochs_dict, **kwargs) @pytest.mark.parametrize( - "group, feature, nb_bins, ep, minmax, expected_exception", + "data, epochs_dict, kwargs, expectation", [ - ("a", get_feature(), 10, get_ep(), (0, 1), "group should be a TsGroup."), + # single rate unit, single epoch ( - get_group(), - "a", - 10, - get_ep(), - (0, 1), - r"feature should be a Tsd \(or TsdFrame with 1 column only\)", + get_group_n(1).count(1.0), + {"0": nap.IntervalSet(0, 50)}, + {}, + xr.DataArray( + [[10.0]], + dims=["unit", "epochs"], + coords={"unit": [1], "epochs": ["0"]}, + ), ), + # two rate units, single epoch ( - get_group(), - get_feature(), - "a", - get_ep(), - (0, 1), - r"nb_bins should be of type int \(or tuple with \(int, int\) for 2D tuning curves\).", + get_group_n(2).count(1.0), + {"0": nap.IntervalSet(0, 50)}, + {}, + xr.DataArray( + [[10.0], [1.0]], + dims=["unit", "epochs"], + coords={"unit": [1, 2], "epochs": ["0"]}, + ), ), - (get_group(), get_feature(), 10, "a", (0, 1), r"ep should be an IntervalSet"), + # two rate units, multiple epochs ( - get_group(), - get_feature(), - 10, - get_ep(), - 1, - r"minmax should be a tuple\/list of 2 numbers", + get_group_n(2).count(1.0), + {"0": nap.IntervalSet(0, 50), "1": nap.IntervalSet(50, 100)}, + {}, + xr.DataArray( + [[10.0, 10.0], [1.0, 1.0]], + dims=["unit", "epochs"], + coords={"unit": [1, 2], "epochs": ["0", "1"]}, + ), + ), + # two rate units, multiple epochs, overlapping + ( + get_group_n(2).count(1.0), + {"0": nap.IntervalSet(0, 100), "1": nap.IntervalSet(50, 100)}, + {}, + xr.DataArray( + [[10.0, 10.0], [1.0, 1.0]], + dims=["unit", "epochs"], + coords={"unit": [1, 2], "epochs": ["0", "1"]}, + ), + ), + # two rate units, multiple epochs, multiple intervals + ( + get_group_n(2).count(1.0), + { + "0": nap.IntervalSet([0, 20], [10, 30]), + "1": nap.IntervalSet([50, 70], [60, 80]), + }, + {}, + xr.DataArray( + [[10.0, 10.0], [1.0, 1.0]], + dims=["unit", "epochs"], + coords={"unit": [1, 2], "epochs": ["0", "1"]}, + ), + ), + # single unit, single epoch + ( + get_group_n(1), + {"0": nap.IntervalSet(50, 100)}, + {}, + xr.DataArray( + [[10.0]], + dims=["unit", "epochs"], + coords={"unit": [1], "epochs": ["0"]}, + ), + ), + # two units, single epoch + ( + get_group_n(2), + {"0": nap.IntervalSet(0, 100)}, + {}, + xr.DataArray( + [[10.0], [1.0]], + dims=["unit", "epochs"], + coords={"unit": [1, 2], "epochs": ["0"]}, + ), + ), + # two units, multiple epochs + ( + get_group_n(2), + {"0": nap.IntervalSet(0, 49.9999), "1": nap.IntervalSet(50, 100)}, + {}, + xr.DataArray( + [[10.0, 10.0], [1.0, 1.0]], + dims=["unit", "epochs"], + coords={"unit": [1, 2], "epochs": ["0", "1"]}, + ), + ), + # two units, multiple epochs, overlapping + ( + get_group_n(2), + {"0": nap.IntervalSet(0, 100), "1": nap.IntervalSet(50, 100)}, + {}, + xr.DataArray( + [[10.0, 10.0], [1.0, 1.0]], + dims=["unit", "epochs"], + coords={"unit": [1, 2], "epochs": ["0", "1"]}, + ), + ), + # two units, multiple epochs, multiple intervals + ( + get_group_n(2), + { + "0": nap.IntervalSet([0, 20], [10, 30]), + "1": nap.IntervalSet([50, 70], [60, 80]), + }, + {}, + xr.DataArray( + [[10.1, 10.1], [1.1, 1.1]], + dims=["unit", "epochs"], + coords={"unit": [1, 2], "epochs": ["0", "1"]}, + ), + ), + # two units, multiple epochs, return_pandas=True + ( + get_group_n(2), + {"0": nap.IntervalSet(0, 100), "1": nap.IntervalSet(50, 100)}, + {"return_pandas": True}, + xr.DataArray( + [[10.0, 10.0], [1.0, 1.0]], + dims=["unit", "epochs"], + coords={"unit": [1, 2], "epochs": ["0", "1"]}, + ) + .to_pandas() + .T, ), ], ) -def test_compute_1d_tuning_curves_errors( - group, feature, nb_bins, ep, minmax, expected_exception -): - with pytest.raises(TypeError, match=expected_exception): - nap.compute_1d_tuning_curves(group, feature, nb_bins, ep, minmax) +def test_compute_response_per_epoch(data, epochs_dict, kwargs, expectation): + tcs = nap.compute_response_per_epoch(data, epochs_dict, **kwargs) + if isinstance(expectation, pd.DataFrame): + pd.testing.assert_frame_equal(tcs, expectation) + else: + xr.testing.assert_allclose(tcs, expectation) + + +# ------------------------------------------------------------------------------------ +# MUTUAL INFORMATION TESTS +# ------------------------------------------------------------------------------------ + + +def get_testing_set(n_units=1, n_features=1, pattern="uniform"): + dims = ["unit"] + [f"dim_{i}" for i in range(n_features)] + coords = {"unit": np.arange(n_units)} + shape = (n_units,) + (2,) * n_features # 2 bins per feature, for simplicity + for i in range(n_features): + coords[f"dim_{i}"] = np.arange(2) + + # Build tuning curves + data = np.zeros(shape) + + if pattern == "uniform": + data[:] = 1.0 + expected_mi_per_sec = 0.0 + expected_mi_per_spike = 0.0 + + elif pattern == "onehot": + # Each unit fires in a unique location only + for u in range(n_units): + index = [u] + [0] * n_features + data[tuple(index)] = 1.0 + + n_bins = np.prod(shape[1:]) + expected_mi_per_spike = np.log2(n_bins) + mean_rate = 1.0 / n_bins + expected_mi_per_sec = mean_rate * expected_mi_per_spike + + else: + raise ValueError("Unknown firing_pattern. Use 'uniform' or 'onehot'.") + + tuning_curves = xr.DataArray( + data, + coords=coords, + dims=dims, + attrs={"occupancy": np.ones(shape[1:]) / np.prod(shape[1:])}, + ) + + MI = pd.DataFrame( + data=np.stack( + [ + np.full(n_units, expected_mi_per_sec), + np.full(n_units, expected_mi_per_spike), + ], + axis=1, + ), + index=coords["unit"], + columns=["bits/sec", "bits/spike"], + ) + + return tuning_curves, MI @pytest.mark.parametrize( - "group, features, nb_bins, ep, minmax, expected_exception", + "tuning_curves, expectation", [ - ("a", get_features(), 10, get_ep(), (0, 1), "group should be a TsGroup."), ( - get_group(), - "a", - 10, - get_ep(), - (0, 1), - r"features should be a TsdFrame with 2 columns", + [], + pytest.raises( + TypeError, + match="tuning_curves should be an xr.DataArray as computed by compute_tuning_curves.", + ), ), ( - get_group(), - get_features(), - "a", - get_ep(), - (0, 1), - r"nb_bins should be of type int \(or tuple with \(int, int\) for 2D tuning curves\).", + 1, + pytest.raises( + TypeError, + match="tuning_curves should be an xr.DataArray as computed by compute_tuning_curves.", + ), ), - (get_group(), get_features(), 10, "a", (0, 1), r"ep should be an IntervalSet"), ( - get_group(), - get_features(), - 10, - get_ep(), - 1, - r"minmax should be a tuple\/list of 2 numbers", + get_testing_set()[0].to_pandas().T, + pytest.raises( + TypeError, + match="tuning_curves should be an xr.DataArray as computed by compute_tuning_curves.", + ), + ), + ( + (lambda x: (x.attrs.clear(), x)[1])(get_testing_set()[0]), + pytest.raises( + ValueError, + match="No occupancy found in tuning curves.", + ), ), + (get_testing_set(1, 2)[0], does_not_raise()), + (get_testing_set(1, 3)[0], does_not_raise()), + (get_testing_set(2, 1)[0], does_not_raise()), + (get_testing_set(2, 2)[0], does_not_raise()), + (get_testing_set(2, 3)[0], does_not_raise()), ], ) -def test_compute_2d_tuning_curves_errors( - group, features, nb_bins, ep, minmax, expected_exception -): - with pytest.raises(TypeError, match=expected_exception): - nap.compute_2d_tuning_curves(group, features, nb_bins, ep, minmax) +def test_compute_mutual_information_errors(tuning_curves, expectation): + with expectation: + nap.compute_mutual_information(tuning_curves) + + +@pytest.mark.parametrize( + "n_units, n_features", + [(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)], +) +@pytest.mark.parametrize( + "pattern", + ["uniform", "onehot"], +) +def test_compute_mutual_information(n_units, n_features, pattern): + tuning_curves, expectation = get_testing_set(n_units, n_features, pattern) + actual = nap.compute_mutual_information(tuning_curves) + pd.testing.assert_frame_equal(actual, expectation) + + +# ------------------------------------------------------------------------------------ +# OLD MUTUAL INFORMATION TESTS +# ------------------------------------------------------------------------------------ + + +def get_group(): + return nap.TsGroup({0: nap.Ts(t=np.arange(0, 100))}) + + +def get_feature(): + return nap.Tsd( + t=np.arange(0, 100, 0.1), + d=np.arange(0, 100, 0.1) % 1.0, + time_support=nap.IntervalSet(0, 100), + ) + + +def get_features(): + tmp = np.vstack( + (np.repeat(np.arange(0, 100), 10), np.tile(np.arange(0, 100), 10)) + ).T + return nap.TsdFrame( + t=np.arange(0, 200, 0.1), + d=np.vstack((tmp, tmp[::-1])), + time_support=nap.IntervalSet(0, 200), + ) + + +def get_ep(): + return nap.IntervalSet(start=0, end=50) + + +def get_tsdframe(): + return nap.TsdFrame(t=np.arange(0, 100), d=np.ones((100, 2))) @pytest.mark.parametrize( @@ -248,26 +1411,207 @@ def test_compute_1d_mutual_info_errors( ), ], ) -def test_compute_2d_mutual_info_errors( - dict_tc, features, ep, minmax, bitssec, expected_exception -): - with pytest.raises(TypeError, match=expected_exception): - nap.compute_2d_mutual_info(dict_tc, features, ep, minmax, bitssec) +def test_compute_2d_mutual_info_errors( + dict_tc, features, ep, minmax, bitssec, expected_exception +): + with pytest.raises(TypeError, match=expected_exception): + nap.compute_2d_mutual_info(dict_tc, features, ep, minmax, bitssec) + + +@pytest.mark.filterwarnings("ignore") +@pytest.mark.parametrize( + "args, kwargs, expected", + [ + ( + ( + pd.DataFrame(index=np.arange(0, 2), data=np.array([0, 10])), + nap.Tsd(t=np.arange(100), d=np.tile(np.arange(2), 50)), + ), + {}, + np.array([[1.0]]), + ), + ( + ( + pd.DataFrame(index=np.arange(0, 2), data=np.array([0, 10])), + nap.Tsd(t=np.arange(100), d=np.tile(np.arange(2), 50)), + ), + {"bitssec": True}, + np.array([[5.0]]), + ), + ( + ( + pd.DataFrame(index=np.arange(0, 2), data=np.array([0, 10])), + nap.Tsd(t=np.arange(100), d=np.tile(np.arange(2), 50)), + ), + {"ep": nap.IntervalSet(start=0, end=49)}, + np.array([[1.0]]), + ), + ( + ( + pd.DataFrame(index=np.arange(0, 2), data=np.array([0, 10])), + nap.Tsd(t=np.arange(100), d=np.tile(np.arange(2), 50)), + ), + {"minmax": (0, 1)}, + np.array([[1.0]]), + ), + ( + ( + np.array([[0], [10]]), + nap.Tsd(t=np.arange(100), d=np.tile(np.arange(2), 50)), + ), + {"minmax": (0, 1)}, + np.array([[1.0]]), + ), + ], +) +def test_compute_1d_mutual_info(args, kwargs, expected): + tc = args[0] + feature = args[1] + si = nap.compute_1d_mutual_info(tc, feature, **kwargs) + assert isinstance(si, pd.DataFrame) + assert list(si.columns) == ["SI"] + if isinstance(tc, pd.DataFrame): + assert list(si.index.values) == list(tc.columns) + np.testing.assert_approx_equal(si.values, expected) + + +@pytest.mark.filterwarnings("ignore") +@pytest.mark.parametrize( + "args, kwargs, expected", + [ + ( + ( + {0: np.array([[0, 1], [0, 0]])}, + nap.TsdFrame( + t=np.arange(100), + d=np.tile(np.array([[0, 0, 1, 1], [0, 1, 0, 1]]), 25).T, + ), + ), + {}, + np.array([[2.0]]), + ), + ( + ( + np.array([[[0, 1], [0, 0]]]), + nap.TsdFrame( + t=np.arange(100), + d=np.tile(np.array([[0, 0, 1, 1], [0, 1, 0, 1]]), 25).T, + ), + ), + {}, + np.array([[2.0]]), + ), + ( + ( + {0: np.array([[0, 1], [0, 0]])}, + nap.TsdFrame( + t=np.arange(100), + d=np.tile(np.array([[0, 0, 1, 1], [0, 1, 0, 1]]), 25).T, + ), + ), + {"bitssec": True}, + np.array([[0.5]]), + ), + ( + ( + {0: np.array([[0, 1], [0, 0]])}, + nap.TsdFrame( + t=np.arange(100), + d=np.tile(np.array([[0, 0, 1, 1], [0, 1, 0, 1]]), 25).T, + ), + ), + {"ep": nap.IntervalSet(start=0, end=7)}, + np.array([[2.0]]), + ), + ( + ( + {0: np.array([[0, 1], [0, 0]])}, + nap.TsdFrame( + t=np.arange(100), + d=np.tile(np.array([[0, 0, 1, 1], [0, 1, 0, 1]]), 25).T, + ), + ), + {"minmax": (0, 1, 0, 1)}, + np.array([[2.0]]), + ), + ], +) +def test_compute_2d_mutual_info(args, kwargs, expected): + dict_tc = args[0] + features = args[1] + si = nap.compute_2d_mutual_info(dict_tc, features, **kwargs) + assert isinstance(si, pd.DataFrame) + assert list(si.columns) == ["SI"] + if isinstance(dict_tc, dict): + assert list(si.index.values) == list(dict_tc.keys()) + np.testing.assert_approx_equal(si.values, expected) + + +# ------------------------------------------------------------------------------------ +# OLD TUNING CURVE TESTS +# ------------------------------------------------------------------------------------ + + +@pytest.mark.filterwarnings("ignore") +@pytest.mark.parametrize( + "group, dict_ep, expectation", + [ + ( + "a", + { + 0: nap.IntervalSet(start=0, end=50), + 1: nap.IntervalSet(start=50, end=100), + }, + pytest.raises(TypeError, match="group should be a TsGroup."), + ), + ( + get_group(), + "a", + pytest.raises( + TypeError, match="dict_ep should be a dictionary of IntervalSet" + ), + ), + ( + get_group(), + {0: "a", 1: nap.IntervalSet(start=50, end=100)}, + pytest.raises( + TypeError, match="dict_ep argument should contain only IntervalSet." + ), + ), + ], +) +def test_compute_discrete_tuning_curves_errors(group, dict_ep, expectation): + with expectation: + nap.compute_discrete_tuning_curves(group, dict_ep) + + +@pytest.mark.filterwarnings("ignore") +@pytest.mark.parametrize("group", [get_group()]) +@pytest.mark.parametrize( + "dict_ep", + [ + {0: nap.IntervalSet(start=0, end=50), 1: nap.IntervalSet(start=50, end=100)}, + { + "0": nap.IntervalSet(start=0, end=50), + "1": nap.IntervalSet(start=50, end=100), + }, + ], +) +def test_compute_discrete_tuning_curves(group, dict_ep): + tc = nap.compute_discrete_tuning_curves(group, dict_ep) + assert len(tc) == 2 + assert list(tc.columns) == list(group.keys()) + assert list(tc.index.values) == list(dict_ep.keys()) + np.testing.assert_almost_equal(tc.iloc[0, 0], 51 / 50) + np.testing.assert_almost_equal(tc.iloc[1, 0], 1) @pytest.mark.parametrize( - "tsdframe, feature, nb_bins, ep, minmax, expected_exception", + "group, feature, nb_bins, ep, minmax, expected_exception", [ + ("a", get_feature(), 10, get_ep(), (0, 1), "group should be a TsGroup."), ( - "a", - get_feature(), - 10, - get_ep(), - (0, 1), - "Argument tsdframe should be of type Tsd or TsdFrame.", - ), - ( - get_tsdframe(), + get_group(), "a", 10, get_ep(), @@ -275,23 +1619,16 @@ def test_compute_2d_mutual_info_errors( r"feature should be a Tsd \(or TsdFrame with 1 column only\)", ), ( - get_tsdframe(), + get_group(), get_feature(), "a", get_ep(), (0, 1), r"nb_bins should be of type int \(or tuple with \(int, int\) for 2D tuning curves\).", ), + (get_group(), get_feature(), 10, "a", (0, 1), r"ep should be an IntervalSet"), ( - get_tsdframe(), - get_feature(), - 10, - "a", - (0, 1), - r"ep should be an IntervalSet", - ), - ( - get_tsdframe(), + get_group(), get_feature(), 10, get_ep(), @@ -300,26 +1637,19 @@ def test_compute_2d_mutual_info_errors( ), ], ) -def test_compute_1d_tuning_curves_continuous_errors( - tsdframe, feature, nb_bins, ep, minmax, expected_exception +def test_compute_1d_tuning_curves_errors( + group, feature, nb_bins, ep, minmax, expected_exception ): with pytest.raises(TypeError, match=expected_exception): - nap.compute_1d_tuning_curves_continuous(tsdframe, feature, nb_bins, ep, minmax) + nap.compute_1d_tuning_curves(group, feature, nb_bins, ep, minmax) @pytest.mark.parametrize( - "tsdframe, features, nb_bins, ep, minmax, expected_exception", + "group, features, nb_bins, ep, minmax, expected_exception", [ + ("a", get_features(), 10, get_ep(), (0, 1), "group should be a TsGroup."), ( - "a", - get_features(), - 10, - get_ep(), - (0, 1), - "Argument tsdframe should be of type Tsd or TsdFrame.", - ), - ( - get_tsdframe(), + get_group(), "a", 10, get_ep(), @@ -327,23 +1657,16 @@ def test_compute_1d_tuning_curves_continuous_errors( r"features should be a TsdFrame with 2 columns", ), ( - get_tsdframe(), + get_group(), get_features(), "a", get_ep(), (0, 1), r"nb_bins should be of type int \(or tuple with \(int, int\) for 2D tuning curves\).", ), + (get_group(), get_features(), 10, "a", (0, 1), r"ep should be an IntervalSet"), ( - get_tsdframe(), - get_features(), - 10, - "a", - (0, 1), - r"ep should be an IntervalSet", - ), - ( - get_tsdframe(), + get_group(), get_features(), 10, get_ep(), @@ -352,100 +1675,16 @@ def test_compute_1d_tuning_curves_continuous_errors( ), ], ) -def test_compute_2d_tuning_curves_continuous_errors( - tsdframe, features, nb_bins, ep, minmax, expected_exception +def test_compute_2d_tuning_curves_errors( + group, features, nb_bins, ep, minmax, expected_exception ): with pytest.raises(TypeError, match=expected_exception): - nap.compute_2d_tuning_curves_continuous(tsdframe, features, nb_bins, ep, minmax) - - -######################## -# ValueError test -######################## -@pytest.mark.parametrize( - "func, args, minmax, expected", - [ - ( - nap.compute_1d_tuning_curves, - (get_group(), get_feature(), 10), - (0, 1, 2), - "minmax should be of length 2.", - ), - ( - nap.compute_1d_tuning_curves_continuous, - (get_tsdframe(), get_feature(), 10), - (0, 1, 2), - "minmax should be of length 2.", - ), - ( - nap.compute_2d_tuning_curves, - (get_group(), get_features(), 10), - (0, 1, 2), - "minmax should be of length 4.", - ), - ( - nap.compute_2d_tuning_curves_continuous, - (get_tsdframe(), get_features(), 10), - (0, 1, 2), - "minmax should be of length 4.", - ), - ( - nap.compute_2d_tuning_curves, - (get_group(), nap.TsdFrame(t=np.arange(10), d=np.ones((10, 3))), 10), - (0, 1), - "features should have 2 columns only.", - ), - ( - nap.compute_1d_tuning_curves, - (get_group(), nap.TsdFrame(t=np.arange(10), d=np.ones((10, 3))), 10), - (0, 1), - r"feature should be a Tsd \(or TsdFrame with 1 column only\)", - ), - ( - nap.compute_2d_tuning_curves, - (get_group(), get_features(), (0, 1, 2)), - (0, 1, 2, 3), - r"nb_bins should be of type int \(or tuple with \(int, int\) for 2D tuning curves\).", - ), - ( - nap.compute_2d_tuning_curves_continuous, - (get_tsdframe(), get_features(), (0, 1, 2)), - (0, 1, 2, 3), - r"nb_bins should be of type int \(or tuple with \(int, int\) for 2D tuning curves\).", - ), - ], -) -def test_compute_tuning_curves_value_error(func, args, minmax, expected): - with pytest.raises(ValueError, match=expected): - func(*args, minmax=minmax) - - -######################## -# Normal test -######################## -@pytest.mark.parametrize("group", [get_group()]) -@pytest.mark.parametrize( - "dict_ep", - [ - {0: nap.IntervalSet(start=0, end=50), 1: nap.IntervalSet(start=50, end=100)}, - { - "0": nap.IntervalSet(start=0, end=50), - "1": nap.IntervalSet(start=50, end=100), - }, - ], -) -def test_compute_discrete_tuning_curves(group, dict_ep): - tc = nap.compute_discrete_tuning_curves(group, dict_ep) - assert len(tc) == 2 - assert list(tc.columns) == list(group.keys()) - assert list(tc.index.values) == list(dict_ep.keys()) - np.testing.assert_almost_equal(tc.iloc[0, 0], 51 / 50) - np.testing.assert_almost_equal(tc.iloc[1, 0], 1) + nap.compute_2d_tuning_curves(group, features, nb_bins, ep, minmax) @pytest.mark.filterwarnings("ignore") @pytest.mark.parametrize( - "args, kwargs, expected", + "args, kwargs, expectation", [ ((get_group(), get_feature(), 10), {}, np.array([10.0] + [0.0] * 9)[:, None]), ( @@ -465,7 +1704,7 @@ def test_compute_discrete_tuning_curves(group, dict_ep): ), ], ) -def test_compute_1d_tuning_curves(args, kwargs, expected): +def test_compute_1d_tuning_curves(args, kwargs, expectation): tc = nap.compute_1d_tuning_curves(*args, **kwargs) # Columns assert list(tc.columns) == list(args[0].keys()) @@ -479,19 +1718,19 @@ def test_compute_1d_tuning_curves(args, kwargs, expected): np.testing.assert_almost_equal(tmp[0:-1] + np.diff(tmp) / 2, tc.index.values) # Array - np.testing.assert_almost_equal(tc.values, expected) + np.testing.assert_almost_equal(tc.values, expectation) @pytest.mark.filterwarnings("ignore") @pytest.mark.parametrize( - "args, kwargs, expected", + "args, kwargs, expectation", [ ((get_group(), get_features(), 10), {}, np.ones((10, 10)) * 0.5), ((get_group(), get_features(), (10, 10)), {}, np.ones((10, 10)) * 0.5), ( (get_group(), get_features(), 10), {"ep": nap.IntervalSet(0, 400)}, - np.ones((10, 10)) * 0.25, + np.ones((10, 10)) * 0.5, ), ( (get_group(), get_features(), 10), @@ -505,7 +1744,7 @@ def test_compute_1d_tuning_curves(args, kwargs, expected): ), ], ) -def test_compute_2d_tuning_curves(args, kwargs, expected): +def test_compute_2d_tuning_curves(args, kwargs, expectation): tc, xy = nap.compute_2d_tuning_curves(*args, **kwargs) assert isinstance(tc, dict) @@ -531,141 +1770,12 @@ def test_compute_2d_tuning_curves(args, kwargs, expected): # Values for i in tc.keys(): assert tc[i].shape == nb_bins - np.testing.assert_almost_equal(tc[i], expected) - - -@pytest.mark.filterwarnings("ignore") -@pytest.mark.parametrize( - "args, kwargs, expected", - [ - ( - ( - pd.DataFrame(index=np.arange(0, 2), data=np.array([0, 10])), - nap.Tsd(t=np.arange(100), d=np.tile(np.arange(2), 50)), - ), - {}, - np.array([[1.0]]), - ), - ( - ( - pd.DataFrame(index=np.arange(0, 2), data=np.array([0, 10])), - nap.Tsd(t=np.arange(100), d=np.tile(np.arange(2), 50)), - ), - {"bitssec": True}, - np.array([[5.0]]), - ), - ( - ( - pd.DataFrame(index=np.arange(0, 2), data=np.array([0, 10])), - nap.Tsd(t=np.arange(100), d=np.tile(np.arange(2), 50)), - ), - {"ep": nap.IntervalSet(start=0, end=49)}, - np.array([[1.0]]), - ), - ( - ( - pd.DataFrame(index=np.arange(0, 2), data=np.array([0, 10])), - nap.Tsd(t=np.arange(100), d=np.tile(np.arange(2), 50)), - ), - {"minmax": (0, 1)}, - np.array([[1.0]]), - ), - ( - ( - np.array([[0], [10]]), - nap.Tsd(t=np.arange(100), d=np.tile(np.arange(2), 50)), - ), - {"minmax": (0, 1)}, - np.array([[1.0]]), - ), - ], -) -def test_compute_1d_mutual_info(args, kwargs, expected): - tc = args[0] - feature = args[1] - si = nap.compute_1d_mutual_info(tc, feature, **kwargs) - assert isinstance(si, pd.DataFrame) - assert list(si.columns) == ["SI"] - if isinstance(tc, pd.DataFrame): - assert list(si.index.values) == list(tc.columns) - np.testing.assert_approx_equal(si.values, expected) - - -@pytest.mark.filterwarnings("ignore") -@pytest.mark.parametrize( - "args, kwargs, expected", - [ - ( - ( - {0: np.array([[0, 1], [0, 0]])}, - nap.TsdFrame( - t=np.arange(100), - d=np.tile(np.array([[0, 0, 1, 1], [0, 1, 0, 1]]), 25).T, - ), - ), - {}, - np.array([[2.0]]), - ), - ( - ( - np.array([[[0, 1], [0, 0]]]), - nap.TsdFrame( - t=np.arange(100), - d=np.tile(np.array([[0, 0, 1, 1], [0, 1, 0, 1]]), 25).T, - ), - ), - {}, - np.array([[2.0]]), - ), - ( - ( - {0: np.array([[0, 1], [0, 0]])}, - nap.TsdFrame( - t=np.arange(100), - d=np.tile(np.array([[0, 0, 1, 1], [0, 1, 0, 1]]), 25).T, - ), - ), - {"bitssec": True}, - np.array([[0.5]]), - ), - ( - ( - {0: np.array([[0, 1], [0, 0]])}, - nap.TsdFrame( - t=np.arange(100), - d=np.tile(np.array([[0, 0, 1, 1], [0, 1, 0, 1]]), 25).T, - ), - ), - {"ep": nap.IntervalSet(start=0, end=7)}, - np.array([[2.0]]), - ), - ( - ( - {0: np.array([[0, 1], [0, 0]])}, - nap.TsdFrame( - t=np.arange(100), - d=np.tile(np.array([[0, 0, 1, 1], [0, 1, 0, 1]]), 25).T, - ), - ), - {"minmax": (0, 1, 0, 1)}, - np.array([[2.0]]), - ), - ], -) -def test_compute_2d_mutual_info(args, kwargs, expected): - dict_tc = args[0] - features = args[1] - si = nap.compute_2d_mutual_info(dict_tc, features, **kwargs) - assert isinstance(si, pd.DataFrame) - assert list(si.columns) == ["SI"] - if isinstance(dict_tc, dict): - assert list(si.index.values) == list(dict_tc.keys()) - np.testing.assert_approx_equal(si.values, expected) + np.testing.assert_almost_equal(tc[i], expectation) @pytest.mark.filterwarnings("ignore") @pytest.mark.parametrize( - "args, kwargs, expected", + "args, kwargs, expectation", [ ( (get_tsdframe(), get_feature(), 10), @@ -699,7 +1809,7 @@ def test_compute_2d_mutual_info(args, kwargs, expected): ), ], ) -def test_compute_1d_tuning_curves_continuous(args, kwargs, expected): +def test_compute_1d_tuning_curves_continuous(args, kwargs, expectation): tsdframe, feature, nb_bins = args tc = nap.compute_1d_tuning_curves_continuous(tsdframe, feature, nb_bins, **kwargs) # Columns @@ -710,14 +1820,15 @@ def test_compute_1d_tuning_curves_continuous(args, kwargs, expected): if "minmax" in kwargs: tmp = np.linspace(kwargs["minmax"][0], kwargs["minmax"][1], nb_bins + 1) else: - tmp = np.linspace(np.min(args[1]), np.max(args[1]), nb_bins + 1) + tmp = np.linspace(np.min(feature), np.max(feature), nb_bins + 1) np.testing.assert_almost_equal(tmp[0:-1] + np.diff(tmp) / 2, tc.index.values) # Array - np.testing.assert_almost_equal(tc.values, expected) + np.testing.assert_almost_equal(tc.values, expectation) +@pytest.mark.filterwarnings("ignore") @pytest.mark.parametrize( - "tsdframe, nb_bins, kwargs, expected", + "tsdframe, nb_bins, kwargs, expectation", [ ( nap.TsdFrame( @@ -736,13 +1847,13 @@ def test_compute_1d_tuning_curves_continuous(args, kwargs, expected): ), 2, {}, - {"x": np.array([[1, 0], [0, 0]]), "y": np.array([[2, 0], [0, 0]])}, + {"x": np.ones((2, 2)), "y": np.ones((2, 2)) * 2}, ), ( nap.Tsd(t=np.arange(0, 100), d=np.hstack((np.ones((100,)) * 2))), 2, {}, - {0: np.array([[2, 0], [0, 0]])}, + {0: np.ones((2, 2)) * 2}, ), ( nap.TsdFrame( @@ -751,7 +1862,7 @@ def test_compute_1d_tuning_curves_continuous(args, kwargs, expected): ), (1, 2), {}, - {0: np.array([[1.0, 0.0]]), 1: np.array([[2.0, 0.0]])}, + {0: np.array([[1.0, 1.0]]), 1: np.array([[2.0, 2.0]])}, ), ( nap.TsdFrame( @@ -782,7 +1893,7 @@ def test_compute_1d_tuning_curves_continuous(args, kwargs, expected): ), ], ) -def test_compute_2d_tuning_curves_continuous(tsdframe, nb_bins, kwargs, expected): +def test_compute_2d_tuning_curves_continuous(tsdframe, nb_bins, kwargs, expectation): features = nap.TsdFrame( t=np.arange(100), d=np.tile(np.array([[0, 0, 1, 1], [0, 1, 0, 1]]), 25).T ) @@ -816,4 +1927,4 @@ def test_compute_2d_tuning_curves_continuous(tsdframe, nb_bins, kwargs, expected # Values for i in tc.keys(): assert tc[i].shape == nb_bins - np.testing.assert_almost_equal(tc[i], expected[i]) + np.testing.assert_almost_equal(tc[i], expectation[i]) diff --git a/tox.ini b/tox.ini index e6ad4d7ad..15e27741d 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,6 @@ [tox] isolated_build = True envlist = py310,py311 -requires = tox-conda [testenv] # means we'll run the equivalent of `pip install .[dev]`, also installing pytest @@ -18,6 +17,10 @@ commands = coverage run --source=pynapple --branch -m pytest tests/ coverage report -m +[testenv:params] +commands= + python scripts/check_parameter_naming.py + [gh-actions] python = 3.10: py310