From 71ad2f3a86c3746c40a367875498798dc81384cd Mon Sep 17 00:00:00 2001 From: Francesco D'Eugenio Date: Wed, 6 Nov 2024 12:05:34 +0000 Subject: [PATCH 1/2] fig keyword allows m^2 fig.axes, if m>=n sample dims. --- src/corner/core.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/corner/core.py b/src/corner/core.py index a08b3b8..2972082 100644 --- a/src/corner/core.py +++ b/src/corner/core.py @@ -904,16 +904,20 @@ def _parse_input(xs): def _get_fig_axes(fig, K): if not fig.axes: return fig.subplots(K, K), True - try: + + axarr = np.array(fig.axes) + axarr_size = axarr.size + if np.sqrt(axarr_size)!=int(np.sqrt(axarr_size)): + raise ValueError( + f"Provided figure has {axarr_size} axes. Must be a square number") + if axarr.size==K**2: axarr = np.array(fig.axes).reshape((K, K)) return axarr.item() if axarr.size == 1 else axarr.squeeze(), False - except ValueError: - raise ValueError( - ( - "Provided figure has {0} axes, but data has " - "dimensions K={1}" - ).format(len(fig.axes), K) - ) + if axarr.size>K**2: + axarr_ndim = int(np.sqrt(axarr_size)) + axarr = axarr.reshape((axarr_ndim, axarr_ndim)) # Reshape to square + axarr = axarr[:K, :K] + return axarr.squeeze(), False def _set_xlim(force, new_fig, ax, new_xlim): From a9a6633dbcec4933f7081ab1221f39c4d175117b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Nov 2024 12:11:43 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/corner/core.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/corner/core.py b/src/corner/core.py index 2972082..7a63023 100644 --- a/src/corner/core.py +++ b/src/corner/core.py @@ -907,15 +907,16 @@ def _get_fig_axes(fig, K): axarr = np.array(fig.axes) axarr_size = axarr.size - if np.sqrt(axarr_size)!=int(np.sqrt(axarr_size)): + if np.sqrt(axarr_size) != int(np.sqrt(axarr_size)): raise ValueError( - f"Provided figure has {axarr_size} axes. Must be a square number") - if axarr.size==K**2: + f"Provided figure has {axarr_size} axes. Must be a square number" + ) + if axarr.size == K**2: axarr = np.array(fig.axes).reshape((K, K)) return axarr.item() if axarr.size == 1 else axarr.squeeze(), False - if axarr.size>K**2: + if axarr.size > K**2: axarr_ndim = int(np.sqrt(axarr_size)) - axarr = axarr.reshape((axarr_ndim, axarr_ndim)) # Reshape to square + axarr = axarr.reshape((axarr_ndim, axarr_ndim)) # Reshape to square axarr = axarr[:K, :K] return axarr.squeeze(), False