diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index c4603c506..ef1748eb7 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -56,6 +56,7 @@ rcsetup, warnings, ) +from ..ultralayout import KIWI_AVAILABLE, ColorbarLayoutSolver from ..utils import _fontsize_to_pt, edges, units try: @@ -1231,6 +1232,7 @@ def _add_colorbar( loc=loc, labelloc=labelloc, labelrotation=labelrotation, + labelsize=labelsize, pad=pad, **kwargs, ) # noqa: E501 @@ -1417,6 +1419,12 @@ def _add_colorbar( longaxis = obj.long_axis for label in longaxis.get_ticklabels(): label.update(kw_ticklabels) + if KIWI_AVAILABLE and getattr(cax, "_inset_colorbar_layout", None): + _reflow_inset_colorbar_frame(obj, labelloc=labelloc, ticklen=ticklen) + cax._inset_colorbar_obj = obj + cax._inset_colorbar_labelloc = labelloc + cax._inset_colorbar_ticklen = ticklen + _register_inset_colorbar_reflow(self.figure) kw_outline = {"edgecolor": color, "linewidth": linewidth} if obj.outline is not None: obj.outline.update(kw_outline) @@ -1870,6 +1878,16 @@ def _get_size_inches(self): bbox = self.get_position() width = width * abs(bbox.width) height = height * abs(bbox.height) + dpi = getattr(self.figure, "dpi", None) + if dpi: + width = round(width * dpi) / dpi + height = round(height * dpi) / dpi + fig = self.figure + if fig is not None and getattr(fig, "_refnum", None) == self.number: + if getattr(fig, "_refwidth", None) is not None: + width = fig._refwidth + if getattr(fig, "_refheight", None) is not None: + height = fig._refheight return np.array([width, height]) def _get_topmost_axes(self): @@ -2193,6 +2211,7 @@ def _parse_colorbar_inset( frame=None, frameon=None, label=None, + labelsize=None, pad=None, tickloc=None, ticklocation=None, @@ -2211,6 +2230,9 @@ def _parse_colorbar_inset( ) # noqa: E501 width = _not_none(width, rc["colorbar.insetwidth"]) pad = _not_none(pad, rc["colorbar.insetpad"]) + length_raw = length + width_raw = width + pad_raw = pad orientation = _not_none(orientation, "horizontal") ticklocation = _not_none( tickloc, ticklocation, "bottom" if orientation == "horizontal" else "right" @@ -2220,158 +2242,43 @@ def _parse_colorbar_inset( xpad = units(pad, "em", "ax", axes=self, width=True) ypad = units(pad, "em", "ax", axes=self, width=False) - # Calculate space requirements for labels and ticks - labspace = rc["xtick.major.size"] / 72 - fontsize = rc["xtick.labelsize"] - fontsize = _fontsize_to_pt(fontsize) - scale = 1.2 - if orientation == "vertical" and labelloc in ("left", "right"): - scale = 2 # we need a little more room - if label is not None: - labspace += 2 * scale * fontsize / 72 - else: - labspace += scale * fontsize / 72 + tick_fontsize = _fontsize_to_pt(rc["xtick.labelsize"]) + label_fontsize = _fontsize_to_pt(_not_none(labelsize, rc["axes.labelsize"])) + bounds_inset = None + bounds_frame = None - # Convert to axes-relative coordinates - if orientation == "horizontal": - labspace /= self._get_size_inches()[1] + if KIWI_AVAILABLE: + bounds_inset, bounds_frame = _solve_inset_colorbar_bounds( + axes=self, + loc=loc, + orientation=orientation, + length=length, + width=width, + xpad=xpad, + ypad=ypad, + ticklocation=ticklocation, + labelloc=labelloc, + label=label, + labelrotation=labelrotation, + tick_fontsize=tick_fontsize, + label_fontsize=label_fontsize, + ) else: - labspace /= self._get_size_inches()[0] - - # Initial frame dimensions (will be adjusted based on label position) - if orientation == "horizontal": - frame_width = 2 * xpad + length - frame_height = 2 * ypad + width + labspace - else: # vertical - frame_width = 2 * xpad + width + labspace - frame_height = 2 * ypad + length - - # Initialize frame position and colorbar position - xframe = yframe = 0 # frame lower left corner - if loc == "upper right": - xframe = 1 - frame_width - yframe = 1 - frame_height - cb_x = xframe + xpad - cb_y = yframe + ypad - elif loc == "upper left": - yframe = 1 - frame_height - cb_x = xpad - cb_y = yframe + ypad - elif loc == "lower left": - cb_x = xpad - cb_y = ypad - else: # lower right - xframe = 1 - frame_width - cb_x = xframe + xpad - cb_y = ypad - - # Adjust frame and colorbar position based on label location - label_offset = 0.5 * labspace - - # Account for label rotation if specified - labelrotation = _not_none(labelrotation, 0) # default to 0 degrees - if labelrotation != 0 and label is not None: - # Estimate label text dimensions - import math - - # Rough estimate of text width (characters * font size * 0.6) - estimated_text_width = len(str(label)) * fontsize * 0.6 / 72 - text_height = fontsize / 72 - - # Convert rotation to radians - angle_rad = math.radians(abs(labelrotation)) - - # Calculate rotated dimensions - rotated_width = estimated_text_width * math.cos( - angle_rad - ) + text_height * math.sin(angle_rad) - rotated_height = estimated_text_width * math.sin( - angle_rad - ) + text_height * math.cos(angle_rad) - - # Convert back to axes-relative coordinates - if orientation == "horizontal": - # For horizontal colorbars, rotation affects vertical space - rotation_offset = rotated_height / self._get_size_inches()[1] - else: - # For vertical colorbars, rotation affects horizontal space - rotation_offset = rotated_width / self._get_size_inches()[0] - - # Use the larger of the original offset or rotation-adjusted offset - label_offset = max(label_offset, rotation_offset) - - if orientation == "vertical": - if labelloc == "left": - # Move colorbar right to make room for left labels - cb_x += label_offset - - elif labelloc == "top": - # Center colorbar horizontally and extend frame for top labels - cb_x += label_offset - if "upper" in loc: - # Upper positions: extend frame downward - cb_y -= label_offset - yframe -= label_offset - frame_height += label_offset - frame_width += label_offset - if "right" in loc: - xframe -= label_offset - cb_x -= label_offset - elif "lower" in loc: - # Lower positions: extend frame upward - frame_height += label_offset - frame_width += label_offset - if "right" in loc: - xframe -= label_offset - cb_x -= label_offset - - elif labelloc == "bottom": - # Extend frame for bottom labels - if "left" in loc: - cb_x += label_offset - frame_width += label_offset - else: # right - xframe -= label_offset - frame_width += label_offset - - if "lower" in loc: - cb_y += label_offset - frame_height += label_offset - elif "upper" in loc: - yframe -= label_offset - frame_height += label_offset - - elif orientation == "horizontal": - # Base vertical adjustment for horizontal colorbars - cb_y += 2 * label_offset - - if labelloc == "bottom": - if "upper" in loc: - yframe -= label_offset - frame_height += label_offset - elif "lower" in loc: - frame_height += label_offset - cb_y += 0.5 * label_offset - - elif labelloc == "top": - if "upper" in loc: - cb_y -= 1.5 * label_offset - yframe -= label_offset - frame_height += label_offset - elif "lower" in loc: - frame_height += label_offset - cb_y -= 0.5 * label_offset - - # Set final bounds - bounds_inset = [cb_x, cb_y] - bounds_frame = [xframe, yframe] - - if orientation == "horizontal": - bounds_inset.extend((length, width)) - else: # vertical - bounds_inset.extend((width, length)) - - bounds_frame.extend((frame_width, frame_height)) + bounds_inset, bounds_frame = _legacy_inset_colorbar_bounds( + axes=self, + loc=loc, + orientation=orientation, + length=length, + width=width, + xpad=xpad, + ypad=ypad, + ticklocation=ticklocation, + labelloc=labelloc, + label=label, + labelrotation=labelrotation, + tick_fontsize=tick_fontsize, + label_fontsize=label_fontsize, + ) # Create axes and frame cls = mproj.get_projection_class("ultraplot_cartesian") @@ -2382,7 +2289,23 @@ def _parse_colorbar_inset( self.add_child_axes(ax) kw_frame, kwargs = self._parse_frame("colorbar", **kwargs) if frame: - frame = self._add_guide_frame(*bounds_frame, fontsize=fontsize, **kw_frame) + frame = self._add_guide_frame( + *bounds_frame, fontsize=tick_fontsize, **kw_frame + ) + ax._inset_colorbar_layout = { + "loc": loc, + "orientation": orientation, + "length": length, + "width": width, + "xpad": xpad, + "ypad": ypad, + "ticklocation": ticklocation, + "length_raw": length_raw, + "width_raw": width_raw, + "pad_raw": pad_raw, + } + ax._inset_colorbar_parent = self + ax._inset_colorbar_frame = frame kwargs.update({"orientation": orientation, "ticklocation": ticklocation}) return ax, kwargs @@ -2791,6 +2714,79 @@ def _reposition_subplot(self): self.update_params() setter(self.figbox) # equivalent to above + # In UltraLayout, place panels relative to their parent axes, not the grid. + if ( + self._panel_parent + and self._panel_side + and self.figure.gridspec._use_ultra_layout + ): + gs = self.get_subplotspec().get_gridspec() + figwidth, figheight = self.figure.get_size_inches() + ss = self.get_subplotspec().get_topmost_subplotspec() + row1, row2, col1, col2 = ss._get_rows_columns(ncols=gs.ncols_total) + side = self._panel_side + parent_bbox = self._panel_parent.get_position() + panels = list(self._panel_parent._panel_dict.get(side, ())) + anchor_ax = self._panel_parent + if self in panels: + idx = panels.index(self) + if idx > 0: + anchor_ax = panels[idx - 1] + elif panels: + anchor_ax = panels[-1] + anchor_bbox = anchor_ax.get_position() + anchor_ss = anchor_ax.get_subplotspec().get_topmost_subplotspec() + a_row1, a_row2, a_col1, a_col2 = anchor_ss._get_rows_columns( + ncols=gs.ncols_total + ) + + if side in ("right", "left"): + boundary = None + width = sum(gs._wratios_total[col1 : col2 + 1]) / figwidth + if a_col2 < col1: + boundary = a_col2 + elif col2 < a_col1: + boundary = col2 + # Fall back to an interface adjacent to this panel + boundary = min( + max( + _not_none(boundary, a_col2 if side == "right" else col2), + 0, + ), + len(gs.wspace_total) - 1, + ) + pad = gs.wspace_total[boundary] / figwidth + if side == "right": + x0 = anchor_bbox.x1 + pad + else: + x0 = anchor_bbox.x0 - pad - width + bbox = mtransforms.Bbox.from_bounds( + x0, parent_bbox.y0, width, parent_bbox.height + ) + else: + boundary = None + height = sum(gs._hratios_total[row1 : row2 + 1]) / figheight + if a_row2 < row1: + boundary = a_row2 + elif row2 < a_row1: + boundary = row2 + boundary = min( + max( + _not_none(boundary, a_row2 if side == "top" else row2), + 0, + ), + len(gs.hspace_total) - 1, + ) + pad = gs.hspace_total[boundary] / figheight + if side == "top": + y0 = anchor_bbox.y1 + pad + else: + y0 = anchor_bbox.y0 - pad - height + bbox = mtransforms.Bbox.from_bounds( + parent_bbox.x0, y0, parent_bbox.width, height + ) + setter(bbox) + def _update_abc(self, **kwargs): """ Update the a-b-c label. @@ -3443,6 +3439,18 @@ def draw(self, renderer=None, *args, **kwargs): if self._inset_parent is not None and self._inset_zoom: self.indicate_inset_zoom() super().draw(renderer, *args, **kwargs) + if getattr(self, "_inset_colorbar_obj", None) and getattr( + self, "_inset_colorbar_needs_reflow", False + ): + self._inset_colorbar_needs_reflow = False + _reflow_inset_colorbar_frame( + self._inset_colorbar_obj, + labelloc=getattr(self, "_inset_colorbar_labelloc", None), + ticklen=getattr( + self, "_inset_colorbar_ticklen", units(rc["tick.len"], "pt") + ), + ) + self.figure.canvas.draw_idle() def get_tightbbox(self, renderer, *args, **kwargs): # Perform extra post-processing steps @@ -4072,3 +4080,525 @@ def _determine_label_rotation( f"Label rotation must be a number or 'auto', got {labelrotation!r}." ) kw_label.update({"rotation": labelrotation}) + + +def _resolve_label_rotation( + labelrotation: str | Number, + *, + labelloc: str, + orientation: str, +) -> float: + layout_rotation = _not_none(labelrotation, 0) + if layout_rotation == "auto": + kw_label = {} + _determine_label_rotation( + "auto", + labelloc=labelloc, + orientation=orientation, + kw_label=kw_label, + ) + layout_rotation = kw_label.get("rotation", 0) + if not isinstance(layout_rotation, (int, float)): + return 0.0 + return float(layout_rotation) + + +def _measure_label_points( + label: str, + rotation: float, + fontsize: float, + figure, +) -> Optional[Tuple[float, float]]: + try: + renderer = figure._get_renderer() + text = mtext.Text(0, 0, label, rotation=rotation, fontsize=fontsize) + text.set_figure(figure) + bbox = text.get_window_extent(renderer=renderer) + except Exception: + return None + dpi = figure.dpi + return (bbox.width * 72 / dpi, bbox.height * 72 / dpi) + + +def _measure_text_artist_points( + text: mtext.Text, figure +) -> Optional[Tuple[float, float]]: + try: + renderer = figure._get_renderer() + bbox = text.get_window_extent(renderer=renderer) + except Exception: + return None + dpi = figure.dpi + return (bbox.width * 72 / dpi, bbox.height * 72 / dpi) + + +def _measure_ticklabel_extent_points(axis, figure) -> Optional[Tuple[float, float]]: + try: + renderer = figure._get_renderer() + labels = axis.get_ticklabels() + except Exception: + return None + max_width = 0.0 + max_height = 0.0 + for label in labels: + if not label.get_visible() or not label.get_text(): + continue + extent = _measure_text_artist_points(label, figure) + if extent is None: + continue + width_pt, height_pt = extent + max_width = max(max_width, width_pt) + max_height = max(max_height, height_pt) + if max_width == 0.0 and max_height == 0.0: + return None + return (max_width, max_height) + + +def _measure_text_overhang_axes( + text: mtext.Text, axes +) -> Optional[Tuple[float, float, float, float]]: + try: + renderer = axes.figure._get_renderer() + bbox = text.get_window_extent(renderer=renderer) + inv = axes.transAxes.inverted() + (x0, y0) = inv.transform((bbox.x0, bbox.y0)) + (x1, y1) = inv.transform((bbox.x1, bbox.y1)) + except Exception: + return None + left = max(0.0, -x0) + right = max(0.0, x1 - 1.0) + bottom = max(0.0, -y0) + top = max(0.0, y1 - 1.0) + return (left, right, bottom, top) + + +def _measure_ticklabel_overhang_axes( + axis, axes +) -> Optional[Tuple[float, float, float, float]]: + try: + renderer = axes.figure._get_renderer() + inv = axes.transAxes.inverted() + labels = axis.get_ticklabels() + except Exception: + return None + min_x, max_x = 0.0, 1.0 + min_y, max_y = 0.0, 1.0 + found = False + for label in labels: + if not label.get_visible() or not label.get_text(): + continue + bbox = label.get_window_extent(renderer=renderer) + (x0, y0) = inv.transform((bbox.x0, bbox.y0)) + (x1, y1) = inv.transform((bbox.x1, bbox.y1)) + min_x = min(min_x, x0) + max_x = max(max_x, x1) + min_y = min(min_y, y0) + max_y = max(max_y, y1) + found = True + if not found: + return None + left = max(0.0, -min_x) + right = max(0.0, max_x - 1.0) + bottom = max(0.0, -min_y) + top = max(0.0, max_y - 1.0) + return (left, right, bottom, top) + + +def _get_colorbar_long_axis(colorbar): + if hasattr(colorbar, "_long_axis"): + return colorbar._long_axis() + return colorbar.long_axis + + +def _register_inset_colorbar_reflow(fig): + if getattr(fig, "_inset_colorbar_reflow_cid", None) is not None: + return + + def _on_resize(event): + axes = list(event.canvas.figure.axes) + i = 0 + seen = set() + while i < len(axes): + ax = axes[i] + i += 1 + ax_id = id(ax) + if ax_id in seen: + continue + seen.add(ax_id) + child_axes = getattr(ax, "child_axes", ()) + if child_axes: + axes.extend(child_axes) + if getattr(ax, "_inset_colorbar_obj", None) is None: + continue + ax._inset_colorbar_needs_reflow = True + event.canvas.draw_idle() + + fig._inset_colorbar_reflow_cid = fig.canvas.mpl_connect("resize_event", _on_resize) + + +def _solve_inset_colorbar_bounds( + *, + axes: "Axes", + loc: str, + orientation: str, + length: float, + width: float, + xpad: float, + ypad: float, + ticklocation: str, + labelloc: Optional[str], + label, + labelrotation: Union[str, float, None], + tick_fontsize: float, + label_fontsize: float, +) -> Tuple[list[float], list[float]]: + scale = 1.2 + labelloc_layout = labelloc if isinstance(labelloc, str) else ticklocation + if orientation == "vertical" and labelloc_layout in ("left", "right"): + scale = 2 + + tick_space_pt = rc["xtick.major.size"] + scale * tick_fontsize + label_space_pt = 0.0 + if label is not None: + label_space_pt = scale * label_fontsize + layout_rotation = _resolve_label_rotation( + labelrotation, labelloc=labelloc_layout, orientation=orientation + ) + extent = _measure_label_points( + str(label), layout_rotation, label_fontsize, axes.figure + ) + if extent is not None: + width_pt, height_pt = extent + if labelloc_layout in ("left", "right"): + label_space_pt = max(label_space_pt, width_pt) + else: + label_space_pt = max(label_space_pt, height_pt) + + fig_w, fig_h = axes._get_size_inches() + tick_space_x = ( + tick_space_pt / 72 / fig_w if ticklocation in ("left", "right") else 0 + ) + tick_space_y = ( + tick_space_pt / 72 / fig_h if ticklocation in ("top", "bottom") else 0 + ) + label_space_x = ( + label_space_pt / 72 / fig_w if labelloc_layout in ("left", "right") else 0 + ) + label_space_y = ( + label_space_pt / 72 / fig_h if labelloc_layout in ("top", "bottom") else 0 + ) + + pad_left = xpad + (tick_space_x if ticklocation == "left" else 0) + pad_left += label_space_x if labelloc_layout == "left" else 0 + pad_right = xpad + (tick_space_x if ticklocation == "right" else 0) + pad_right += label_space_x if labelloc_layout == "right" else 0 + pad_bottom = ypad + (tick_space_y if ticklocation == "bottom" else 0) + pad_bottom += label_space_y if labelloc_layout == "bottom" else 0 + pad_top = ypad + (tick_space_y if ticklocation == "top" else 0) + pad_top += label_space_y if labelloc_layout == "top" else 0 + + if orientation == "horizontal": + cb_width, cb_height = length, width + else: + cb_width, cb_height = width, length + solver = ColorbarLayoutSolver( + loc, + cb_width, + cb_height, + pad_left, + pad_right, + pad_bottom, + pad_top, + ) + layout = solver.solve() + return list(layout["inset"]), list(layout["frame"]) + + +def _legacy_inset_colorbar_bounds( + *, + axes: "Axes", + loc: str, + orientation: str, + length: float, + width: float, + xpad: float, + ypad: float, + ticklocation: str, + labelloc: Optional[str], + label, + labelrotation: Union[str, float, None], + tick_fontsize: float, + label_fontsize: float, +) -> Tuple[list[float], list[float]]: + labspace = rc["xtick.major.size"] / 72 + scale = 1.2 + if orientation == "vertical" and labelloc in ("left", "right"): + scale = 2 + if label is not None: + labspace += 2 * scale * label_fontsize / 72 + else: + labspace += scale * tick_fontsize / 72 + + if orientation == "horizontal": + labspace /= axes._get_size_inches()[1] + else: + labspace /= axes._get_size_inches()[0] + + if orientation == "horizontal": + frame_width = 2 * xpad + length + frame_height = 2 * ypad + width + labspace + else: + frame_width = 2 * xpad + width + labspace + frame_height = 2 * ypad + length + + xframe = yframe = 0 + if loc == "upper right": + xframe = 1 - frame_width + yframe = 1 - frame_height + cb_x = xframe + xpad + cb_y = yframe + ypad + elif loc == "upper left": + yframe = 1 - frame_height + cb_x = xpad + cb_y = yframe + ypad + elif loc == "lower left": + cb_x = xpad + cb_y = ypad + else: + xframe = 1 - frame_width + cb_x = xframe + xpad + cb_y = ypad + + label_offset = 0.5 * labspace + labelrotation = _not_none(labelrotation, 0) + if labelrotation == "auto": + kw_label = {} + _determine_label_rotation( + "auto", + labelloc=labelloc or ticklocation, + orientation=orientation, + kw_label=kw_label, + ) + labelrotation = kw_label.get("rotation", 0) + if not isinstance(labelrotation, (int, float)): + labelrotation = 0 + if labelrotation != 0 and label is not None: + import math + + estimated_text_width = len(str(label)) * label_fontsize * 0.6 / 72 + text_height = label_fontsize / 72 + angle_rad = math.radians(abs(labelrotation)) + rotated_width = estimated_text_width * math.cos( + angle_rad + ) + text_height * math.sin(angle_rad) + rotated_height = estimated_text_width * math.sin( + angle_rad + ) + text_height * math.cos(angle_rad) + + if orientation == "horizontal": + rotation_offset = rotated_height / axes._get_size_inches()[1] + else: + rotation_offset = rotated_width / axes._get_size_inches()[0] + + label_offset = max(label_offset, rotation_offset) + + if orientation == "vertical": + if labelloc == "left": + cb_x += label_offset + elif labelloc == "top": + cb_x += label_offset + if "upper" in loc: + cb_y -= label_offset + yframe -= label_offset + frame_height += label_offset + frame_width += label_offset + if "right" in loc: + xframe -= label_offset + cb_x -= label_offset + elif "lower" in loc: + frame_height += label_offset + frame_width += label_offset + if "right" in loc: + xframe -= label_offset + cb_x -= label_offset + elif labelloc == "bottom": + if "left" in loc: + cb_x += label_offset + frame_width += label_offset + else: + xframe -= label_offset + frame_width += label_offset + if "lower" in loc: + cb_y += label_offset + frame_height += label_offset + elif "upper" in loc: + yframe -= label_offset + frame_height += label_offset + elif orientation == "horizontal": + cb_y += 2 * label_offset + if labelloc == "bottom": + if "upper" in loc: + yframe -= label_offset + frame_height += label_offset + elif "lower" in loc: + frame_height += label_offset + cb_y += 0.5 * label_offset + elif labelloc == "top": + if "upper" in loc: + cb_y -= 1.5 * label_offset + yframe -= label_offset + frame_height += label_offset + elif "lower" in loc: + frame_height += label_offset + cb_y -= 0.5 * label_offset + + bounds_inset = [cb_x, cb_y] + bounds_frame = [xframe, yframe] + if orientation == "horizontal": + bounds_inset.extend((length, width)) + else: + bounds_inset.extend((width, length)) + bounds_frame.extend((frame_width, frame_height)) + return bounds_inset, bounds_frame + + +def _apply_inset_colorbar_layout( + axes: "Axes", + *, + bounds_inset: list[float], + bounds_frame: list[float], + frame: Optional[mpatches.FancyBboxPatch], +): + parent = getattr(axes, "_inset_colorbar_parent", None) + transform = parent.transAxes if parent is not None else axes.transAxes + locator = axes._make_inset_locator(bounds_inset, transform) + axes.set_axes_locator(locator) + axes.set_position(locator(axes, None).bounds) + axes._inset_colorbar_bounds = { + "inset": bounds_inset, + "frame": bounds_frame, + } + if frame is not None: + frame.set_bounds(*bounds_frame) + + +def _reflow_inset_colorbar_frame( + colorbar, + *, + labelloc: str, + ticklen: float, +): + cax = colorbar.ax + layout = getattr(cax, "_inset_colorbar_layout", None) + frame = getattr(cax, "_inset_colorbar_frame", None) + if not layout: + return + parent = getattr(cax, "_inset_colorbar_parent", None) + if parent is None: + return + orientation = layout["orientation"] + loc = layout["loc"] + ticklocation = layout["ticklocation"] + length_raw = layout.get("length_raw") + width_raw = layout.get("width_raw") + pad_raw = layout.get("pad_raw") + if length_raw is None or width_raw is None or pad_raw is None: + length = layout["length"] + width = layout["width"] + xpad = layout["xpad"] + ypad = layout["ypad"] + else: + length = units(length_raw, "em", "ax", axes=parent, width=True) + width = units(width_raw, "em", "ax", axes=parent, width=False) + xpad = units(pad_raw, "em", "ax", axes=parent, width=True) + ypad = units(pad_raw, "em", "ax", axes=parent, width=False) + layout["length"] = length + layout["width"] = width + layout["xpad"] = xpad + layout["ypad"] = ypad + labelloc_layout = labelloc if isinstance(labelloc, str) else ticklocation + if orientation == "horizontal": + cb_width = length + cb_height = width + else: + cb_width = width + cb_height = length + + renderer = cax.figure._get_renderer() + if hasattr(colorbar, "update_ticks"): + colorbar.update_ticks(manual_only=True) + bboxes = [] + longaxis = _get_colorbar_long_axis(colorbar) + try: + bbox = longaxis.get_tightbbox(renderer) + except Exception: + bbox = None + if bbox is not None: + bboxes.append(bbox) + label_axis = _get_axis_for( + labelloc_layout, loc, orientation=orientation, ax=colorbar + ) + if label_axis.label.get_text(): + try: + bboxes.append(label_axis.label.get_window_extent(renderer=renderer)) + except Exception: + pass + if colorbar.outline is not None: + try: + bboxes.append(colorbar.outline.get_window_extent(renderer=renderer)) + except Exception: + pass + if getattr(colorbar, "solids", None) is not None: + try: + bboxes.append(colorbar.solids.get_window_extent(renderer=renderer)) + except Exception: + pass + if getattr(colorbar, "dividers", None) is not None: + try: + bboxes.append(colorbar.dividers.get_window_extent(renderer=renderer)) + except Exception: + pass + if not bboxes: + return + x0 = min(b.x0 for b in bboxes) + y0 = min(b.y0 for b in bboxes) + x1 = max(b.x1 for b in bboxes) + y1 = max(b.y1 for b in bboxes) + inv_parent = parent.transAxes.inverted() + (px0, py0) = inv_parent.transform((x0, y0)) + (px1, py1) = inv_parent.transform((x1, y1)) + cax_bbox = cax.get_window_extent(renderer=renderer) + (cx0, cy0) = inv_parent.transform((cax_bbox.x0, cax_bbox.y0)) + (cx1, cy1) = inv_parent.transform((cax_bbox.x1, cax_bbox.y1)) + px0, px1 = sorted((px0, px1)) + py0, py1 = sorted((py0, py1)) + cx0, cx1 = sorted((cx0, cx1)) + cy0, cy1 = sorted((cy0, cy1)) + delta_left = max(0.0, cx0 - px0) + delta_right = max(0.0, px1 - cx1) + delta_bottom = max(0.0, cy0 - py0) + delta_top = max(0.0, py1 - cy1) + + pad_left = xpad + delta_left + pad_right = xpad + delta_right + pad_bottom = ypad + delta_bottom + pad_top = ypad + delta_top + try: + solver = ColorbarLayoutSolver( + loc, + cb_width, + cb_height, + pad_left, + pad_right, + pad_bottom, + pad_top, + ) + bounds = solver.solve() + except Exception: + return + _apply_inset_colorbar_layout( + cax, + bounds_inset=list(bounds["inset"]), + bounds_frame=list(bounds["frame"]), + frame=frame, + ) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index b24fd98c9..61fea8249 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -3287,9 +3287,9 @@ def _parse_level_lim( for z in zs: if z is None: # e.g. empty scatter color continue + z = inputs._to_numpy_array(z) if z.ndim > 2: # e.g. imshow data continue - z = inputs._to_numpy_array(z) if inbounds and x is not None and y is not None: # ignore if None coords z = self._inbounds_vlim(x, y, z, to_centers=to_centers) imin, imax = inputs._safe_range(z, pmin, pmax) diff --git a/ultraplot/figure.py b/ultraplot/figure.py index b2612d6a3..273b71ecf 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -1841,6 +1841,8 @@ def _axes_dict(naxs, input, kw=False, default=None): # Create or update the gridspec and add subplots with subplotspecs # NOTE: The gridspec is added to the figure when we pass the subplotspec if gs is None: + if "layout_array" not in gridspec_kw: + gridspec_kw = {**gridspec_kw, "layout_array": array} gs = pgridspec.GridSpec(*array.shape, **gridspec_kw) else: gs.update(**gridspec_kw) diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 6f4c2d229..90b4da086 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -25,6 +25,14 @@ ) from .utils import _fontsize_to_pt, units +try: + from . import ultralayout + + ULTRA_AVAILABLE = True +except ImportError: + ultralayout = None + ULTRA_AVAILABLE = False + __all__ = ["GridSpec", "SubplotGrid"] @@ -228,6 +236,20 @@ def get_position(self, figure, return_all=False): nrows, ncols = gs.get_total_geometry() else: nrows, ncols = gs.get_geometry() + + # Check if we should use UltraLayout for this subplot + if isinstance(gs, GridSpec) and gs._use_ultra_layout: + bbox = gs._get_ultra_position(self.num1, figure) + if bbox is not None: + if return_all: + rows, cols = np.unravel_index( + [self.num1, self.num2], (nrows, ncols) + ) + return bbox, rows[0], cols[0], nrows, ncols + else: + return bbox + + # Default behavior: use grid positions rows, cols = np.unravel_index([self.num1, self.num2], (nrows, ncols)) bottoms, tops, lefts, rights = gs.get_grid_positions(figure) bottom = bottoms[rows].min() @@ -267,7 +289,14 @@ def __getattr__(self, attr): super().__getattribute__(attr) # native error message @docstring._snippet_manager - def __init__(self, nrows=1, ncols=1, **kwargs): + def __init__( + self, + nrows=1, + ncols=1, + layout_array=None, + ultra_layout: Optional[bool] = None, + **kwargs, + ): """ Parameters ---------- @@ -275,6 +304,14 @@ def __init__(self, nrows=1, ncols=1, **kwargs): The number of rows in the subplot grid. ncols : int, optional The number of columns in the subplot grid. + layout_array : array-like, optional + 2D array specifying the subplot layout, where each unique integer + represents a subplot and 0 represents empty space. When provided, + enables UltraLayout constraint-based positioning (requires + kiwisolver package). + ultra_layout : bool, optional + Whether to use the UltraLayout constraint solver. Defaults to True + when kiwisolver is available. Set to False to use the legacy solver. Other parameters ---------------- @@ -304,6 +341,27 @@ def __init__(self, nrows=1, ncols=1, **kwargs): manually and want the same geometry for multiple figures, you must create a copy with `GridSpec.copy` before working on the subsequent figure). """ + # Layout array for UltraLayout + self._layout_array = ( + np.array(layout_array) if layout_array is not None else None + ) + self._ultra_positions = None # Cache for UltraLayout-computed positions + self._ultra_layout_array = None # Cache for expanded UltraLayout array + self._use_ultra_layout = False # Flag to enable UltraLayout + + # Check if we should use UltraLayout + if ultra_layout is not None: + self._use_ultra_layout = bool(ultra_layout) and ULTRA_AVAILABLE + elif ULTRA_AVAILABLE: + self._use_ultra_layout = True + if ultra_layout and not ULTRA_AVAILABLE: + warnings._warn_ultraplot( + "ultra_layout=True requested but kiwisolver is not available. " + "Falling back to the legacy layout solver." + ) + if self._use_ultra_layout and self._layout_array is None: + self._layout_array = np.arange(1, nrows * ncols + 1).reshape(nrows, ncols) + # Fundamental GridSpec properties self._nrows_total = nrows self._ncols_total = ncols @@ -366,6 +424,162 @@ def __init__(self, nrows=1, ncols=1, **kwargs): } self._update_params(pad=pad, **kwargs) + def _get_ultra_position(self, subplot_num, figure): + """ + Get the position of a subplot using UltraLayout constraint-based positioning. + + Parameters + ---------- + subplot_num : int + The subplot number (in total geometry indexing) + figure : Figure + The matplotlib figure instance + + Returns + ------- + bbox : Bbox or None + The bounding box for the subplot, or None if kiwi layout fails + """ + if not self._use_ultra_layout or self._layout_array is None: + return None + + # Ensure figure is set + if not self.figure: + self._figure = figure + if not self.figure: + return None + + # Compute or retrieve cached UltraLayout positions + if self._ultra_positions is None: + self._compute_ultra_positions() + if self._ultra_positions is None: + return None + layout_array = self._get_ultra_layout_array() + if layout_array is None: + return None + + # Find which subplot number in the layout array corresponds to this subplot_num + # We need to map from the gridspec cell index to the layout array subplot number + nrows, ncols = layout_array.shape + + # Decode the subplot_num to find which layout number it corresponds to + # This is a bit tricky because subplot_num is in total geometry space + # We need to find which unique number in the layout_array this corresponds to + + # Get the cell position from subplot_num + if (nrows, ncols) == self.get_total_geometry(): + row, col = divmod(subplot_num, self.ncols_total) + else: + decoded = self._decode_indices(subplot_num) + row, col = divmod(decoded, ncols) + + # Check if this is within the layout array bounds + if row >= nrows or col >= ncols: + return None + + # Get the layout number at this position + layout_num = layout_array[row, col] + + if layout_num == 0 or layout_num not in self._ultra_positions: + return None + + # Return the cached position + left, bottom, width, height = self._ultra_positions[layout_num] + bbox = mtransforms.Bbox.from_bounds(left, bottom, width, height) + return bbox + + def _compute_ultra_positions(self): + """ + Compute subplot positions using UltraLayout and cache them. + """ + if not ULTRA_AVAILABLE or self._layout_array is None: + return + layout_array = self._get_ultra_layout_array() + if layout_array is None: + return + + # Get figure size + if not self.figure: + return + + figwidth, figheight = self.figure.get_size_inches() + + # Convert spacing to inches (including default ticklabel sizes). + wspace_inches = list(self.wspace_total) + hspace_inches = list(self.hspace_total) + + # Get margins + left = self.left + right = self.right + top = self.top + bottom = self.bottom + + # Compute positions using UltraLayout + try: + self._ultra_positions = ultralayout.compute_ultra_positions( + layout_array, + figwidth=figwidth, + figheight=figheight, + wspace=wspace_inches, + hspace=hspace_inches, + left=left, + right=right, + top=top, + bottom=bottom, + wratios=self._wratios_total, + hratios=self._hratios_total, + wpanels=[bool(val) for val in self._wpanels], + hpanels=[bool(val) for val in self._hpanels], + ) + except Exception as e: + warnings._warn_ultraplot( + f"Failed to compute UltraLayout: {e}. " + "Falling back to default grid layout." + ) + self._use_ultra_layout = False + self._ultra_positions = None + + def _get_ultra_layout_array(self): + """ + Return the layout array expanded to total geometry to include panels. + """ + if self._layout_array is None: + return None + if self._ultra_layout_array is not None: + return self._ultra_layout_array + + nrows_total, ncols_total = self.get_total_geometry() + layout = self._layout_array + if layout.shape == (nrows_total, ncols_total): + self._ultra_layout_array = layout + return layout + + nrows, ncols = self.get_geometry() + if layout.shape != (nrows, ncols): + warnings._warn_ultraplot( + "Layout array shape does not match gridspec geometry; " + "using the original layout array for UltraLayout." + ) + self._ultra_layout_array = layout + return layout + + row_idxs = self._get_indices("h", panel=False) + col_idxs = self._get_indices("w", panel=False) + if len(row_idxs) != nrows or len(col_idxs) != ncols: + warnings._warn_ultraplot( + "Layout array shape does not match non-panel gridspec geometry; " + "using the original layout array for UltraLayout." + ) + self._ultra_layout_array = layout + return layout + + expanded = np.zeros((nrows_total, ncols_total), dtype=layout.dtype) + for i, row_idx in enumerate(row_idxs): + for j, col_idx in enumerate(col_idxs): + expanded[row_idx, col_idx] = layout[i, j] + self._ultra_layout_array = expanded + return expanded + def __getitem__(self, key): """ Get a `~matplotlib.gridspec.SubplotSpec`. "Hidden" slots allocated for axes @@ -425,14 +639,11 @@ def _encode_indices(self, *args, which=None, panel=False): nums = [] idxs = self._get_indices(which=which, panel=panel) for arg in args: - if isinstance(arg, (list, np.ndarray)): - try: - nums.append([idxs[int(i)] for i in arg]) - except (IndexError, TypeError): - raise ValueError(f"Invalid gridspec index {arg}.") - continue try: - nums.append(idxs[arg]) + if isinstance(arg, (list, np.ndarray)): + nums.append([idxs[i] for i in list(arg)]) + else: + nums.append(idxs[arg]) except (IndexError, TypeError): raise ValueError(f"Invalid gridspec index {arg}.") return nums[0] if len(nums) == 1 else nums @@ -495,6 +706,9 @@ def _modify_subplot_geometry(self, newrow=None, newcol=None): """ Update the axes subplot specs by inserting rows and columns as specified. """ + if self._use_ultra_layout: + self._ultra_positions = None + self._ultra_layout_array = None fig = self.figure ncols = self._ncols_total - int(newcol is not None) # previous columns inserts = (newrow, newrow, newcol, newcol) @@ -970,8 +1184,11 @@ def _auto_layout_aspect(self): # Update the layout figsize = self._update_figsize() - if not fig._is_same_size(figsize): - fig.set_size_inches(figsize, internal=True) + eps = 0.01 + if fig._refwidth is not None or fig._refheight is not None: + eps = 0 + if not fig._is_same_size(figsize, eps=eps): + fig.set_size_inches(figsize, internal=True, eps=0) def _auto_layout_tight(self, renderer): """ @@ -1029,8 +1246,11 @@ def _auto_layout_tight(self, renderer): # spaces (necessary since native position coordinates are figure-relative) # and to enforce fixed panel ratios. So only self.update() if we skip resize. figsize = self._update_figsize() - if not fig._is_same_size(figsize): - fig.set_size_inches(figsize, internal=True) + eps = 0.01 + if fig._refwidth is not None or fig._refheight is not None: + eps = 0 # force resize when explicit reference sizing is requested + if not fig._is_same_size(figsize, eps=eps): + fig.set_size_inches(figsize, internal=True, eps=0) else: self.update() @@ -1047,14 +1267,14 @@ def _update_figsize(self): return ss = ax.get_subplotspec().get_topmost_subplotspec() y1, y2, x1, x2 = ss._get_rows_columns() - refhspace = sum(self.hspace_total[y1:y2]) - refwspace = sum(self.wspace_total[x1:x2]) - refhpanel = sum( - self.hratios_total[i] for i in range(y1, y2 + 1) if self._hpanels[i] - ) # noqa: E501 - refwpanel = sum( - self.wratios_total[i] for i in range(x1, x2 + 1) if self._wpanels[i] - ) # noqa: E501 + # NOTE: Reference width/height should correspond to the span of the *axes* + # themselves. Spaces between rows/columns and adjacent panel slots should + # not reduce the target size; those are accounted for separately when the + # full figure size is rebuilt below. + refhspace = 0 + refwspace = 0 + refhpanel = 0 + refwpanel = 0 refhsubplot = sum( self.hratios_total[i] for i in range(y1, y2 + 1) if not self._hpanels[i] ) # noqa: E501 @@ -1066,6 +1286,10 @@ def _update_figsize(self): # NOTE: The sizing arguments should have been normalized already figwidth, figheight = fig._figwidth, fig._figheight refwidth, refheight = fig._refwidth, fig._refheight + if refwidth is not None: + figwidth = None # prefer explicit reference sizing over preset fig size + if refheight is not None: + figheight = None refaspect = _not_none(fig._refaspect, fig._refaspect_default) if refheight is None and figheight is None: if figwidth is not None: @@ -1096,6 +1320,15 @@ def _update_figsize(self): gridwidth = refwidth * self.gridwidth / refwsubplot figwidth = gridwidth + self.spacewidth + self.panelwidth + # Snap explicit reference-driven sizes to the pixel grid to avoid + # rounding the axes width below the requested reference size. + if fig and (fig._refwidth is not None or fig._refheight is not None): + dpi = _not_none(getattr(fig, "dpi", None), 72) + if figwidth is not None: + figwidth = np.ceil(figwidth * dpi) / dpi + if figheight is not None: + figheight = np.ceil(figheight * dpi) / dpi + # Return the figure size figsize = (figwidth, figheight) if all(np.isfinite(figsize)): @@ -1106,6 +1339,7 @@ def _update_figsize(self): def _update_params( self, *, + ultra_layout=None, left=None, bottom=None, right=None, @@ -1133,6 +1367,20 @@ def _update_params( """ Update the user-specified properties. """ + if ultra_layout is not None: + self._use_ultra_layout = bool(ultra_layout) and ULTRA_AVAILABLE + if ultra_layout and not ULTRA_AVAILABLE: + warnings._warn_ultraplot( + "ultra_layout=True requested but kiwisolver is not available. " + "Falling back to the legacy layout solver." + ) + if self._use_ultra_layout and self._layout_array is None: + nrows, ncols = self.get_geometry() + self._layout_array = np.arange(1, nrows * ncols + 1).reshape( + nrows, ncols + ) + self._ultra_positions = None + self._ultra_layout_array = None # Assign scalar args # WARNING: The key signature here is critical! Used in ui.py to @@ -1225,7 +1473,12 @@ def copy(self, **kwargs): # WARNING: For some reason copy.copy() fails. Updating e.g. wpanels # and hpanels on the copy also updates this object. No idea why. nrows, ncols = self.get_geometry() - gs = GridSpec(nrows, ncols) + gs = GridSpec( + nrows, + ncols, + layout_array=self._layout_array, + ultra_layout=self._use_ultra_layout, + ) hidxs = self._get_indices("h") widxs = self._get_indices("w") gs._hratios_total = [self._hratios_total[i] for i in hidxs] @@ -1390,6 +1643,9 @@ def update(self, **kwargs): # Apply positions to all axes # NOTE: This uses the current figure size to fix panel widths # and determine physical grid spacing. + if self._use_ultra_layout: + self._ultra_positions = None + self._ultra_layout_array = None self._update_params(**kwargs) fig = self.figure if fig is None: @@ -1445,8 +1701,30 @@ def figure(self, fig): get_height_ratios = _disable_method("get_height_ratios") set_width_ratios = _disable_method("set_width_ratios") set_height_ratios = _disable_method("set_height_ratios") - get_subplot_params = _disable_method("get_subplot_params") - locally_modified_subplot_params = _disable_method("locally_modified_subplot_params") + + # Compat: some backends (e.g., Positron) call these for read-only checks. + # We return current margins/spaces without permitting mutation. + def get_subplot_params(self, figure=None): + from matplotlib.figure import SubplotParams + + fig = figure or self.figure + if fig is None: + raise RuntimeError("Figure must be assigned to gridspec.") + # Convert absolute margins to figure-relative floats + width, height = fig.get_size_inches() + left = self.left / width + right = 1 - self.right / width + bottom = self.bottom / height + top = 1 - self.top / height + wspace = sum(self.wspace_total) / width + hspace = sum(self.hspace_total) / height + return SubplotParams( + left=left, right=right, bottom=bottom, top=top, wspace=wspace, hspace=hspace + ) + + def locally_modified_subplot_params(self): + # Backend probe: report False/None semantics (no local mods to MPL params). + return False # Immutable helper properties used to calculate figure size and subplot positions # NOTE: The spaces are auto-filled with defaults wherever user left them unset @@ -1618,13 +1896,13 @@ def __getitem__(self, key): >>> axs[:, 0] # a SubplotGrid containing the subplots in the first column """ # Allow 1D list-like indexing - if isinstance(key, (Integral, np.integer)): + if isinstance(key, int): return list.__getitem__(self, key) elif isinstance(key, slice): return SubplotGrid(list.__getitem__(self, key)) elif isinstance(key, (list, np.ndarray)): - # NOTE: list.__getitem__ does not support numpy integers - return SubplotGrid([list.__getitem__(self, int(i)) for i in key]) + objs = [list.__getitem__(self, idx) for idx in list(key)] + return SubplotGrid(objs) # Allow 2D array-like indexing # NOTE: We assume this is a 2D array of subplots, because this is @@ -1767,14 +2045,14 @@ def format(self, **kwargs): all_axes = set(self.figure._subplot_dict.values()) is_subset = bool(axes) and all_axes and set(axes) != all_axes if len(self) > 1: - if not is_subset and share_xlabels is None and xlabel is not None: - self.figure._clear_share_label_groups(target="x") - if not is_subset and share_ylabels is None and ylabel is not None: - self.figure._clear_share_label_groups(target="y") if share_xlabels is False: self.figure._clear_share_label_groups(self, target="x") if share_ylabels is False: self.figure._clear_share_label_groups(self, target="y") + if not is_subset and share_xlabels is None and xlabel is not None: + self.figure._clear_share_label_groups(self, target="x") + if not is_subset and share_ylabels is None and ylabel is not None: + self.figure._clear_share_label_groups(self, target="y") if is_subset and share_xlabels is None and xlabel is not None: self.figure._register_share_label_group(self, target="x") if is_subset and share_ylabels is None and ylabel is not None: diff --git a/ultraplot/tests/test_base.py b/ultraplot/tests/test_base.py index e6bff68ee..7fb52224e 100644 --- a/ultraplot/tests/test_base.py +++ b/ultraplot/tests/test_base.py @@ -1,7 +1,11 @@ -import ultraplot as uplt, pytest, numpy as np from unittest import mock + +import numpy as np +import pytest from packaging import version +import ultraplot as uplt + @pytest.mark.parametrize( "mpl_version", @@ -119,3 +123,27 @@ def test_unshare_setting_share_x_or_y(): assert ax[0]._sharex is None assert ax[1]._sharex is None uplt.close(fig) + + +def test_get_size_inches_rounding_and_reference_override(): + """ + _get_size_inches should snap to pixel grid and respect reference sizing. + """ + fig = uplt.figure(figsize=(4, 3), dpi=101) + ax = fig.add_subplot(1, 1, 1) + ax.set_position([0.0, 0.0, 1 / 3, 0.5]) + + size = ax._get_size_inches() + expected_width = round((4 * (1 / 3)) * 101) / 101 + expected_height = round((3 * 0.5) * 101) / 101 + assert np.isclose(size[0], expected_width) + assert np.isclose(size[1], expected_height) + + fig._refnum = ax.number + fig._refwidth = 9.5 + fig._refheight = 7.25 + size = ax._get_size_inches() + assert size[0] == 9.5 + assert size[1] == 7.25 + + uplt.close(fig) diff --git a/ultraplot/tests/test_colorbar.py b/ultraplot/tests/test_colorbar.py index b4e42eb40..81118762f 100644 --- a/ultraplot/tests/test_colorbar.py +++ b/ultraplot/tests/test_colorbar.py @@ -46,6 +46,56 @@ def test_explicit_legend_with_handles_under_external_mode(): assert "LegendLabel" in labels +@pytest.mark.parametrize( + "orientation, labelloc", + [ + ("horizontal", "top"), + ("vertical", "left"), + ], +) +def test_inset_colorbar_frame_wraps_label(rng, orientation, labelloc): + """ + Ensure inset colorbar frame expands to include label after resize. + """ + from ultraplot.axes.base import _get_axis_for, _reflow_inset_colorbar_frame + + fig, ax = uplt.subplots() + data = rng.random((10, 10)) + m = ax.imshow(data) + cb = ax.colorbar( + m, + loc="ur", + label="test", + frameon=True, + orientation=orientation, + labelloc=labelloc, + ) + fig.canvas.draw() + fig.set_size_inches(7, 4.5) + fig.canvas.draw() + + labelloc = cb.ax._inset_colorbar_labelloc + ticklen = cb.ax._inset_colorbar_ticklen + _reflow_inset_colorbar_frame(cb, labelloc=labelloc, ticklen=ticklen) + fig.canvas.draw() + + frame = cb.ax._inset_colorbar_frame + assert frame is not None + renderer = fig.canvas.get_renderer() + frame_bbox = frame.get_window_extent(renderer) + layout = cb.ax._inset_colorbar_layout + labelloc_layout = labelloc if isinstance(labelloc, str) else layout["ticklocation"] + label_axis = _get_axis_for( + labelloc_layout, layout["loc"], orientation=layout["orientation"], ax=cb + ) + label_bbox = label_axis.label.get_window_extent(renderer) + tol = 1.0 + assert frame_bbox.x0 <= label_bbox.x0 + tol + assert frame_bbox.x1 >= label_bbox.x1 - tol + assert frame_bbox.y0 <= label_bbox.y0 + tol + assert frame_bbox.y1 >= label_bbox.y1 - tol + + from itertools import product diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index 1bcb69684..8e747057d 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -74,6 +74,17 @@ def test_external_disables_autolabels_no_label(): assert (not labels) or (labels[0] in ("_no_label", "")) +def test_parse_level_lim_accepts_list_input(): + """ + Ensure list inputs are converted before checking ndim in _parse_level_lim. + """ + fig, ax = uplt.subplots() + vmin, vmax, _ = ax[0]._parse_level_lim([[1, 2], [3, 4]]) + assert vmin == 1 + assert vmax == 4 + uplt.close(fig) + + def test_error_shading_explicit_label_external(): """ Explicit label on fill_between should be preserved in legend entries. diff --git a/ultraplot/tests/test_ultralayout.py b/ultraplot/tests/test_ultralayout.py new file mode 100644 index 000000000..2e1244daa --- /dev/null +++ b/ultraplot/tests/test_ultralayout.py @@ -0,0 +1,357 @@ +import numpy as np +import pytest + +import ultraplot as uplt +from ultraplot import ultralayout +from ultraplot.gridspec import GridSpec +from ultraplot.internals.warnings import UltraPlotWarning + + +def test_is_orthogonal_layout_simple_grid(): + """Test orthogonal layout detection for simple grids.""" + # Simple 2x2 grid should be orthogonal + array = np.array([[1, 2], [3, 4]]) + assert ultralayout.is_orthogonal_layout(array) is True + + +def test_is_orthogonal_layout_non_orthogonal(): + """Test orthogonal layout detection for non-orthogonal layouts.""" + # Centered subplot with empty cells should be non-orthogonal + array = np.array([[1, 1, 2, 2], [0, 3, 3, 0]]) + assert ultralayout.is_orthogonal_layout(array) is False + + +def test_is_orthogonal_layout_spanning(): + """Test orthogonal layout with spanning subplots that is still orthogonal.""" + # L-shape that maintains grid alignment + array = np.array([[1, 1], [1, 2]]) + assert ultralayout.is_orthogonal_layout(array) is True + + +def test_is_orthogonal_layout_with_gaps(): + """Test non-orthogonal layout with gaps.""" + array = np.array([[1, 1, 1], [2, 0, 3]]) + assert ultralayout.is_orthogonal_layout(array) is False + + +def test_is_orthogonal_layout_empty(): + """Test empty layout.""" + array = np.array([[0, 0], [0, 0]]) + assert ultralayout.is_orthogonal_layout(array) is True + + +def test_gridspec_with_orthogonal_layout(): + """Test that GridSpec activates UltraLayout for orthogonal layouts.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 2], [3, 4]]) + gs = GridSpec(2, 2, layout_array=layout) + assert gs._layout_array is not None + # Should use UltraLayout for orthogonal layouts + assert gs._use_ultra_layout is True + + +def test_gridspec_with_non_orthogonal_layout(): + """Test that GridSpec activates UltraLayout for non-orthogonal layouts.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 1, 2, 2], [0, 3, 3, 0]]) + gs = GridSpec(2, 4, layout_array=layout) + assert gs._layout_array is not None + # Should use UltraLayout for non-orthogonal layouts + assert gs._use_ultra_layout is True + + +def test_gridspec_without_kiwisolver(monkeypatch): + """Test graceful fallback when kiwisolver is not available.""" + # Mock the ULTRA_AVAILABLE flag + import ultraplot.gridspec as gs_module + + monkeypatch.setattr(gs_module, "ULTRA_AVAILABLE", False) + + layout = np.array([[1, 1, 2, 2], [0, 3, 3, 0]]) + gs = GridSpec(2, 4, layout_array=layout) + # Should not activate UltraLayout if kiwisolver not available + assert gs._use_ultra_layout is False + + +def test_gridspec_ultralayout_opt_out(): + """Test that UltraLayout can be disabled explicitly.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 2], [3, 4]]) + gs = GridSpec(2, 2, layout_array=layout, ultra_layout=False) + assert gs._use_ultra_layout is False + + +def test_gridspec_default_layout_array_with_ultralayout(): + """Test that UltraLayout initializes a default layout array.""" + pytest.importorskip("kiwisolver") + gs = GridSpec(2, 3) + assert gs._layout_array is not None + assert gs._layout_array.shape == (2, 3) + assert gs._use_ultra_layout is True + + +def test_ultralayout_layout_array_shape_mismatch_warns(): + """Test that mismatched layout arrays fall back to the original array.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 2, 3]]) + with pytest.warns(UltraPlotWarning): + gs = GridSpec(2, 2, layout_array=layout) + resolved = gs._get_ultra_layout_array() + assert resolved.shape == layout.shape + assert np.array_equal(resolved, layout) + + +def test_subplots_pass_layout_array_into_gridspec(): + """Test that subplots pass the layout array to GridSpec.""" + layout = [[1, 1, 2], [3, 4, 5]] + fig, axs = uplt.subplots(array=layout, figsize=(6, 4)) + assert np.array_equal(fig.gridspec._layout_array, np.array(layout)) + uplt.close(fig) + + +def test_ultralayout_solver_initialization(): + """Test UltraLayoutSolver can be initialized.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 1, 2, 2], [0, 3, 3, 0]]) + solver = ultralayout.UltraLayoutSolver(layout, figwidth=10.0, figheight=6.0) + assert solver.array is not None + assert solver.nrows == 2 + assert solver.ncols == 4 + + +def test_compute_ultra_positions(): + """Test computing positions with UltraLayout.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 1, 2, 2], [0, 3, 3, 0]]) + positions = ultralayout.compute_ultra_positions( + layout, + figwidth=10.0, + figheight=6.0, + wspace=[0.2, 0.2, 0.2], + hspace=[0.2], + ) + + # Should return positions for 3 subplots + assert len(positions) == 3 + assert 1 in positions + assert 2 in positions + assert 3 in positions + + # Each position should be (left, bottom, width, height) + for num, pos in positions.items(): + assert len(pos) == 4 + left, bottom, width, height = pos + assert 0 <= left <= 1 + assert 0 <= bottom <= 1 + assert width > 0 + assert height > 0 + assert left + width <= 1.01 # Allow small numerical error + assert bottom + height <= 1.01 + + +def test_subplots_with_non_orthogonal_layout(): + """Test creating subplots with non-orthogonal layout.""" + pytest.importorskip("kiwisolver") + layout = [[1, 1, 2, 2], [0, 3, 3, 0]] + fig, axs = uplt.subplots(array=layout, figsize=(10, 6)) + + # Should create 3 subplots + assert len(axs) == 3 + + # Check that positions are valid + for ax in axs: + pos = ax.get_position() + assert pos.width > 0 + assert pos.height > 0 + assert 0 <= pos.x0 <= 1 + assert 0 <= pos.y0 <= 1 + + +def test_ultralayout_panel_alignment_matches_parent(): + """Test panel axes stay aligned with parent axes under UltraLayout.""" + pytest.importorskip("kiwisolver") + layout = [[1, 1, 2, 2], [0, 3, 3, 0]] + fig, axs = uplt.subplots(array=layout, figsize=(8, 5)) + parent = axs[0] + panel = parent.panel_axes("right", width=0.4) + fig.auto_layout() + + parent_pos = parent.get_position() + panel_pos = panel.get_position() + assert np.isclose(panel_pos.y0, parent_pos.y0) + assert np.isclose(panel_pos.height, parent_pos.height) + assert panel_pos.x0 >= parent_pos.x1 + uplt.close(fig) + + +def test_subplots_with_orthogonal_layout(): + """Test creating subplots with orthogonal layout (should work as before).""" + layout = [[1, 2], [3, 4]] + fig, axs = uplt.subplots(array=layout, figsize=(8, 6)) + + # Should create 4 subplots + assert len(axs) == 4 + + # Check that positions are valid + for ax in axs: + pos = ax.get_position() + assert pos.width > 0 + assert pos.height > 0 + + +def test_ultralayout_respects_spacing(): + """Test that UltraLayout respects spacing parameters.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 1, 2, 2], [0, 3, 3, 0]]) + + # Compute with different spacing + positions1 = ultralayout.compute_ultra_positions( + layout, figwidth=10.0, figheight=6.0, wspace=[0.1, 0.1, 0.1], hspace=[0.1] + ) + positions2 = ultralayout.compute_ultra_positions( + layout, figwidth=10.0, figheight=6.0, wspace=[0.5, 0.5, 0.5], hspace=[0.5] + ) + + # Subplots should be smaller with more spacing + for num in [1, 2, 3]: + _, _, width1, height1 = positions1[num] + _, _, width2, height2 = positions2[num] + # With more spacing, subplots should be smaller + assert width2 < width1 or height2 < height1 + + +def test_ultralayout_respects_ratios(): + """Test that UltraLayout respects width/height ratios.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 2], [3, 4]]) + + # Equal ratios + positions1 = ultralayout.compute_ultra_positions( + layout, figwidth=10.0, figheight=6.0, wratios=[1, 1], hratios=[1, 1] + ) + + # Unequal ratios + positions2 = ultralayout.compute_ultra_positions( + layout, figwidth=10.0, figheight=6.0, wratios=[1, 2], hratios=[1, 1] + ) + + # Subplot 2 should be wider than subplot 1 with unequal ratios + _, _, width1_1, _ = positions1[1] + _, _, width1_2, _ = positions1[2] + _, _, width2_1, _ = positions2[1] + _, _, width2_2, _ = positions2[2] + + # With equal ratios, widths should be similar + assert abs(width1_1 - width1_2) < 0.01 + # With 1:2 ratio, second should be roughly twice as wide + assert width2_2 > width2_1 + + +def test_ultralayout_with_panels_uses_total_geometry(): + """Test UltraLayout accounts for panel slots in total geometry.""" + pytest.importorskip("kiwisolver") + layout = [[1, 1, 2, 2], [0, 3, 3, 0]] + fig, axs = uplt.subplots(array=layout, figsize=(8, 6)) + + # Add a colorbar to introduce panel slots + mappable = axs[0].imshow([[0, 1], [2, 3]]) + fig.colorbar(mappable, loc="r") + + gs = fig.gridspec + gs._compute_ultra_positions() + assert gs._ultra_layout_array.shape == gs.get_total_geometry() + + row_idxs = gs._get_indices("h", panel=False) + col_idxs = gs._get_indices("w", panel=False) + for i, row_idx in enumerate(row_idxs): + for j, col_idx in enumerate(col_idxs): + assert gs._ultra_layout_array[row_idx, col_idx] == gs._layout_array[i, j] + + ss = axs[0].get_subplotspec() + assert gs._get_ultra_position(ss.num1, fig) is not None + + +def test_ultralayout_cached_positions(): + """Test that UltraLayout positions are cached in GridSpec.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 1, 2, 2], [0, 3, 3, 0]]) + gs = GridSpec(2, 4, layout_array=layout) + + # Positions should not be computed yet + assert gs._ultra_positions is None + + # Create a figure to trigger position computation + fig = uplt.figure() + gs._figure = fig + + # Access a position (this should trigger computation) + ss = gs[0, 0] + pos = ss.get_position(fig) + + # Positions should now be cached + assert gs._ultra_positions is not None + assert len(gs._ultra_positions) == 3 + + +def test_ultralayout_with_margins(): + """Test that UltraLayout respects margin parameters.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 2]]) + + # Small margins + positions1 = ultralayout.compute_ultra_positions( + layout, figwidth=10.0, figheight=6.0, left=0.1, right=0.1, top=0.1, bottom=0.1 + ) + + # Large margins + positions2 = ultralayout.compute_ultra_positions( + layout, figwidth=10.0, figheight=6.0, left=1.0, right=1.0, top=1.0, bottom=1.0 + ) + + # With larger margins, subplots should be smaller + for num in [1, 2]: + _, _, width1, height1 = positions1[num] + _, _, width2, height2 = positions2[num] + assert width2 < width1 + assert height2 < height1 + + +def test_complex_non_orthogonal_layout(): + """Test a more complex non-orthogonal layout.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 1, 1, 2], [3, 3, 0, 2], [4, 5, 5, 5]]) + + positions = ultralayout.compute_ultra_positions( + layout, figwidth=12.0, figheight=9.0 + ) + + # Should have 5 subplots + assert len(positions) == 5 + + # All positions should be valid + for num in range(1, 6): + assert num in positions + left, bottom, width, height = positions[num] + assert 0 <= left <= 1 + assert 0 <= bottom <= 1 + assert width > 0 + assert height > 0 + + +def test_ultralayout_module_exports(): + """Test that ultralayout module exports expected symbols.""" + assert hasattr(ultralayout, "UltraLayoutSolver") + assert hasattr(ultralayout, "compute_ultra_positions") + assert hasattr(ultralayout, "is_orthogonal_layout") + assert hasattr(ultralayout, "get_grid_positions_ultra") + + +def test_gridspec_copy_preserves_layout_array(): + """Test that copying a GridSpec preserves the layout array.""" + layout = np.array([[1, 1, 2, 2], [0, 3, 3, 0]]) + gs1 = GridSpec(2, 4, layout_array=layout) + gs2 = gs1.copy() + + assert gs2._layout_array is not None + assert np.array_equal(gs1._layout_array, gs2._layout_array) + assert gs1._use_ultra_layout == gs2._use_ultra_layout diff --git a/ultraplot/ultralayout.py b/ultraplot/ultralayout.py new file mode 100644 index 000000000..75aa9cd18 --- /dev/null +++ b/ultraplot/ultralayout.py @@ -0,0 +1,634 @@ +#!/usr/bin/env python3 +""" +UltraLayout: Advanced constraint-based layout system for non-orthogonal subplot arrangements. + +This module provides UltraPlot's constraint-based layout computation for subplot grids +that don't follow simple orthogonal patterns, such as [[1, 1, 2, 2], [0, 3, 3, 0]] +where subplot 3 should be nicely centered between subplots 1 and 2. +""" + +from typing import Dict, List, Optional, Tuple + +import numpy as np + +try: + from kiwisolver import Solver, Variable + + KIWI_AVAILABLE = True +except ImportError: + KIWI_AVAILABLE = False + Variable = None + Solver = None + + +__all__ = [ + "ColorbarLayoutSolver", + "UltraLayoutSolver", + "compute_ultra_positions", + "get_grid_positions_ultra", + "is_orthogonal_layout", +] + + +def is_orthogonal_layout(array: np.ndarray) -> bool: + """ + Check if a subplot array follows an orthogonal (grid-aligned) layout. + + An orthogonal layout is one where every subplot's edges align with + other subplots' edges, forming a simple grid. + + Parameters + ---------- + array : np.ndarray + 2D array of subplot numbers (with 0 for empty cells) + + Returns + ------- + bool + True if layout is orthogonal, False otherwise + """ + if array.size == 0: + return True + + # Get unique subplot numbers (excluding 0) + subplot_nums = np.unique(array[array != 0]) + + if len(subplot_nums) == 0: + return True + + # Reject layouts with interior gaps (zeros surrounded by non-zero rows/cols). + row_has = np.any(array != 0, axis=1) + col_has = np.any(array != 0, axis=0) + if np.any((array == 0) & row_has[:, None] & col_has[None, :]): + return False + + # For each subplot, get its bounding box + bboxes = {} + for num in subplot_nums: + rows, cols = np.where(array == num) + bboxes[num] = { + "row_min": rows.min(), + "row_max": rows.max(), + "col_min": cols.min(), + "col_max": cols.max(), + } + + # Check if layout is orthogonal by verifying that all vertical and + # horizontal edges align with cell boundaries + # A more sophisticated check: for each row/col boundary, check if + # all subplots either cross it or are completely on one side + + # Collect all unique row and column boundaries + row_boundaries = set() + col_boundaries = set() + + for bbox in bboxes.values(): + row_boundaries.add(bbox["row_min"]) + row_boundaries.add(bbox["row_max"] + 1) + col_boundaries.add(bbox["col_min"]) + col_boundaries.add(bbox["col_max"] + 1) + + # Check if these boundaries create a consistent grid + # For orthogonal layout, we should be able to split the grid + # using these boundaries such that each subplot is a union of cells + + row_boundaries = sorted(row_boundaries) + col_boundaries = sorted(col_boundaries) + + # Create a refined grid + refined_rows = len(row_boundaries) - 1 + refined_cols = len(col_boundaries) - 1 + + if refined_rows == 0 or refined_cols == 0: + return True + + # Map each subplot to refined grid cells + for num in subplot_nums: + rows, cols = np.where(array == num) + + # Check if this subplot occupies a rectangular region in the refined grid + refined_row_indices = set() + refined_col_indices = set() + + for r in rows: + for i, (r_start, r_end) in enumerate( + zip(row_boundaries[:-1], row_boundaries[1:]) + ): + if r_start <= r < r_end: + refined_row_indices.add(i) + + for c in cols: + for i, (c_start, c_end) in enumerate( + zip(col_boundaries[:-1], col_boundaries[1:]) + ): + if c_start <= c < c_end: + refined_col_indices.add(i) + + # Check if indices form a rectangle + if refined_row_indices and refined_col_indices: + r_min, r_max = min(refined_row_indices), max(refined_row_indices) + c_min, c_max = min(refined_col_indices), max(refined_col_indices) + + expected_cells = (r_max - r_min + 1) * (c_max - c_min + 1) + actual_cells = len(refined_row_indices) * len(refined_col_indices) + + if expected_cells != actual_cells: + return False + + return True + + +class UltraLayoutSolver: + """ + UltraLayout: Constraint-based layout solver using kiwisolver for subplot positioning. + + This solver computes aesthetically pleasing positions for subplots in + non-orthogonal arrangements by using constraint satisfaction, providing + a superior layout experience for complex subplot arrangements. + """ + + def __init__( + self, + array: np.ndarray, + figwidth: float = 10.0, + figheight: float = 8.0, + wspace: Optional[List[float]] = None, + hspace: Optional[List[float]] = None, + left: float = 0.125, + right: float = 0.125, + top: float = 0.125, + bottom: float = 0.125, + wratios: Optional[List[float]] = None, + hratios: Optional[List[float]] = None, + wpanels: Optional[List[bool]] = None, + hpanels: Optional[List[bool]] = None, + ): + """ + Initialize the UltraLayout solver. + + Parameters + ---------- + array : np.ndarray + 2D array of subplot numbers (with 0 for empty cells) + figwidth, figheight : float + Figure dimensions in inches + wspace, hspace : list of float, optional + Spacing between columns and rows in inches + left, right, top, bottom : float + Margins in inches + wratios, hratios : list of float, optional + Width and height ratios for columns and rows + wpanels, hpanels : list of bool, optional + Flags indicating panel columns or rows with fixed widths/heights. + """ + if not KIWI_AVAILABLE: + raise ImportError( + "kiwisolver is required for non-orthogonal layouts. " + "Install it with: pip install kiwisolver" + ) + + self.array = array + self.nrows, self.ncols = array.shape + self.figwidth = figwidth + self.figheight = figheight + self.left_margin = left + self.right_margin = right + self.top_margin = top + self.bottom_margin = bottom + + # Get subplot numbers + self.subplot_nums = sorted(np.unique(array[array != 0])) + + # Set up spacing + if wspace is None: + self.wspace = [0.2] * (self.ncols - 1) if self.ncols > 1 else [] + else: + self.wspace = list(wspace) + + if hspace is None: + self.hspace = [0.2] * (self.nrows - 1) if self.nrows > 1 else [] + else: + self.hspace = list(hspace) + + # Set up ratios + if wratios is None: + self.wratios = [1.0] * self.ncols + else: + self.wratios = list(wratios) + + if hratios is None: + self.hratios = [1.0] * self.nrows + else: + self.hratios = list(hratios) + + # Set up panel flags (True for fixed-width panel slots). + if wpanels is None: + self.wpanels = [False] * self.ncols + else: + if len(wpanels) != self.ncols: + raise ValueError("wpanels length must match number of columns.") + self.wpanels = [bool(val) for val in wpanels] + if hpanels is None: + self.hpanels = [False] * self.nrows + else: + if len(hpanels) != self.nrows: + raise ValueError("hpanels length must match number of rows.") + self.hpanels = [bool(val) for val in hpanels] + + # Initialize solver + self.solver = Solver() + self._setup_variables() + self._setup_constraints() + + def _setup_variables(self): + """Create kiwisolver variables for all grid lines.""" + # Vertical lines (left edges of columns + right edge of last column) + self.col_lefts = [Variable(f"col_{i}_left") for i in range(self.ncols)] + self.col_rights = [Variable(f"col_{i}_right") for i in range(self.ncols)] + + # Horizontal lines (top edges of rows + bottom edge of last row) + # Note: in figure coordinates, top is higher value + self.row_tops = [Variable(f"row_{i}_top") for i in range(self.nrows)] + self.row_bottoms = [Variable(f"row_{i}_bottom") for i in range(self.nrows)] + + def _setup_constraints(self): + """Set up all constraints for the layout.""" + # 1. Figure boundary constraints + self.solver.addConstraint(self.col_lefts[0] == self.left_margin / self.figwidth) + self.solver.addConstraint( + self.col_rights[-1] == 1.0 - self.right_margin / self.figwidth + ) + self.solver.addConstraint( + self.row_bottoms[-1] == self.bottom_margin / self.figheight + ) + self.solver.addConstraint( + self.row_tops[0] == 1.0 - self.top_margin / self.figheight + ) + + # 2. Column continuity and spacing constraints + for i in range(self.ncols - 1): + # Right edge of column i connects to left edge of column i+1 with spacing + spacing = self.wspace[i] / self.figwidth if i < len(self.wspace) else 0 + self.solver.addConstraint( + self.col_rights[i] + spacing == self.col_lefts[i + 1] + ) + + # 3. Row continuity and spacing constraints + for i in range(self.nrows - 1): + # Bottom edge of row i connects to top edge of row i+1 with spacing + spacing = self.hspace[i] / self.figheight if i < len(self.hspace) else 0 + self.solver.addConstraint( + self.row_bottoms[i] == self.row_tops[i + 1] + spacing + ) + + # 4. Width constraints (panel slots are fixed, remaining slots use ratios) + total_width = 1.0 - (self.left_margin + self.right_margin) / self.figwidth + if self.ncols > 1: + spacing_total = sum(self.wspace) / self.figwidth + else: + spacing_total = 0 + available_width = total_width - spacing_total + fixed_width = 0.0 + ratio_sum = 0.0 + for i in range(self.ncols): + if self.wpanels[i]: + fixed_width += self.wratios[i] / self.figwidth + else: + ratio_sum += self.wratios[i] + remaining_width = max(0.0, available_width - fixed_width) + if ratio_sum == 0: + ratio_sum = 1.0 + + for i in range(self.ncols): + if self.wpanels[i]: + width = self.wratios[i] / self.figwidth + else: + width = remaining_width * self.wratios[i] / ratio_sum + self.solver.addConstraint(self.col_rights[i] == self.col_lefts[i] + width) + + # 5. Height constraints (panel slots are fixed, remaining slots use ratios) + total_height = 1.0 - (self.top_margin + self.bottom_margin) / self.figheight + if self.nrows > 1: + spacing_total = sum(self.hspace) / self.figheight + else: + spacing_total = 0 + available_height = total_height - spacing_total + fixed_height = 0.0 + ratio_sum = 0.0 + for i in range(self.nrows): + if self.hpanels[i]: + fixed_height += self.hratios[i] / self.figheight + else: + ratio_sum += self.hratios[i] + remaining_height = max(0.0, available_height - fixed_height) + if ratio_sum == 0: + ratio_sum = 1.0 + + for i in range(self.nrows): + if self.hpanels[i]: + height = self.hratios[i] / self.figheight + else: + height = remaining_height * self.hratios[i] / ratio_sum + self.solver.addConstraint(self.row_tops[i] == self.row_bottoms[i] + height) + + def solve(self) -> Dict[int, Tuple[float, float, float, float]]: + """ + Solve the constraint system and return subplot positions. + + Returns + ------- + dict + Dictionary mapping subplot numbers to (left, bottom, width, height) + in figure-relative coordinates [0, 1] + """ + # Solve the constraint system + self.solver.updateVariables() + + # Extract positions for each subplot + positions = {} + col_lefts = [v.value() for v in self.col_lefts] + col_rights = [v.value() for v in self.col_rights] + row_tops = [v.value() for v in self.row_tops] + row_bottoms = [v.value() for v in self.row_bottoms] + col_widths = [right - left for left, right in zip(col_lefts, col_rights)] + row_heights = [top - bottom for top, bottom in zip(row_tops, row_bottoms)] + + base_wgap = None + for i in range(self.ncols - 1): + if not self.wpanels[i] and not self.wpanels[i + 1]: + gap = col_lefts[i + 1] - col_rights[i] + if base_wgap is None or gap < base_wgap: + base_wgap = gap + if base_wgap is None: + base_wgap = 0.0 + + base_hgap = None + for i in range(self.nrows - 1): + if not self.hpanels[i] and not self.hpanels[i + 1]: + gap = row_bottoms[i] - row_tops[i + 1] + if base_hgap is None or gap < base_hgap: + base_hgap = gap + if base_hgap is None: + base_hgap = 0.0 + + def _adjust_span( + spans: List[int], + start: float, + end: float, + sizes: List[float], + panels: List[bool], + base_gap: float, + ) -> Tuple[float, float]: + effective = [i for i in spans if not panels[i]] + if len(effective) <= 1: + return start, end + desired = sum(sizes[i] for i in effective) + # Collapse inter-column/row gaps inside spans to keep widths consistent. + # This avoids widening subplots that cross internal panel slots. + full = end - start + if desired < full: + offset = 0.5 * (full - desired) + start = start + offset + end = start + desired + return start, end + + for num in self.subplot_nums: + rows, cols = np.where(self.array == num) + row_min, row_max = rows.min(), rows.max() + col_min, col_max = cols.min(), cols.max() + + # Get the bounding box from the grid lines + left = col_lefts[col_min] + right = col_rights[col_max] + bottom = row_bottoms[row_max] + top = row_tops[row_min] + + span_cols = list(range(col_min, col_max + 1)) + span_rows = list(range(row_min, row_max + 1)) + + left, right = _adjust_span( + span_cols, + left, + right, + col_widths, + self.wpanels, + base_wgap, + ) + top, bottom = _adjust_span( + span_rows, + top, + bottom, + row_heights, + self.hpanels, + base_hgap, + ) + + width = right - left + height = top - bottom + + positions[num] = (left, bottom, width, height) + + return positions + + +class ColorbarLayoutSolver: + """ + Constraint-based solver for inset colorbar frame alignment. + """ + + def __init__( + self, + loc: str, + cb_width: float, + cb_height: float, + pad_left: float, + pad_right: float, + pad_bottom: float, + pad_top: float, + ): + if not KIWI_AVAILABLE: + raise ImportError( + "kiwisolver is required for constraint-based colorbar layout. " + "Install it with: pip install kiwisolver" + ) + self.loc = loc + self.cb_width = cb_width + self.cb_height = cb_height + self.pad_left = pad_left + self.pad_right = pad_right + self.pad_bottom = pad_bottom + self.pad_top = pad_top + self.frame_width = pad_left + cb_width + pad_right + self.frame_height = pad_bottom + cb_height + pad_top + + self.solver = Solver() + self.xframe = Variable("cb_frame_x") + self.yframe = Variable("cb_frame_y") + self.cb_x = Variable("cb_x") + self.cb_y = Variable("cb_y") + self._setup_constraints() + + def _setup_constraints(self): + self.solver.addConstraint(self.cb_x == self.xframe + self.pad_left) + self.solver.addConstraint(self.cb_y == self.yframe + self.pad_bottom) + self.solver.addConstraint(self.xframe >= 0) + self.solver.addConstraint(self.yframe >= 0) + self.solver.addConstraint(self.xframe + self.frame_width <= 1) + self.solver.addConstraint(self.yframe + self.frame_height <= 1) + + loc = self.loc or "lower right" + if loc not in ("upper right", "upper left", "lower left", "lower right"): + loc = "lower right" + if "left" in loc: + self.solver.addConstraint(self.xframe == 0) + elif "right" in loc: + self.solver.addConstraint(self.xframe + self.frame_width == 1) + if "upper" in loc: + self.solver.addConstraint(self.yframe + self.frame_height == 1) + elif "lower" in loc: + self.solver.addConstraint(self.yframe == 0) + + def solve(self) -> Dict[str, Tuple[float, float, float, float]]: + """ + Solve the constraint system and return inset and frame bounds. + """ + self.solver.updateVariables() + xframe = self.xframe.value() + yframe = self.yframe.value() + cb_x = self.cb_x.value() + cb_y = self.cb_y.value() + return { + "frame": (xframe, yframe, self.frame_width, self.frame_height), + "inset": (cb_x, cb_y, self.cb_width, self.cb_height), + } + + +def compute_ultra_positions( + array: np.ndarray, + figwidth: float = 10.0, + figheight: float = 8.0, + wspace: Optional[List[float]] = None, + hspace: Optional[List[float]] = None, + left: float = 0.125, + right: float = 0.125, + top: float = 0.125, + bottom: float = 0.125, + wratios: Optional[List[float]] = None, + hratios: Optional[List[float]] = None, + wpanels: Optional[List[bool]] = None, + hpanels: Optional[List[bool]] = None, +) -> Dict[int, Tuple[float, float, float, float]]: + """ + Compute subplot positions using UltraLayout for non-orthogonal layouts. + + Parameters + ---------- + array : np.ndarray + 2D array of subplot numbers (with 0 for empty cells) + figwidth, figheight : float + Figure dimensions in inches + wspace, hspace : list of float, optional + Spacing between columns and rows in inches + left, right, top, bottom : float + Margins in inches + wratios, hratios : list of float, optional + Width and height ratios for columns and rows + wpanels, hpanels : list of bool, optional + Flags indicating panel columns or rows with fixed widths/heights. + + Returns + ------- + dict + Dictionary mapping subplot numbers to (left, bottom, width, height) + in figure-relative coordinates [0, 1] + + Examples + -------- + >>> array = np.array([[1, 1, 2, 2], [0, 3, 3, 0]]) + >>> positions = compute_ultra_positions(array) + >>> positions[3] # Position of subplot 3 + (0.25, 0.125, 0.5, 0.35) + """ + solver = UltraLayoutSolver( + array, + figwidth, + figheight, + wspace, + hspace, + left, + right, + top, + bottom, + wratios, + hratios, + wpanels, + hpanels, + ) + return solver.solve() + + +def get_grid_positions_ultra( + array: np.ndarray, + figwidth: float, + figheight: float, + wspace: Optional[List[float]] = None, + hspace: Optional[List[float]] = None, + left: float = 0.125, + right: float = 0.125, + top: float = 0.125, + bottom: float = 0.125, + wratios: Optional[List[float]] = None, + hratios: Optional[List[float]] = None, + wpanels: Optional[List[bool]] = None, + hpanels: Optional[List[bool]] = None, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + Get grid line positions using UltraLayout. + + This returns arrays of grid line positions similar to GridSpec.get_grid_positions(), + but computed using UltraLayout's constraint satisfaction for better handling of non-orthogonal layouts. + + Parameters + ---------- + array : np.ndarray + 2D array of subplot numbers + figwidth, figheight : float + Figure dimensions in inches + wspace, hspace : list of float, optional + Spacing between columns and rows in inches + left, right, top, bottom : float + Margins in inches + wratios, hratios : list of float, optional + Width and height ratios for columns and rows + wpanels, hpanels : list of bool, optional + Flags indicating panel columns or rows with fixed widths/heights. + + Returns + ------- + bottoms, tops, lefts, rights : np.ndarray + Arrays of grid line positions for each cell + """ + solver = UltraLayoutSolver( + array, + figwidth, + figheight, + wspace, + hspace, + left, + right, + top, + bottom, + wratios, + hratios, + wpanels, + hpanels, + ) + solver.solver.updateVariables() + + # Extract grid line positions + lefts = np.array([v.value() for v in solver.col_lefts]) + rights = np.array([v.value() for v in solver.col_rights]) + tops = np.array([v.value() for v in solver.row_tops]) + bottoms = np.array([v.value() for v in solver.row_bottoms]) + + return bottoms, tops, lefts, rights