Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 145 additions & 1 deletion TESTS/unitTests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import pyrtools as pt
from pyrtools.pyramids.pyramid import Pyramid
from pyrtools.tools.display import colormap_range

import scipy.io
import os
Expand Down Expand Up @@ -1485,7 +1486,150 @@ def test_pyrshow_2d_shape_err(self):
with self.assertRaises(ValueError):
pt.pyrshow(pyr.pyr_coeffs)


def _get_clims(fig):
"""get vmin vmax for each image in fig as list of tuples (vmin, vmax)"""
return [ax.images[0].get_clim() for ax in fig.axes if ax.images]

# define test images such that each image has a different range of values,
# so we can test that the correct vrange is applied to each one
IMAGES = [np.arange(4 * i, 4 * i + 4, dtype=float).reshape(2, 2) for i in range(4)]

class TestVrange(unittest.TestCase):

def tearDown(self):
plt.close("all")

def _imshow(self, vrange):
return pt.imshow(IMAGES, vrange=vrange, zoom=1, col_wrap=2)

def _expected_clims(self, vrange):
clims, _ = colormap_range(image=IMAGES, vrange=vrange, cmap=None, n_rows=2, n_cols=2)
return clims

def test_global_vrange_all_images_share_clim(self):
for mode in range(4):
with self.subTest(mode=mode):
clims = _get_clims(self._imshow(f"auto{mode}"))
self.assertTrue(all(c == clims[0] for c in clims))

def test_global_vrange_vmin(self):
for mode in range(4):
for img_idx in range(4):
with self.subTest(mode=mode, img_idx=img_idx):
vmin, _ = _get_clims(self._imshow(f"auto{mode}"))[img_idx]
exp_vmin, _ = self._expected_clims(f"auto{mode}")[img_idx]
self.assertTrue(np.isclose(vmin, exp_vmin, atol=1e-6))

def test_global_vrange_vmax(self):
for mode in range(4):
for img_idx in range(4):
with self.subTest(mode=mode, img_idx=img_idx):
_, vmax = _get_clims(self._imshow(f"auto{mode}"))[img_idx]
_, exp_vmax = self._expected_clims(f"auto{mode}")[img_idx]
self.assertTrue(np.isclose(vmax, exp_vmax, atol=1e-6))

def test_global_vrange_title_matches_clim(self):
for mode in range(4):
for img_idx in range(4):
with self.subTest(mode=mode, img_idx=img_idx):
fig = self._imshow(f"auto{mode}")
clim_vmin, clim_vmax = _get_clims(fig)[img_idx]
title_vmin, title_vmax = _get_title_clims(fig)[img_idx]
self.assertEqual("{:.1e}".format(clim_vmin), "{:.1e}".format(title_vmin))
self.assertEqual("{:.1e}".format(clim_vmax), "{:.1e}".format(title_vmax))

def test_row_vrange_same_row_shares_clim(self):
for mode in range(4):
with self.subTest(mode=mode):
clims = _get_clims(self._imshow(f"auto{mode}row"))
self.assertEqual(clims[0], clims[1], "row 0 images differ")
self.assertEqual(clims[2], clims[3], "row 1 images differ")

def test_row_vrange_vmin(self):
for mode in range(4):
for img_idx in range(4):
with self.subTest(mode=mode, img_idx=img_idx):
vmin, _ = _get_clims(self._imshow(f"auto{mode}row"))[img_idx]
exp_vmin, _ = self._expected_clims(f"auto{mode}row")[img_idx]
self.assertTrue(np.isclose(vmin, exp_vmin, atol=1e-6))

def test_row_vrange_vmax(self):
for mode in range(4):
for img_idx in range(4):
with self.subTest(mode=mode, img_idx=img_idx):
_, vmax = _get_clims(self._imshow(f"auto{mode}row"))[img_idx]
_, exp_vmax = self._expected_clims(f"auto{mode}row")[img_idx]
self.assertTrue(np.isclose(vmax, exp_vmax, atol=1e-6))

def test_row_vrange_title_matches_clim(self):
for mode in range(4):
for img_idx in range(4):
with self.subTest(mode=mode, img_idx=img_idx):
fig = self._imshow(f"auto{mode}row")
clim_vmin, clim_vmax = _get_clims(fig)[img_idx]
title_vmin, title_vmax = _get_title_clims(fig)[img_idx]
self.assertEqual("{:.1e}".format(clim_vmin), "{:.1e}".format(title_vmin))
self.assertEqual("{:.1e}".format(clim_vmax), "{:.1e}".format(title_vmax))

def test_col_vrange_same_col_shares_clim(self):
for mode in range(4):
with self.subTest(mode=mode):
clims = _get_clims(self._imshow(f"auto{mode}col"))
self.assertEqual(clims[0], clims[2], "col 0 images differ")
self.assertEqual(clims[1], clims[3], "col 1 images differ")

def test_col_vrange_vmin(self):
for mode in range(4):
for img_idx in range(4):
with self.subTest(mode=mode, img_idx=img_idx):
vmin, _ = _get_clims(self._imshow(f"auto{mode}col"))[img_idx]
exp_vmin, _ = self._expected_clims(f"auto{mode}col")[img_idx]
self.assertTrue(np.isclose(vmin, exp_vmin, atol=1e-6))

def test_col_vrange_vmax(self):
for mode in range(4):
for img_idx in range(4):
with self.subTest(mode=mode, img_idx=img_idx):
_, vmax = _get_clims(self._imshow(f"auto{mode}col"))[img_idx]
_, exp_vmax = self._expected_clims(f"auto{mode}col")[img_idx]
self.assertTrue(np.isclose(vmax, exp_vmax, atol=1e-6))

def test_col_vrange_title_matches_clim(self):
for mode in range(4):
for img_idx in range(4):
with self.subTest(mode=mode, img_idx=img_idx):
fig = self._imshow(f"auto{mode}col")
clim_vmin, clim_vmax = _get_clims(fig)[img_idx]
title_vmin, title_vmax = _get_title_clims(fig)[img_idx]
self.assertEqual("{:.1e}".format(clim_vmin), "{:.1e}".format(title_vmin))
self.assertEqual("{:.1e}".format(clim_vmax), "{:.1e}".format(title_vmax))

def test_indep_vrange_vmin(self):
for mode in range(4):
for img_idx in range(4):
with self.subTest(mode=mode, img_idx=img_idx):
vmin, _ = _get_clims(self._imshow(f"indep{mode}"))[img_idx]
exp_vmin, _ = self._expected_clims(f"indep{mode}")[img_idx]
self.assertTrue(np.isclose(vmin, exp_vmin, atol=1e-6))

def test_indep_vrange_vmax(self):
for mode in range(4):
for img_idx in range(4):
with self.subTest(mode=mode, img_idx=img_idx):
_, vmax = _get_clims(self._imshow(f"indep{mode}"))[img_idx]
_, exp_vmax = self._expected_clims(f"indep{mode}")[img_idx]
self.assertTrue(np.isclose(vmax, exp_vmax, atol=1e-6))

def test_indep_vrange_title_matches_clim(self):
for mode in range(4):
for img_idx in range(4):
with self.subTest(mode=mode, img_idx=img_idx):
fig = self._imshow(f"indep{mode}")
clim_vmin, clim_vmax = _get_clims(fig)[img_idx]
title_vmin, title_vmax = _get_title_clims(fig)[img_idx]
self.assertEqual("{:.1e}".format(clim_vmin), "{:.1e}".format(title_vmin))
self.assertEqual("{:.1e}".format(clim_vmax), "{:.1e}".format(title_vmax))

def main():
unittest.main()

Expand Down
Loading