diff --git a/CHANGELOG.md b/CHANGELOG.md index 82dc15bb..f66d8f9e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - #386, #433: Added `Statevec.fidelity` and `Statevec.isclose` methods for pure-state fidelity computation and equality check up to global phase. +- #387, #444: Improved `Pattern.draw_graph` visualization: MBQC literature node shapes (squares for inputs, filled/empty circles for measured/output), solid gray edges, measurement order arrow, `show_measurements` and `show_legend` parameters. - #447: `Pattern.perform_pauli_pushing` which calls `StandardizedPattern.perform_pauli_pushing`. diff --git a/examples/visualization.py b/examples/visualization.py index 2c08fbeb..28668caa 100644 --- a/examples/visualization.py +++ b/examples/visualization.py @@ -33,13 +33,13 @@ pattern = circuit.transpile().pattern # note that this visualization is not always consistent with the correction set of pattern, # since we find the correction sets with flow-finding algorithms. -pattern.draw_graph(flow_from_pattern=False, show_measurement_planes=True) +pattern.draw_graph(flow_from_pattern=False, show_measurements=True) # %% # next, show the gflow: pattern.remove_input_nodes() pattern.perform_pauli_measurements() -pattern.draw_graph(flow_from_pattern=False, show_measurement_planes=True, node_distance=(1, 0.6)) +pattern.draw_graph(flow_from_pattern=False, show_measurements=True, node_distance=(1, 0.6)) # %% @@ -49,7 +49,7 @@ # # node_distance argument specifies the scale of the node arrangement in x and y directions. -pattern.draw_graph(flow_from_pattern=True, show_measurement_planes=True, node_distance=(0.7, 0.6)) +pattern.draw_graph(flow_from_pattern=True, show_measurements=True, node_distance=(0.7, 0.6)) # %% # Instead of the measurement planes, we can show the local Clifford of the resource graph. @@ -75,7 +75,7 @@ measurements = {node: Measurement.XY(0) for node in graph.nodes() if node not in outputs} og = OpenGraph(graph, inputs, outputs, measurements) vis = GraphVisualizer(og) -vis.visualize(show_measurement_planes=True) +vis.visualize(show_measurements=True) # %% @@ -91,6 +91,6 @@ } og = OpenGraph(graph, inputs, outputs, measurements) vis = GraphVisualizer(og) -vis.visualize(show_measurement_planes=True) +vis.visualize(show_measurements=True) # %% diff --git a/graphix/pattern.py b/graphix/pattern.py index 5e26933d..61f693ad 100644 --- a/graphix/pattern.py +++ b/graphix/pattern.py @@ -1450,7 +1450,9 @@ def draw_graph( flow_from_pattern: bool = True, show_pauli_measurement: bool = True, show_local_clifford: bool = False, - show_measurement_planes: bool = False, + show_measurements: bool = False, + show_legend: bool = False, + show_measurement_order: bool = True, show_loop: bool = True, node_distance: tuple[float, float] = (1, 1), figsize: tuple[int, int] | None = None, @@ -1458,16 +1460,23 @@ def draw_graph( ) -> None: """Visualize the underlying graph of the pattern with flow or gflow structure. + Nodes are drawn following MBQC literature conventions: inputs as squares, + measured nodes as filled circles, and outputs as empty circles. + Parameters ---------- flow_from_pattern : bool If True, the command sequence of the pattern is used to derive flow or gflow structure. If False, only the underlying graph is used. show_pauli_measurement : bool - If True, the nodes with Pauli measurement angles are colored light blue. + If True, Pauli-measured nodes are filled with blue instead of black. show_local_clifford : bool If True, indexes of the local Clifford operator are displayed adjacent to the nodes. - show_measurement_planes : bool - If True, measurement planes are displayed adjacent to the nodes. + show_measurements : bool + If True, measurement labels are displayed adjacent to the nodes. + show_legend : bool + If True, a legend is displayed indicating node types and edge meanings. + show_measurement_order : bool + If True, layer labels and a measurement order arrow are displayed below the graph. show_loop : bool whether or not to show loops for graphs with gflow. defaulted to True. node_distance : tuple @@ -1488,7 +1497,9 @@ def draw_graph( pattern=self.copy(), show_pauli_measurement=show_pauli_measurement, show_local_clifford=show_local_clifford, - show_measurement_planes=show_measurement_planes, + show_measurements=show_measurements, + show_legend=show_legend, + show_measurement_order=show_measurement_order, show_loop=show_loop, node_distance=node_distance, figsize=figsize, @@ -1498,7 +1509,9 @@ def draw_graph( vis.visualize( show_pauli_measurement=show_pauli_measurement, show_local_clifford=show_local_clifford, - show_measurement_planes=show_measurement_planes, + show_measurements=show_measurements, + show_legend=show_legend, + show_measurement_order=show_measurement_order, show_loop=show_loop, node_distance=node_distance, figsize=figsize, diff --git a/graphix/visualization.py b/graphix/visualization.py index 4322b82b..a48b3977 100644 --- a/graphix/visualization.py +++ b/graphix/visualization.py @@ -9,13 +9,16 @@ import networkx as nx import numpy as np from matplotlib import pyplot as plt +from matplotlib.lines import Line2D +from matplotlib.transforms import offset_copy from graphix.flow.exceptions import FlowError -from graphix.measurements import Measurement, PauliMeasurement +from graphix.measurements import BlochMeasurement, Measurement, PauliMeasurement # OpenGraph is needed for dataclass from graphix.opengraph import OpenGraph # noqa: TC001 from graphix.optimization import StandardizedPattern +from graphix.pretty_print import OutputFormat, angle_to_str if TYPE_CHECKING: from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence @@ -56,14 +59,15 @@ def visualize( self, show_pauli_measurement: bool = True, show_local_clifford: bool = False, - show_measurement_planes: bool = False, + show_measurements: bool = False, + show_legend: bool = False, + show_measurement_order: bool = True, show_loop: bool = True, node_distance: tuple[float, float] = (1, 1), figsize: tuple[int, int] | None = None, filename: Path | None = None, ) -> None: - """ - Visualize the graph with flow or gflow structure. + """Visualize the graph with flow or gflow structure. If there exists a flow structure, then the graph is visualized with the flow structure. If flow structure is not found and there exists a gflow structure, then the graph is visualized @@ -73,11 +77,15 @@ def visualize( Parameters ---------- show_pauli_measurement : bool - If True, the nodes with Pauli measurement angles are colored light blue. + If True, Pauli-measured nodes are filled with blue instead of black. show_local_clifford : bool If True, indexes of the local Clifford operator are displayed adjacent to the nodes. - show_measurement_planes : bool - If True, the measurement planes are displayed adjacent to the nodes. + show_measurements : bool + If True, measurement labels are displayed adjacent to the nodes. + show_legend : bool + If True, a legend is displayed indicating node types and edge meanings. + show_measurement_order : bool + If True, layer labels and a measurement order arrow are displayed below the graph. show_loop : bool whether or not to show loops for graphs with gflow. defaulted to True. node_distance : tuple @@ -138,7 +146,9 @@ def place_paths( None, show_pauli_measurement, show_local_clifford, - show_measurement_planes, + show_measurements, + show_legend, + show_measurement_order, show_loop, node_distance, figsize, @@ -150,14 +160,15 @@ def visualize_from_pattern( pattern: Pattern, show_pauli_measurement: bool = True, show_local_clifford: bool = False, - show_measurement_planes: bool = False, + show_measurements: bool = False, + show_legend: bool = False, + show_measurement_order: bool = True, show_loop: bool = True, node_distance: tuple[float, float] = (1, 1), figsize: tuple[int, int] | None = None, filename: Path | None = None, ) -> None: - """ - Visualize the graph with flow or gflow structure found from the given pattern. + """Visualize the graph with flow or gflow structure found from the given pattern. If pattern sequence is consistent with flow structure, then the graph is visualized with the flow structure. If it is not consistent with flow structure and consistent with gflow structure, then the graph is visualized @@ -168,11 +179,15 @@ def visualize_from_pattern( pattern : Pattern pattern to be visualized show_pauli_measurement : bool - If True, the nodes with Pauli measurement angles are colored light blue. + If True, Pauli-measured nodes are filled with blue instead of black. show_local_clifford : bool If True, indexes of the local Clifford operator are displayed adjacent to the nodes. - show_measurement_planes : bool - If True, the measurement planes are displayed adjacent to the nodes. + show_measurements : bool + If True, measurement labels are displayed adjacent to the nodes. + show_legend : bool + If True, a legend is displayed indicating node types and edge meanings. + show_measurement_order : bool + If True, layer labels and a measurement order arrow are displayed below the graph. show_loop : bool whether or not to show loops for graphs with gflow. defaulted to True. node_distance : tuple @@ -235,7 +250,9 @@ def place_paths( corrections, show_pauli_measurement, show_local_clifford, - show_measurement_planes, + show_measurements, + show_legend, + show_measurement_order, show_loop, node_distance, figsize, @@ -252,35 +269,87 @@ def _shorten_path(path: Sequence[_Point]) -> list[_Point]: new_path[-1] = last_edge return new_path - def _draw_labels(self, pos: Mapping[int, _Point]) -> None: + def _draw_labels(self, pos: Mapping[int, _Point], font_color: Mapping[int, str] | str = "black") -> None: + """Draw node number labels centered inside their nodes. + + Parameters + ---------- + pos : Mapping[int, tuple[float, float]] + Dictionary of node positions. + font_color : Mapping[int, str] | str + Font color for node labels. Can be a single color string or a mapping from node to color. + """ fontsize = 12 if max(self.og.graph.nodes(), default=0) >= 100: fontsize = int(fontsize * 2 / len(str(max(self.og.graph.nodes())))) - nx.draw_networkx_labels(self.og.graph, pos, font_size=fontsize) - - def __draw_nodes_role(self, pos: Mapping[int, _Point], show_pauli_measurement: bool = False) -> None: - """ - Draw the nodes with different colors based on their role (input, output, or other). + ax = plt.gca() + # Shift text down by 4 points to compensate for matplotlib centering + # the full glyph bounding box (including descender space) rather than + # the visible digit area. + label_transform = offset_copy(ax.transData, fig=plt.gcf(), y=-4.0, units="points") + for node in self.og.graph.nodes(): + x, y = pos[node] + color = font_color.get(node, "black") if isinstance(font_color, dict) else font_color + plt.text( + x, + y, + str(node), + fontsize=fontsize, + color=color, + ha="center", + va="center", + zorder=3, + transform=label_transform, + ) + + def __draw_nodes_role(self, pos: Mapping[int, _Point], show_pauli_measurement: bool = False) -> dict[int, str]: + """Draw the nodes with shapes and fills following MBQC literature conventions. + + Input nodes are drawn as squares, measured (non-output) nodes as filled circles, + and output nodes as empty circles. Pauli-measured nodes are optionally distinguished + with a blue fill. Parameters ---------- pos : Mapping[int, tuple[float, float]] - dictionary of node positions. + Dictionary of node positions. show_pauli_measurement : bool - If True, the nodes with Pauli measurement angles are colored light blue. + If True, Pauli-measured nodes are filled with blue instead of black. + + Returns + ------- + dict[int, str] + Mapping from node index to font color for label rendering. """ + font_colors: dict[int, str] = {} + for node in self.og.graph.nodes(): - color = "black" # default color for 'other' nodes - inner_color = "white" - if node in self.og.input_nodes: - color = "red" + marker = "s" if node in self.og.input_nodes else "o" + if node in self.og.output_nodes: - inner_color = "lightgray" - elif show_pauli_measurement and isinstance(self.og.measurements[node], PauliMeasurement): - inner_color = "lightblue" + facecolor = "white" + elif ( + show_pauli_measurement + and node in self.og.measurements + and isinstance(self.og.measurements[node], PauliMeasurement) + ): + facecolor = "#4292c6" + else: + facecolor = "black" + + font_colors[node] = "white" if facecolor == "black" else "black" + plt.scatter( - *pos[node], edgecolor=color, facecolor=inner_color, s=350, zorder=2 - ) # Draw the nodes manually with scatter() + *pos[node], + marker=marker, + edgecolor="black", + facecolor=facecolor, + s=350, + zorder=2, + linewidths=1.5, + ) + + return font_colors def visualize_graph( self, @@ -292,19 +361,19 @@ def visualize_graph( corrections: tuple[Mapping[int, AbstractSet[int]], Mapping[int, AbstractSet[int]]] | None, show_pauli_measurement: bool = True, show_local_clifford: bool = False, - show_measurement_planes: bool = False, + show_measurements: bool = False, + show_legend: bool = False, + show_measurement_order: bool = True, show_loop: bool = True, node_distance: tuple[float, float] = (1, 1), figsize: _Point | None = None, filename: Path | None = None, ) -> None: - """ - Visualizes the graph. + """Visualize the graph. - Nodes are colored based on their role (input, output, or other) and edges are depicted as arrows - or dashed lines depending on whether they are in the flow mapping. Vertical dashed lines separate - different layers of the graph. This function does not return anything but plots the graph - using matplotlib's pyplot. + Nodes are drawn following MBQC literature conventions: inputs as squares, + measured nodes as filled circles, and outputs as empty circles. Graph edges + are dashed lines and flow arrows indicate corrections. Parameters ---------- @@ -319,11 +388,15 @@ def visualize_graph( corrections: tuple[Mapping[int, AbstractSet[int]], Mapping[int, AbstractSet[int]]] | None X and Z corrections if any. show_pauli_measurement : bool - If True, the nodes with Pauli measurement angles are colored light blue. + If True, Pauli-measured nodes are filled with blue instead of black. show_local_clifford : bool If True, indexes of the local Clifford operator are displayed adjacent to the nodes. - show_measurement_planes : bool - If True, the measurement planes are displayed adjacent to the nodes. + show_measurements : bool + If True, measurement labels are displayed adjacent to the nodes. + show_legend : bool + If True, a legend is displayed indicating node types and edge meanings. + show_measurement_order : bool + If True, layer labels and a measurement order arrow are displayed below the graph. show_loop : bool whether or not to show loops for graphs with gflow. defaulted to True. node_distance : tuple @@ -341,7 +414,7 @@ def visualize_graph( edge_path, arrow_path = place_paths(pos) - if corrections is not None: + if show_legend or corrections is not None: # add some padding to the right for the legend figsize = (figsize[0] + 3.0, figsize[1]) @@ -349,10 +422,10 @@ def visualize_graph( for edge, path in edge_path.items(): if len(path) == 2: - nx.draw_networkx_edges(self.og.graph, pos, edgelist=[edge], style="dashed", alpha=0.7) + nx.draw_networkx_edges(self.og.graph, pos, edgelist=[edge], style="dashed", alpha=0.6) # type: ignore[no-untyped-call] else: curve = self._bezier_curve_linspace(path) - plt.plot(curve[:, 0], curve[:, 1], "k--", linewidth=1, alpha=0.7) + plt.plot(curve[:, 0], curve[:, 1], "k--", linewidth=1, alpha=0.6) if arrow_path is not None: for arrow, path in arrow_path.items(): @@ -378,8 +451,14 @@ def visualize_graph( ) elif len(path) == 2: # straight line nx.draw_networkx_edges( - self.og.graph, pos, edgelist=[arrow], edge_color=color, arrowstyle="->", arrows=True - ) + self.og.graph, + pos, + edgelist=[arrow], + edge_color=color, + arrowstyle="->", + arrows=True, + node_size=350, + ) # type: ignore[no-untyped-call] else: new_path = GraphVisualizer._shorten_path(path) curve = self._bezier_curve_linspace(new_path) @@ -391,19 +470,21 @@ def visualize_graph( arrowprops={"arrowstyle": "->", "color": color, "lw": 1}, ) - self.__draw_nodes_role(pos, show_pauli_measurement) + font_colors = self.__draw_nodes_role(pos, show_pauli_measurement) if show_local_clifford: self.__draw_local_clifford(pos) - if show_measurement_planes: - self.__draw_measurement_planes(pos) + if show_measurements: + self.__draw_measurement_labels(pos) - self._draw_labels(pos) + self._draw_labels(pos, font_colors) - if corrections is not None: - # legend for arrow colors - plt.plot([], [], "k--", alpha=0.7, label="graph edge") + if show_legend: + self.__draw_legend(show_pauli_measurement, corrections, arrow_path is not None) + elif corrections is not None: + # backward-compatible minimal legend for correction arrows + plt.plot([], [], "k--", alpha=0.6, label="graph edge") plt.plot([], [], color="tab:red", label="xflow") plt.plot([], [], color="tab:green", label="zflow") plt.plot([], [], color="tab:brown", label="xflow and zflow") @@ -414,26 +495,54 @@ def visualize_graph( y_min = min((pos[node][1] for node in self.og.graph.nodes()), default=0) # Get the minimum y coordinate y_max = max((pos[node][1] for node in self.og.graph.nodes()), default=0) # Get the maximum y coordinate - if l_k is not None and l_k: - # Draw the vertical lines to separate different layers - for layer in range(min(l_k.values()), max(l_k.values())): + has_layers = l_k is not None and len(l_k) > 0 + show_layers = show_measurement_order and has_layers + if show_layers and l_k is not None: + l_min_val = min(l_k.values()) + l_max_val = max(l_k.values()) + # Dotted vertical lines to separate layers (distinct from dashed graph edges) + for layer in range(l_min_val, l_max_val): plt.axvline( - x=(layer + 0.5) * node_distance[0], color="gray", linestyle="--", alpha=0.5 - ) # Draw line between layers - for layer in range(min(l_k.values()), max(l_k.values()) + 1): + x=(layer + 0.5) * node_distance[0], + color="lightgray", + linestyle=":", + alpha=0.7, + linewidth=0.8, + ) + # Draw layer numbers below nodes + for layer in range(l_min_val, l_max_val + 1): plt.text( - layer * node_distance[0], y_min - 0.5, f"L: {max(l_k.values()) - layer}", ha="center", va="top" - ) # Add layer label at bottom + layer * node_distance[0], + y_min - 0.4, + str(l_max_val - layer), + ha="center", + va="top", + fontsize=8, + color="gray", + ) + # Draw horizontal arrow indicating measurement order with "Layer" label below + if l_max_val > l_min_val: + arrow_y = y_min - 0.7 + plt.annotate( + "", + xy=(l_max_val * node_distance[0] + 0.3, arrow_y), + xytext=(l_min_val * node_distance[0] - 0.3, arrow_y), + arrowprops={"arrowstyle": "->", "color": "gray", "lw": 1.2}, + ) + mid_x = (l_min_val + l_max_val) / 2 * node_distance[0] + plt.text(mid_x, arrow_y - 0.15, "Layer", ha="center", va="top", fontsize=8, color="gray") - plt.xlim( - x_min - 0.5 * node_distance[0], x_max + 0.5 * node_distance[0] - ) # Add some padding to the left and right - plt.ylim(y_min - 1, y_max + 0.5) # Add some padding to the top and bottom + plt.gca().set_axis_off() + x_margin = 0.7 * node_distance[0] if show_measurements else 0.5 * node_distance[0] + plt.xlim(x_min - x_margin, x_max + x_margin) + top_margin = 0.7 if show_measurements else 0.5 + bottom_margin = 1.3 if show_layers else 0.5 + plt.ylim(y_min - bottom_margin, y_max + top_margin) if filename is None: plt.show() else: - plt.savefig(filename) + plt.savefig(filename, bbox_inches="tight") def __draw_local_clifford(self, pos: Mapping[int, _Point]) -> None: if self.local_clifford is not None: @@ -441,12 +550,174 @@ def __draw_local_clifford(self, pos: Mapping[int, _Point]) -> None: x, y = pos[node] + np.array([0.2, 0.2]) plt.text(x, y, f"{self.local_clifford[node]}", fontsize=10, zorder=3) - def __draw_measurement_planes(self, pos: Mapping[int, _Point]) -> None: + @staticmethod + def __draw_legend( + show_pauli_measurement: bool, + corrections: tuple[Mapping[int, AbstractSet[int]], Mapping[int, AbstractSet[int]]] | None, + has_arrows: bool, + ) -> None: + """Draw a legend indicating node types and edge meanings. + + Parameters + ---------- + show_pauli_measurement : bool + Whether Pauli-measured nodes are visually distinct. + corrections : tuple or None + X and Z corrections if any, to determine arrow legend entries. + has_arrows : bool + Whether flow arrows are present in the graph. + """ + elements: list[Line2D] = [ + Line2D( + [0], + [0], + marker="s", + color="w", + markerfacecolor="white", + markeredgecolor="black", + markersize=10, + label="Input", + ), + Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor="black", + markeredgecolor="black", + markersize=10, + label="Measured", + ), + ] + if show_pauli_measurement: + elements.append( + Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor="#4292c6", + markeredgecolor="black", + markersize=10, + label="Pauli-measured", + ) + ) + elements.extend( + [ + Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor="white", + markeredgecolor="black", + markersize=10, + label="Output", + ), + Line2D([0], [0], color="black", linewidth=1, alpha=0.6, linestyle="dashed", label="Graph edge"), + ] + ) + + if corrections is not None: + elements.extend( + [ + Line2D([0], [0], color="tab:red", linewidth=1, label="X-correction"), + Line2D([0], [0], color="tab:green", linewidth=1, label="Z-correction"), + Line2D([0], [0], color="tab:brown", linewidth=1, label="X & Z-correction"), + ] + ) + elif has_arrows: + elements.append(Line2D([0], [0], color="black", linewidth=1, label="Flow")) + + plt.legend(handles=elements, loc="center left", fontsize=9, bbox_to_anchor=(1, 0.5)) + + def __draw_measurement_labels(self, pos: Mapping[int, _Point]) -> None: + """Draw measurement labels near measured nodes, adaptively placed to avoid overlap. + + For each measured node the method picks the direction (above, upper-right, + or upper-left) whose candidate position is farthest from every other node, + so labels do not sit on top of neighbouring nodes in dense layouts. + + Parameters + ---------- + pos : Mapping[int, tuple[float, float]] + Dictionary of node positions. + """ + # Candidate offsets: (dx, dy, horizontal-alignment, vertical-alignment) + candidates: list[tuple[float, float, str, str]] = [ + (0.15, -0.15, "left", "top"), # lower-right (preferred, like original) + (-0.15, -0.15, "right", "top"), # lower-left + (0, 0.22, "center", "bottom"), # above (fallback) + ] + all_positions = [pos[n] for n in self.og.graph.nodes()] + placed_labels: list[tuple[float, float]] = [] + + # Compute graph extent so labels don't get placed outside the plot boundary + all_x = [p[0] for p in all_positions] + x_lo = min(all_x) if all_x else 0.0 + x_hi = max(all_x) if all_x else 0.0 + for node, meas in self.og.measurements.items(): - x, y = pos[node] + np.array([0.22, -0.2]) - label = meas.to_plane_or_axis().name + label = self._format_measurement_label(meas) + if label is not None: + x, y = pos[node] + # Exclude candidates that push the label past the leftmost/rightmost + # node column — text would overflow the plot boundary there. + valid = [ + c for c in candidates if not (c[0] < 0 and x <= x_lo + 1e-9) and not (c[0] > 0 and x >= x_hi - 1e-9) + ] + if not valid: + valid = candidates # fallback: use all if none pass the filter + # Pick the direction farthest from other nodes AND already-placed labels + best_dx, best_dy, best_ha, best_va = valid[0] + best_min_dist = -1.0 + for dx, dy, ha, va in valid: + lx, ly = x + dx, y + dy + obstacles = [(ox, oy) for ox, oy in all_positions if (ox, oy) != (x, y)] + obstacles.extend(placed_labels) + other_dists = [((lx - ox) ** 2 + (ly - oy) ** 2) ** 0.5 for ox, oy in obstacles] + min_dist = min(other_dists) if other_dists else float("inf") + if min_dist > best_min_dist: + best_min_dist = min_dist + best_dx, best_dy, best_ha, best_va = dx, dy, ha, va + placed_labels.append((x + best_dx, y + best_dy)) + plt.text( + x + best_dx, + y + best_dy, + label, + fontsize=8, + ha=best_ha, + va=best_va, + zorder=3, + bbox={"boxstyle": "round,pad=0.1", "facecolor": "white", "edgecolor": "none", "alpha": 0.7}, + ) - plt.text(x, y, label, fontsize=9, zorder=3) + @staticmethod + def _format_measurement_label(meas: Measurement) -> str | None: + """Format a measurement label for display. + + Parameters + ---------- + meas : Measurement + The measurement to format. + + Returns + ------- + str | None + Formatted label string, or None if nothing to show. + """ + if isinstance(meas, PauliMeasurement): + return str(meas) + if isinstance(meas, BlochMeasurement): + if isinstance(meas.angle, (int, float)): + angle_str = angle_to_str(meas.angle, OutputFormat.Unicode) + # Fall back to compact notation for non-rational angles + if len(angle_str) > 30: + angle_str = f"{meas.angle:.2f}π" + return f"{meas.plane.name}({angle_str})" + angle_str = str(meas.angle) + return f"{meas.plane.name}({angle_str})" + return None def determine_figsize( self, @@ -474,9 +745,9 @@ def determine_figsize( if l_k is None: if pos is None: raise ValueError("Figure size can only be computed given a layer mapping (l_k) or node positions (pos)") - width = len({pos[node][0] for node in self.og.graph.nodes()}) * 0.8 + width = len({pos[node][0] for node in self.og.graph.nodes()}) * 1.0 else: - width = (max(l_k.values(), default=0) + 1) * 0.8 + width = (max(l_k.values(), default=0) + 1) * 1.0 height = len({pos[node][1] for node in self.og.graph.nodes()}) if pos is not None else len(self.og.output_nodes) return (width * node_distance[0], height * node_distance[1]) @@ -673,7 +944,7 @@ def place_pauli_flow(self, flow: PauliFlow[AbstractMeasurement]) -> dict[int, _P l_reverse = {node: l_max - layer_idx for layer_idx, layer in enumerate(layers) for node in layer} _set_node_attributes(g_prime, l_reverse, "subset") - pos = nx.multipartite_layout(g_prime) + pos = nx.multipartite_layout(g_prime) # type: ignore[no-untyped-call] vert = list({pos[node][1] for node in self.og.graph.nodes()}) vert.sort() @@ -775,7 +1046,7 @@ def place_without_structure(self) -> dict[int, _Point]: l_max = max(layers.values()) l_reverse = {v: l_max - l for v, l in layers.items()} _set_node_attributes(g_prime, l_reverse, "subset") - pos = nx.multipartite_layout(g_prime) + pos = nx.multipartite_layout(g_prime) # type: ignore[no-untyped-call] vert = list({pos[node][1] for node in self.og.graph.nodes()}) vert.sort() index = {y: i for i, y in enumerate(vert)} @@ -799,7 +1070,7 @@ def place_all_corrections(self, layers: Mapping[int, int]) -> dict[int, _Point]: g_prime.add_nodes_from(self.og.graph.nodes()) g_prime.add_edges_from(self.og.graph.edges()) _set_node_attributes(g_prime, layers, "subset") - layout = nx.multipartite_layout(g_prime) + layout = nx.multipartite_layout(g_prime) # type: ignore[no-untyped-call] vert = list({layout[node][1] for node in self.og.graph.nodes()}) vert.sort() index = {y: i for i, y in enumerate(vert)} @@ -903,4 +1174,4 @@ def _check_path(path: Iterable[_Point], target_node_pos: _Point | None = None) - def _set_node_attributes(graph: nx.Graph[_HashableT], attrs: Mapping[_HashableT, object], name: str) -> None: - nx.set_node_attributes(graph, attrs, name=name) # type: ignore[arg-type] + nx.set_node_attributes(graph, attrs, name=name) diff --git a/tests/baseline/test_draw_graph_reference_False.png b/tests/baseline/test_draw_graph_reference_False.png index aca8ad90..97969a24 100644 Binary files a/tests/baseline/test_draw_graph_reference_False.png and b/tests/baseline/test_draw_graph_reference_False.png differ diff --git a/tests/baseline/test_draw_graph_reference_True.png b/tests/baseline/test_draw_graph_reference_True.png index 91e28966..8d2c0c30 100644 Binary files a/tests/baseline/test_draw_graph_reference_True.png and b/tests/baseline/test_draw_graph_reference_True.png differ diff --git a/tests/baseline/test_draw_graph_with_labels.png b/tests/baseline/test_draw_graph_with_labels.png new file mode 100644 index 00000000..598c220b Binary files /dev/null and b/tests/baseline/test_draw_graph_with_labels.png differ diff --git a/tests/baseline/test_draw_graph_without_labels.png b/tests/baseline/test_draw_graph_without_labels.png new file mode 100644 index 00000000..7270834d Binary files /dev/null and b/tests/baseline/test_draw_graph_without_labels.png differ diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 030ada64..7e94935c 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -9,8 +9,8 @@ import pytest from graphix import Circuit, Pattern, command, visualization -from graphix.fundamentals import ANGLE_PI -from graphix.measurements import Measurement +from graphix.fundamentals import ANGLE_PI, Axis, Sign +from graphix.measurements import Measurement, PauliMeasurement from graphix.opengraph import OpenGraph, OpenGraphError from graphix.visualization import GraphVisualizer @@ -144,10 +144,10 @@ def test_draw_graph_show_local_clifford() -> None: @pytest.mark.usefixtures("mock_plot") -def test_draw_graph_show_measurement_planes(fx_rng: Generator) -> None: +def test_draw_graph_show_measurements_basic(fx_rng: Generator) -> None: pattern = example_pflow(fx_rng) pattern.draw_graph( - show_measurement_planes=True, + show_measurements=True, node_distance=(0.7, 0.6), ) @@ -247,6 +247,137 @@ def test_draw_graph_reference(flow_and_not_pauli_presimulate: bool) -> Figure: pattern.perform_pauli_measurements() pattern.standardize() pattern.draw_graph( - flow_from_pattern=flow_and_not_pauli_presimulate, node_distance=(0.7, 0.6), show_measurement_planes=True + flow_from_pattern=flow_and_not_pauli_presimulate, node_distance=(0.7, 0.6), show_measurements=True ) return plt.gcf() + + +@pytest.mark.usefixtures("mock_plot") +def test_draw_graph_show_measurements(fx_rng: Generator) -> None: + pattern = example_flow(fx_rng) + pattern.draw_graph( + show_measurements=True, + node_distance=(0.7, 0.6), + ) + + +@pytest.mark.usefixtures("mock_plot") +def test_draw_graph_show_measurements_pflow(fx_rng: Generator) -> None: + pattern = example_pflow(fx_rng) + pattern.draw_graph( + show_measurements=True, + node_distance=(0.7, 0.6), + ) + + +@pytest.mark.usefixtures("mock_plot") +def test_draw_graph_show_legend(fx_rng: Generator) -> None: + pattern = example_flow(fx_rng) + pattern.draw_graph( + show_legend=True, + node_distance=(0.7, 0.6), + ) + + +@pytest.mark.usefixtures("mock_plot") +def test_draw_graph_show_legend_with_corrections(fx_rng: Generator) -> None: + pattern = example_flow(fx_rng) + pattern.draw_graph( + flow_from_pattern=True, + show_legend=True, + show_pauli_measurement=True, + node_distance=(0.7, 0.6), + ) + + +@pytest.mark.usefixtures("mock_plot") +def test_draw_graph_hide_measurement_order(fx_rng: Generator) -> None: + pattern = example_flow(fx_rng) + pattern.draw_graph( + show_measurement_order=False, + node_distance=(0.7, 0.6), + ) + + +# Compare with baseline/test_draw_graph_with_labels.png +# Update baseline by running: pytest --mpl-generate-path=tests/baseline +@pytest.mark.usefixtures("mock_plot") +@pytest.mark.mpl_image_compare +def test_draw_graph_with_labels() -> Figure: + circuit = Circuit(3) + circuit.cnot(0, 1) + circuit.cnot(2, 1) + circuit.rx(0, ANGLE_PI / 3) + circuit.x(2) + circuit.cnot(2, 1) + pattern = circuit.transpile().pattern + pattern.standardize() + pattern.draw_graph( + flow_from_pattern=True, + show_measurements=True, + show_legend=True, + node_distance=(0.7, 0.6), + ) + return plt.gcf() + + +# Compare with baseline/test_draw_graph_without_labels.png +# Update baseline by running: pytest --mpl-generate-path=tests/baseline +@pytest.mark.usefixtures("mock_plot") +@pytest.mark.mpl_image_compare +def test_draw_graph_without_labels() -> Figure: + circuit = Circuit(3) + circuit.cnot(0, 1) + circuit.cnot(2, 1) + circuit.rx(0, ANGLE_PI / 3) + circuit.x(2) + circuit.cnot(2, 1) + pattern = circuit.transpile().pattern + pattern.standardize() + pattern.draw_graph( + flow_from_pattern=True, + show_measurements=False, + show_legend=False, + show_measurement_order=False, + node_distance=(0.7, 0.6), + ) + return plt.gcf() + + +def test_format_measurement_label_bloch() -> None: + bloch_xy = Measurement.XY(0.25) + label = GraphVisualizer._format_measurement_label(bloch_xy) + assert label is not None + assert "XY" in label + assert "π" in label # Unicode fraction format e.g. XY(π/4) + + +def test_format_measurement_label_bloch_zero() -> None: + bloch_zero = Measurement.XY(0) + label = GraphVisualizer._format_measurement_label(bloch_zero) + assert label is not None + assert "XY" in label + assert "0" in label + + +def test_format_measurement_label_bloch_xz() -> None: + bloch_xz = Measurement.XZ(0.5) + label = GraphVisualizer._format_measurement_label(bloch_xz) + assert label is not None + assert "XZ" in label + + +def test_format_measurement_label_pauli() -> None: + pauli_x = Measurement.X + label = GraphVisualizer._format_measurement_label(pauli_x) + assert label is not None + assert label == str(pauli_x) + assert "X" in label + + +def test_format_measurement_label_pauli_minus() -> None: + pauli_minus_z = PauliMeasurement(Axis.Z, Sign.MINUS) + label = GraphVisualizer._format_measurement_label(pauli_minus_z) + assert label is not None + assert label == str(pauli_minus_z) + assert "-Z" in label