From 741a0f95acc238bc1a16e51ce6e1d992c4174ca0 Mon Sep 17 00:00:00 2001 From: ershook Date: Fri, 1 May 2026 17:17:31 -0400 Subject: [PATCH] Add column / row wise normalization to imshow and pyrshow --- TESTS/unitTests.py | 146 ++++++++++++++++++++++++- src/pyrtools/tools/display.py | 197 ++++++++++++++++++++++++++++------ 2 files changed, 312 insertions(+), 31 deletions(-) diff --git a/TESTS/unitTests.py b/TESTS/unitTests.py index abc48cb..e029ee1 100755 --- a/TESTS/unitTests.py +++ b/TESTS/unitTests.py @@ -9,6 +9,7 @@ import pyrtools as pt from pyrtools.pyramids.pyramid import Pyramid +from pyrtools.tools.display import colormap_range import scipy.io import os @@ -1485,7 +1486,150 @@ def test_pyrshow_2d_shape_err(self): with self.assertRaises(ValueError): pt.pyrshow(pyr.pyr_coeffs) - +def _get_clims(fig): + """get vmin vmax for each image in fig as list of tuples (vmin, vmax)""" + return [ax.images[0].get_clim() for ax in fig.axes if ax.images] + +# define test images such that each image has a different range of values, +# so we can test that the correct vrange is applied to each one +IMAGES = [np.arange(4 * i, 4 * i + 4, dtype=float).reshape(2, 2) for i in range(4)] + +class TestVrange(unittest.TestCase): + + def tearDown(self): + plt.close("all") + + def _imshow(self, vrange): + return pt.imshow(IMAGES, vrange=vrange, zoom=1, col_wrap=2) + + def _expected_clims(self, vrange): + clims, _ = colormap_range(image=IMAGES, vrange=vrange, cmap=None, n_rows=2, n_cols=2) + return clims + + def test_global_vrange_all_images_share_clim(self): + for mode in range(4): + with self.subTest(mode=mode): + clims = _get_clims(self._imshow(f"auto{mode}")) + self.assertTrue(all(c == clims[0] for c in clims)) + + def test_global_vrange_vmin(self): + for mode in range(4): + for img_idx in range(4): + with self.subTest(mode=mode, img_idx=img_idx): + vmin, _ = _get_clims(self._imshow(f"auto{mode}"))[img_idx] + exp_vmin, _ = self._expected_clims(f"auto{mode}")[img_idx] + self.assertTrue(np.isclose(vmin, exp_vmin, atol=1e-6)) + + def test_global_vrange_vmax(self): + for mode in range(4): + for img_idx in range(4): + with self.subTest(mode=mode, img_idx=img_idx): + _, vmax = _get_clims(self._imshow(f"auto{mode}"))[img_idx] + _, exp_vmax = self._expected_clims(f"auto{mode}")[img_idx] + self.assertTrue(np.isclose(vmax, exp_vmax, atol=1e-6)) + + def test_global_vrange_title_matches_clim(self): + for mode in range(4): + for img_idx in range(4): + with self.subTest(mode=mode, img_idx=img_idx): + fig = self._imshow(f"auto{mode}") + clim_vmin, clim_vmax = _get_clims(fig)[img_idx] + title_vmin, title_vmax = _get_title_clims(fig)[img_idx] + self.assertEqual("{:.1e}".format(clim_vmin), "{:.1e}".format(title_vmin)) + self.assertEqual("{:.1e}".format(clim_vmax), "{:.1e}".format(title_vmax)) + + def test_row_vrange_same_row_shares_clim(self): + for mode in range(4): + with self.subTest(mode=mode): + clims = _get_clims(self._imshow(f"auto{mode}row")) + self.assertEqual(clims[0], clims[1], "row 0 images differ") + self.assertEqual(clims[2], clims[3], "row 1 images differ") + + def test_row_vrange_vmin(self): + for mode in range(4): + for img_idx in range(4): + with self.subTest(mode=mode, img_idx=img_idx): + vmin, _ = _get_clims(self._imshow(f"auto{mode}row"))[img_idx] + exp_vmin, _ = self._expected_clims(f"auto{mode}row")[img_idx] + self.assertTrue(np.isclose(vmin, exp_vmin, atol=1e-6)) + + def test_row_vrange_vmax(self): + for mode in range(4): + for img_idx in range(4): + with self.subTest(mode=mode, img_idx=img_idx): + _, vmax = _get_clims(self._imshow(f"auto{mode}row"))[img_idx] + _, exp_vmax = self._expected_clims(f"auto{mode}row")[img_idx] + self.assertTrue(np.isclose(vmax, exp_vmax, atol=1e-6)) + + def test_row_vrange_title_matches_clim(self): + for mode in range(4): + for img_idx in range(4): + with self.subTest(mode=mode, img_idx=img_idx): + fig = self._imshow(f"auto{mode}row") + clim_vmin, clim_vmax = _get_clims(fig)[img_idx] + title_vmin, title_vmax = _get_title_clims(fig)[img_idx] + self.assertEqual("{:.1e}".format(clim_vmin), "{:.1e}".format(title_vmin)) + self.assertEqual("{:.1e}".format(clim_vmax), "{:.1e}".format(title_vmax)) + + def test_col_vrange_same_col_shares_clim(self): + for mode in range(4): + with self.subTest(mode=mode): + clims = _get_clims(self._imshow(f"auto{mode}col")) + self.assertEqual(clims[0], clims[2], "col 0 images differ") + self.assertEqual(clims[1], clims[3], "col 1 images differ") + + def test_col_vrange_vmin(self): + for mode in range(4): + for img_idx in range(4): + with self.subTest(mode=mode, img_idx=img_idx): + vmin, _ = _get_clims(self._imshow(f"auto{mode}col"))[img_idx] + exp_vmin, _ = self._expected_clims(f"auto{mode}col")[img_idx] + self.assertTrue(np.isclose(vmin, exp_vmin, atol=1e-6)) + + def test_col_vrange_vmax(self): + for mode in range(4): + for img_idx in range(4): + with self.subTest(mode=mode, img_idx=img_idx): + _, vmax = _get_clims(self._imshow(f"auto{mode}col"))[img_idx] + _, exp_vmax = self._expected_clims(f"auto{mode}col")[img_idx] + self.assertTrue(np.isclose(vmax, exp_vmax, atol=1e-6)) + + def test_col_vrange_title_matches_clim(self): + for mode in range(4): + for img_idx in range(4): + with self.subTest(mode=mode, img_idx=img_idx): + fig = self._imshow(f"auto{mode}col") + clim_vmin, clim_vmax = _get_clims(fig)[img_idx] + title_vmin, title_vmax = _get_title_clims(fig)[img_idx] + self.assertEqual("{:.1e}".format(clim_vmin), "{:.1e}".format(title_vmin)) + self.assertEqual("{:.1e}".format(clim_vmax), "{:.1e}".format(title_vmax)) + + def test_indep_vrange_vmin(self): + for mode in range(4): + for img_idx in range(4): + with self.subTest(mode=mode, img_idx=img_idx): + vmin, _ = _get_clims(self._imshow(f"indep{mode}"))[img_idx] + exp_vmin, _ = self._expected_clims(f"indep{mode}")[img_idx] + self.assertTrue(np.isclose(vmin, exp_vmin, atol=1e-6)) + + def test_indep_vrange_vmax(self): + for mode in range(4): + for img_idx in range(4): + with self.subTest(mode=mode, img_idx=img_idx): + _, vmax = _get_clims(self._imshow(f"indep{mode}"))[img_idx] + _, exp_vmax = self._expected_clims(f"indep{mode}")[img_idx] + self.assertTrue(np.isclose(vmax, exp_vmax, atol=1e-6)) + + def test_indep_vrange_title_matches_clim(self): + for mode in range(4): + for img_idx in range(4): + with self.subTest(mode=mode, img_idx=img_idx): + fig = self._imshow(f"indep{mode}") + clim_vmin, clim_vmax = _get_clims(fig)[img_idx] + title_vmin, title_vmax = _get_title_clims(fig)[img_idx] + self.assertEqual("{:.1e}".format(clim_vmin), "{:.1e}".format(title_vmin)) + self.assertEqual("{:.1e}".format(clim_vmax), "{:.1e}".format(title_vmax)) + def main(): unittest.main() diff --git a/src/pyrtools/tools/display.py b/src/pyrtools/tools/display.py index 39b8daa..16b2330 100644 --- a/src/pyrtools/tools/display.py +++ b/src/pyrtools/tools/display.py @@ -1,3 +1,4 @@ +import math import warnings import numpy as np import matplotlib.pyplot as plt @@ -225,7 +226,7 @@ def reshape_axis(ax, axis_size_pix): return ax -def colormap_range(image, vrange='indep1', cmap=None): +def colormap_range(image, vrange= 'indep1', cmap=None, n_rows = None, n_cols = None): """Find the appropriate ranges for colormaps of provided images Arguments @@ -236,10 +237,13 @@ def colormap_range(image, vrange='indep1', cmap=None): dimension), or list of 2d arrays. all images will be automatically rescaled so they're displayed at the same size. thus, their sizes must be scalar multiples of each other. - vrange : `tuple` or `str` + vrange : `tuple` or `list` or `str` If a 2-tuple, specifies the image values vmin/vmax that are mapped to (ie. clipped to) the minimum and maximum value of the colormap, respectively. + If a list of 2-tuples, the length of number of images, each image has an + independent vmin/vmax that are mapped to the minimum/maximum value of + the colormap, respectively. If a string: * `'auto0'`: all images have same vmin/vmax, which have the same absolute value, and come from the minimum or maximum across @@ -253,6 +257,17 @@ def colormap_range(image, vrange='indep1', cmap=None): the display intensity range. For example: vmin is the 10th percentile image value minus 1/8 times the difference between the 90th and 10th percentile + * `'auto[X]row'`: each row of the figure has the same vmin/vmax, which are + computed using the auto[X] methods described above Eg. + `'auto1row'`, `'auto2row'`, or `'auto3row'`Ie. min/max, + mean minus/plus 2 std dev, or percentile statistics are + computed across all images in a given row, and those + values are used as the vmin/vmax for all images in that row. + High pass and low pass residuals have independent vmin/vmax + based on min/max of the residual image itself. + * `'auto[X]col'`: each column of the figure has the same vmin/vmax, which are + computed using the auto[X] methods described above Eg. + `'auto1col'`, `'auto2col'`, or `'auto3col'`. * `'indep0'`: each image has an independent vmin/vmax, which have the same absolute value, which comes from either their minimum or maximum value, whichever has the larger absolute value. @@ -263,7 +278,6 @@ def colormap_range(image, vrange='indep1', cmap=None): * `'indep3'`: each image has an independent vmin/vmax, chosen so that the 10th/90th percentile values map to the 10th/90th percentile intensities. - Returns ------- vrange_list : `list` @@ -274,10 +288,27 @@ def colormap_range(image, vrange='indep1', cmap=None): # flatimg is one long 1d array, which enables the min, max, mean, std, and percentile calls to # operate on the values from each of the images simultaneously. flatimg = np.concatenate([i.flatten() for i in image]).flatten() - + if isinstance(vrange, str): if vrange[:4] == 'auto': - if vrange == 'auto0': + if 'row' in vrange: + assert n_cols is not None, "n_cols must be provided when using row-wise vrange (e.g. 'auto1row')" + vrange_list = [] + for i in range(math.ceil(len(image) / n_cols)): + vr, _ = colormap_range( + image[n_cols * i : n_cols * (i + 1)], vrange.split('row')[0] + ) + vrange_list.extend(vr) + elif 'col' in vrange: + assert n_cols is not None, "n_cols must be provided when using col-wise vrange" + vrange_list = [None] * len(image) + for j in range(n_cols): + col_images = [image[i] for i in range(j, len(image), n_cols)] + vr, _ = colormap_range(col_images, vrange.split('col')[0]) + for k, i in enumerate(range(j, len(image), n_cols)): + vrange_list[i] = vr[k] + + elif vrange == 'auto0': M = np.nanmax([np.abs(np.nanmin(flatimg)), np.abs(np.nanmax(flatimg))]) vrange_list = [-M, M] elif vrange == 'auto1' or vrange == 'auto': @@ -290,8 +321,9 @@ def colormap_range(image, vrange='indep1', cmap=None): p2 = np.nanpercentile(flatimg, 90) vrange_list = [p1-(p2-p1)/8.0, p2+(p2-p1)/8.0] - # make sure to return as many ranges as there are images - vrange_list = [vrange_list] * len(image) + if 'row' not in vrange and 'col' not in vrange: + # make sure to return as many ranges as there are images + vrange_list = [vrange_list] * len(image) elif vrange[:5] == 'indep': # independent vrange from recursive calls of this function per image @@ -301,17 +333,22 @@ def colormap_range(image, vrange='indep1', cmap=None): vrange_list, _ = colormap_range(image, vrange='auto1') warnings.warn('Unknown vrange argument, using auto1 instead') else: - # two numbers were passed, either as a list or tuple - if len(vrange) != 2: - raise Exception("If you're passing numbers to vrange," - "there must be 2 of them!") - vrange_list = [tuple(vrange)] * len(image) + # either a single 2-tuple was passed or a list of 2-tuples one for each image was passed. + if len(vrange) == 2: + vrange_list = [tuple(vrange)] * len(image) + elif len(vrange) == len(image): + # explicitly cast as tuple in case of list of lists + vrange_list = [tuple(v) for v in vrange] + else: + raise Exception("If you're passing numbers to vrange," + "there must be a single 2-tuple or as many as there are images!") + # double check that we're returning the right number of vranges assert len(image) == len(vrange_list) if cmap is None: - if '0' in vrange: + if isinstance(vrange, str) and '0' in vrange: cmap = cm.RdBu_r else: cmap = cm.gray @@ -630,7 +667,7 @@ def _setup_figure(ax, col_wrap, image, zoom, max_shape, vert_pct): else: fig = ax.figure axes = [reshape_axis(ax, zoom * max_shape)] - return fig, axes + return fig, axes, n_cols, n_rows def imshow(image, vrange='indep1', zoom=1, title='', col_wrap=None, ax=None, @@ -648,10 +685,12 @@ def imshow(image, vrange='indep1', zoom=1, title='', col_wrap=None, ax=None, `(n,h,w)` for multiple grayscale images). all images will be automatically rescaled so they're displayed at the same size. thus, their sizes must be scalar multiples of each other. - vrange : `tuple` or `str` + vrange : `tuple` or `list` or `str` If a 2-tuple, specifies the image values vmin/vmax that are mapped to - the minimum and maximum value of the colormap, respectively. If a - string: + the minimum and maximum value of the colormap, respectively. If a list + of 2-tuples, each image has an independent vmin/vmax, where each images + minimum/maximum values are specified by the 2-tuples in the list ordered + from first image to last. If a string: * `'auto0'`: all images have same vmin/vmax, which have the same absolute value, and come from the minimum or maximum across all @@ -665,6 +704,17 @@ def imshow(image, vrange='indep1', zoom=1, title='', col_wrap=None, ax=None, the display intensity range. For example: vmin is the 10th percentile image value minus 1/8 times the difference between the 90th and 10th percentile + * `'auto[X]row'`: each row of the figure has the same vmin/vmax, which are + computed using the auto[X] methods described above Eg. + `'auto1row'`, `'auto2row'`, or `'auto3row'`Ie. min/max, + mean minus/plus 2 std dev, or percentile statistics are + computed across all images in a given row, and those + values are used as the vmin/vmax for all images in that row. + High pass and low pass residuals have independent vmin/vmax + based on min/max of the residual image itself. + * `'auto[X]col'`: each column of the figure has the same vmin/vmax, which are + computed using the auto[X] methods described above Eg. + `'auto1col'`, `'auto2col'`, or `'auto3col'`. * `'indep0'`: each image has an independent vmin/vmax, which have the same absolute value, which comes from either their minimum or maximum value, whichever has the larger @@ -735,14 +785,13 @@ def imshow(image, vrange='indep1', zoom=1, title='', col_wrap=None, ax=None, # Process complex images for plotting, double-check image size to see if we # have RGB(A) images image, title, contains_rgb = _process_signal(image, title, plot_complex) - # make sure we can properly zoom all images zooms, max_shape = _check_zooms(image, zoom, contains_rgb) # get the figure and axes created - fig, axes = _setup_figure(ax, col_wrap, image, zoom, max_shape, vert_pct) + fig, axes, n_cols, n_rows = _setup_figure(ax, col_wrap, image, zoom, max_shape, vert_pct) - vrange_list, cmap = colormap_range(image=image, vrange=vrange, cmap=cmap) + vrange_list, cmap = colormap_range(image=image, vrange=vrange, cmap=cmap, n_rows = n_rows, n_cols = n_cols) for im, a, r, t, z in zip(image, axes, vrange_list, title, zooms): _showIm(im, a, r, z, t, cmap, **kwargs) @@ -776,10 +825,12 @@ def animshow(video, framerate=2., as_html5=True, repeat=False, Requires ipython to be installed. repeat : `bool` whether to loop the animation or just play it once - vrange : `tuple` or `str` + vrange : `tuple` or `list` or `str` If a 2-tuple, specifies the image values vmin/vmax that are mapped to the minimum and - maximum value of the colormap, respectively. If a string: - + maximum value of the colormap, respectively. If a list + of 2-tuples, each image has an independent vmin/vmax, where each images + minimum/maximum values are specified by the 2-tuples in the list ordered + from first image to last. If a string: * `'auto/auto1'`: all images have same vmin/vmax, which are the minimum/maximum values across all images * `'auto2'`: all images have same vmin/vmax, which are the mean (across all images) minus/ @@ -788,6 +839,17 @@ def animshow(video, framerate=2., as_html5=True, repeat=False, values to the 10th/90th percentile of the display intensity range. For example: vmin is the 10th percentile image value minus 1/8 times the difference between the 90th and 10th percentile + * `'auto[X]row'`: each row of the figure has the same vmin/vmax, which are + computed using the auto[X] methods described above Eg. + `'auto1row'`, `'auto2row'`, or `'auto3row'`Ie. min/max, + mean minus/plus 2 std dev, or percentile statistics are + computed across all images in a given row, and those + values are used as the vmin/vmax for all images in that row. + High pass and low pass residuals have independent vmin/vmax + based on min/max of the residual image itself. + * `'auto[X]col'`: each column of the figure has the same vmin/vmax, which are + computed using the auto[X] methods described above Eg. + `'auto1col'`, `'auto2col'`, or `'auto3col'`. * `'indep1'`: each image has an independent vmin/vmax, which are their minimum/maximum values * `'indep2'`: each image has an independent vmin/vmax, which is their mean minus/plus 2 @@ -844,8 +906,9 @@ def animshow(video, framerate=2., as_html5=True, repeat=False, title, vert_pct = _convert_title_to_list(title, video) video, title, contains_rgb = _process_signal(video, title, plot_complex, video=True) zooms, max_shape = _check_zooms(video, zoom, contains_rgb, video=True) - fig, axes = _setup_figure(ax, col_wrap, video, zoom, max_shape, vert_pct) - vrange_list, cmap = colormap_range(image=video, vrange=vrange, cmap=cmap) + fig, axes, n_cols, n_rows = _setup_figure(ax, col_wrap, video, zoom, max_shape, vert_pct) + vrange_list, cmap = colormap_range(image=video, vrange=vrange, + cmap=cmap, n_rows = n_rows, n_cols = n_cols) first_image = [v[0] for v in video] for im, a, r, t, z in zip(first_image, axes, vrange_list, title, zooms): @@ -889,9 +952,11 @@ def pyrshow(pyr_coeffs, is_complex=False, vrange='indep1', col_wrap=None, zoom=1 is_complex : `bool` default False, indicates whether the pyramids is real or complex indicating whether the pyramid is complex or real - vrange : `tuple` or `str` - If a 2-tuple, specifies the image values vmin/vmax that are mapped to the minimum and - maximum value of the colormap, respectively. If a string: + vrange : `tuple` or `list` or `str` + If a single 2-tuple, specifies the image values vmin/vmax that are mapped to the minimum and + maximum value of the colormap, respectively. + If a list of 2-tuples, each image has an independent vmin/vmax, where each images minimum/maximum + values are specified by the 2-tuples in the list ordered from first image to last. If a string: * `'auto/auto1'`: all images have same vmin/vmax, which are the minimum/maximum values across all images @@ -901,6 +966,17 @@ def pyrshow(pyr_coeffs, is_complex=False, vrange='indep1', col_wrap=None, zoom=1 values to the 10th/90th percentile of the display intensity range. For example: vmin is the 10th percentile image value minus 1/8 times the difference between the 90th and 10th percentile + * `'auto[X]row'`: each row of the figure has the same vmin/vmax, which are + computed using the auto[X] methods described above Eg. + `'auto1row'`, `'auto2row'`, or `'auto3row'`Ie. min/max, + mean minus/plus 2 std dev, or percentile statistics are + computed across all images in a given row, and those + values are used as the vmin/vmax for all images in that row. + High pass and low pass residuals have independent vmin/vmax + based on min/max of the residual image itself. + * `'auto[X]col'`: each column of the figure has the same vmin/vmax, which are + computed using the auto[X] methods described above Eg. + `'auto1col'`, `'auto2col'`, or `'auto3col'`. * `'indep1'`: each image has an independent vmin/vmax, which are their minimum/maximum values * `'indep2'`: each image has an independent vmin/vmax, which is their mean minus/plus 2 @@ -943,22 +1019,80 @@ def pyrshow(pyr_coeffs, is_complex=False, vrange='indep1', col_wrap=None, zoom=1 # not sure about scope here, so we make sure to copy the # pyr_coeffs dictionary. imgs, highpass, lowpass = convert_pyr_coeffs_to_pyr(pyr_coeffs.copy()) + imgs = [i.squeeze() for i in imgs] + + if is_complex: + # Make sure image is a list, do some preliminary checks + image_converted = _convert_signal_to_list(imgs) + + # want to do this check before converting title to a list (at which + # point `title is None` will always be False). we do it here instad + # of checking whether the first item of title is None because it's + # conceivable that the user passed `title=[None, 'important + # title']`, and in that case we do want the space for the title + titles, vert_pct = _convert_title_to_list('', imgs) + + plot_complex = kwargs.get("plot_complex", "rectangular") + imgs, titles, _ = _process_signal(image_converted, titles, plot_complex) + + if 'row' in vrange: + vrange_list = [] + for i in range(math.ceil(len(imgs) / num_orientations)): + vr, _ = colormap_range( + imgs[num_orientations * i : num_orientations * (i + 1)], vrange.split('row')[0] + ) + vrange_list.extend(vr) + + + ## If complex need to collect both imaginary and real parts of the coefficients for each "column" + # (i.e. each orientation) to compute the colormap range across both real and imaginary parts, + # so we loop through orientations and collect the corresponding real and imaginary parts + # of the coefficients for each orientation together to compute the colormap range for that column. + # If not complex, then we just loop through orientations and collect the coefficients for each orientation + # together to compute the colormap range for that column. + + + elif 'col' in vrange: + vrange_list = [None] * len(imgs) + if not is_complex: + for j in range(num_orientations): + col_images = [imgs[i] for i in range(j, len(imgs), num_orientations)] + vr, _ = colormap_range(col_images, vrange.split('col')[0]) + for k, i in enumerate(range(j, len(imgs), num_orientations)): + vrange_list[i] = vr[k] + else: + for j in range(0, num_orientations * 2, 2): + col_images = [] + for i in range(j, len(imgs), num_orientations * 2): + col_images.extend([imgs[i], imgs[i + 1]]) + vr, _ = colormap_range(col_images, vrange.split('col')[0]) + for k, i in enumerate(range(j, len(imgs), num_orientations * 2)): + vrange_list[i] = vr[k] + vrange_list[i + 1] = vr[k] + # we can similarly grab the labels for height and band # from the keys in this pyramid coefficients dictionary pyr_coeffs_keys = [k for k in pyr_coeffs.keys() if isinstance(k, tuple)] - titles = ["height %02d, band %02d" % (h, b) for h, b in sorted(pyr_coeffs_keys)] + if not is_complex: + titles = ["height %02d, band %02d" % (h, b) for h, b in sorted(pyr_coeffs_keys)] if show_residuals: if highpass is not None: titles += ["residual highpass"] imgs.append(highpass) + if 'row' in vrange or 'col' in vrange: + vrange_list.append([highpass.min(), highpass.max()]) if lowpass is not None: titles += ["residual lowpass"] imgs.append(lowpass) + if 'row' in vrange or 'col' in vrange: + vrange_list.append([lowpass.min(), lowpass.max()]) if col_wrap_new is not None and col_wrap_new != 1: if col_wrap is None: col_wrap = col_wrap_new # if these are really 1d (i.e., have shape (1, x) or (x, 1)), then we want them to be 1d + imgs = [i.squeeze() for i in imgs] + if imgs[0].ndim == 1: # then we just want to plot each of the bands in a different subplot, no need to be fancy. if col_wrap is not None: @@ -995,4 +1129,7 @@ def pyrshow(pyr_coeffs, is_complex=False, vrange='indep1', col_wrap=None, zoom=1 "times, where this number is the height of the " f"pyramid{residual_err_msg}. " f"Instead, found:\n{err_msg}") - return imshow(imgs, vrange=vrange, col_wrap=col_wrap, zoom=zoom, title=titles, **kwargs) + + if 'col' in vrange or 'row' in vrange: + vrange = vrange_list + return imshow(imgs, vrange=vrange, col_wrap=col_wrap, zoom=zoom, title=titles, **kwargs) \ No newline at end of file