diff --git a/CHANGELOG.md b/CHANGELOG.md index 3210ebb31..13c53ba87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,32 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added + +- **Circuit Conversion**: Added circuit-derived pre-scheduling support in `circuit2graph()`. + - Added `CircuitScheduleStrategy` with `PARALLEL` and `MINIMIZE_SPACE`. + - Added `schedule_strategy` argument to `circuit2graph()`. + - `circuit2graph()` now returns `(graph, gflow, scheduler)` and pre-populates `Scheduler` via manual scheduling. + +### Changed + +- **Graph State**: Made `meas_bases` read-only by returning `MappingProxyType` to avoid external mutation. +- **Graph State**: Added caching for `physical_nodes` snapshots and proper cache invalidation on node add/remove. +- **Docs/Examples**: Updated circuit conversion usage in README and `examples/pattern_from_circuit.py` for the new `circuit2graph()` return signature. + +### Fixed + +- **Feedforward**: Fixed self-loop removal in `dag_from_flow()` by correcting operator precedence so self-loops are removed from combined `xflow`/`zflow` dependencies. +- **Pauli Frame**: Initialize `_pauli_axis_cache` only when FTQC parity-check groups are provided, avoiding unnecessary cache creation in non-FTQC usage. + +### Tests + +- **Circuit Conversion**: Expanded scheduling tests in `tests/test_circuit.py`, including scheduler return contract, J/CZ/phase-gadget timing behavior, schedule validation, and `MINIMIZE_SPACE` behavior. +- **Integration**: Added circuit-level integration tests for `signal_shifting()` and `pauli_simplification()` with circuit-vs-pattern statevector equivalence checks. +- **Stim Compiler / Pauli Frame**: Updated tests to explicitly pass parity-check groups where logical-observable and cache initialization paths are exercised. + ## [0.2.1] - 2026-01-16 ### Added diff --git a/README.md b/README.md index e843d8fda..a501454d8 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ circuit = Circuit(2) circuit.apply_macro_gate(H(0)) circuit.apply_macro_gate(CNOT((0, 1))) -graph, feedforward = circuit2graph(circuit) +graph, feedforward, scheduler = circuit2graph(circuit) # Compile into pattern pattern = qompile(graph, feedforward) diff --git a/examples/pattern_from_circuit.py b/examples/pattern_from_circuit.py index 1d50154f2..ff1787e96 100644 --- a/examples/pattern_from_circuit.py +++ b/examples/pattern_from_circuit.py @@ -29,8 +29,8 @@ circuit.cz(1, 2) # %% -# convert circuit to graph and flow -graphstate, gflow = circuit2graph(circuit) +# convert circuit to graph, flow, and scheduler +graphstate, gflow, scheduler = circuit2graph(circuit) # first, qompile it to standardized pattern pattern = qompile(graphstate, gflow) diff --git a/graphqomb/circuit.py b/graphqomb/circuit.py index 520436a32..c612db635 100644 --- a/graphqomb/circuit.py +++ b/graphqomb/circuit.py @@ -5,14 +5,17 @@ - `BaseCircuit`: An abstract base class for quantum circuits. - `MBQCCircuit`: A circuit class composed solely of a unit gate set. - `Circuit`: A class for circuits that include macro instructions. -- `circuit2graph`: A function that converts a circuit to a graph state and gflow. +- `CircuitScheduleStrategy`: Scheduling strategies for circuit conversion. +- `circuit2graph`: A function that converts a circuit to a graph state, gflow, and scheduler. """ from __future__ import annotations import copy +import enum import itertools from abc import ABC, abstractmethod +from enum import Enum from typing import TYPE_CHECKING import typing_extensions @@ -20,11 +23,119 @@ from graphqomb.common import Plane, PlannerMeasBasis from graphqomb.gates import CZ, Gate, J, PhaseGadget, UnitGate from graphqomb.graphstate import GraphState +from graphqomb.scheduler import Scheduler if TYPE_CHECKING: from collections.abc import Sequence +class CircuitScheduleStrategy(Enum): + """Enumeration for manual scheduling strategies derived from circuit structure.""" + + PARALLEL = enum.auto() + MINIMIZE_SPACE = enum.auto() + + +class _Circuit2GraphContext: + """Internal helper for converting circuits with a given scheduling strategy.""" + + graph: GraphState + gflow: dict[int, set[int]] + qindex2front_nodes: dict[int, int] + qindex2timestep: dict[int, int] + prepare_time: dict[int, int] + measure_time: dict[int, int] + minimize_qubits: bool + current_time: int + + def __init__(self, graph: GraphState, strategy: CircuitScheduleStrategy) -> None: + if strategy == CircuitScheduleStrategy.PARALLEL: + self.minimize_qubits = False + elif strategy == CircuitScheduleStrategy.MINIMIZE_SPACE: + self.minimize_qubits = True + else: + msg = f"Invalid schedule strategy: {strategy}" + raise ValueError(msg) + + self.graph = graph + self.gflow = {} + self.qindex2front_nodes = {} + self.qindex2timestep = {} + self.prepare_time = {} + self.measure_time = {} + self.current_time = 0 + + def apply_instruction(self, instruction: UnitGate) -> None: + """Apply a unit gate to the graph conversion context. + + Raises + ------ + TypeError + If the instruction type is not supported. + """ + if isinstance(instruction, J): + self._apply_j(instruction) + return + if isinstance(instruction, CZ): + self._apply_cz(instruction) + return + if isinstance(instruction, PhaseGadget): + self._apply_phase_gadget(instruction) + return + msg = f"Invalid instruction: {instruction}" + raise TypeError(msg) + + def _apply_j(self, instruction: J) -> None: + new_node = self.graph.add_physical_node() + self.graph.add_physical_edge(self.qindex2front_nodes[instruction.qubit], new_node) + self.graph.assign_meas_basis( + self.qindex2front_nodes[instruction.qubit], + PlannerMeasBasis(Plane.XY, -instruction.angle), + ) + + timestep = self.qindex2timestep[instruction.qubit] + if self.minimize_qubits: + timestep = max(self.current_time, timestep) + self.prepare_time[new_node] = timestep + self.measure_time[self.qindex2front_nodes[instruction.qubit]] = timestep + 1 + self.qindex2timestep[instruction.qubit] = timestep + 1 + if self.minimize_qubits: + self.current_time = timestep + 1 + + self.gflow[self.qindex2front_nodes[instruction.qubit]] = {new_node} + self.qindex2front_nodes[instruction.qubit] = new_node + + def _apply_cz(self, instruction: CZ) -> None: + self.graph.add_physical_edge( + self.qindex2front_nodes[instruction.qubits[0]], + self.qindex2front_nodes[instruction.qubits[1]], + ) + + aligned_time = max(self.qindex2timestep[instruction.qubits[0]], self.qindex2timestep[instruction.qubits[1]]) + if self.minimize_qubits: + aligned_time = max(self.current_time, aligned_time) + self.current_time = aligned_time + self.qindex2timestep[instruction.qubits[0]] = aligned_time + self.qindex2timestep[instruction.qubits[1]] = aligned_time + + def _apply_phase_gadget(self, instruction: PhaseGadget) -> None: + new_node = self.graph.add_physical_node() + self.graph.assign_meas_basis(new_node, PlannerMeasBasis(Plane.YZ, instruction.angle)) + for qubit in instruction.qubits: + self.graph.add_physical_edge(self.qindex2front_nodes[qubit], new_node) + + self.gflow[new_node] = {new_node} + + max_timestep = max(self.qindex2timestep[qubit] for qubit in instruction.qubits) + if self.minimize_qubits: + max_timestep = max(self.current_time, max_timestep) + self.current_time = max_timestep + 1 + self.prepare_time[new_node] = max_timestep + self.measure_time[new_node] = max_timestep + 1 + for qubit in instruction.qubits: + self.qindex2timestep[qubit] = max_timestep + 1 + + class BaseCircuit(ABC): """ Abstract base class for quantum circuits. @@ -208,64 +319,49 @@ def apply_macro_gate(self, gate: Gate) -> None: self.__macro_gate_instructions.append(gate) -def circuit2graph(circuit: BaseCircuit) -> tuple[GraphState, dict[int, set[int]]]: - r"""Convert a circuit to a graph state and gflow. +def circuit2graph( + circuit: BaseCircuit, + schedule_strategy: CircuitScheduleStrategy = CircuitScheduleStrategy.PARALLEL, +) -> tuple[GraphState, dict[int, set[int]], Scheduler]: + r"""Convert a circuit to a graph state, gflow, and scheduler. Parameters ---------- circuit : `BaseCircuit` The quantum circuit to convert. + schedule_strategy : `CircuitScheduleStrategy`, optional + Strategy for scheduling preparation and measurement times derived from the circuit, + by default `CircuitScheduleStrategy.PARALLEL`. + The strategies are: + + - `CircuitScheduleStrategy.PARALLEL`: schedule each qubit independently to reduce depth + - `CircuitScheduleStrategy.MINIMIZE_SPACE`: serialize operations to reduce prepared qubits Returns ------- - `tuple`\[`GraphState`, `dict`\[`int`, `set`\[`int`\]\]\] - The graph state and gflow converted from the circuit. + `tuple`\[`GraphState`, `dict`\[`int`, `set`\[`int`\]\], `Scheduler`\] + The graph state, gflow, and scheduler converted from the circuit. + The scheduler is configured with automatic time scheduling derived from circuit structure. - Raises - ------ - TypeError - If the circuit contains an invalid instruction. """ graph = GraphState() - gflow: dict[int, set[int]] = {} - - qindex2front_nodes: dict[int, int] = {} + context = _Circuit2GraphContext(graph, schedule_strategy) # input nodes for i in range(circuit.num_qubits): node = graph.add_physical_node() graph.register_input(node, i) - qindex2front_nodes[i] = node + context.qindex2front_nodes[i] = node + context.qindex2timestep[i] = 0 for instruction in circuit.unit_instructions(): - if isinstance(instruction, J): - new_node = graph.add_physical_node() - graph.add_physical_edge(qindex2front_nodes[instruction.qubit], new_node) - graph.assign_meas_basis( - qindex2front_nodes[instruction.qubit], - PlannerMeasBasis(Plane.XY, -instruction.angle), - ) - - gflow[qindex2front_nodes[instruction.qubit]] = {new_node} - qindex2front_nodes[instruction.qubit] = new_node - - elif isinstance(instruction, CZ): - graph.add_physical_edge( - qindex2front_nodes[instruction.qubits[0]], - qindex2front_nodes[instruction.qubits[1]], - ) - elif isinstance(instruction, PhaseGadget): - new_node = graph.add_physical_node() - graph.assign_meas_basis(new_node, PlannerMeasBasis(Plane.YZ, instruction.angle)) - for qubit in instruction.qubits: - graph.add_physical_edge(qindex2front_nodes[qubit], new_node) - - gflow[new_node] = {new_node} - else: - msg = f"Invalid instruction: {instruction}" - raise TypeError(msg) + context.apply_instruction(instruction) - for qindex, node in qindex2front_nodes.items(): + for qindex, node in context.qindex2front_nodes.items(): graph.register_output(node, qindex) - return graph, gflow + # manually schedule + scheduler = Scheduler(graph, context.gflow) + scheduler.manual_schedule(context.prepare_time, context.measure_time) + + return graph, context.gflow, scheduler diff --git a/graphqomb/feedforward.py b/graphqomb/feedforward.py index e1f558ef4..a32357256 100644 --- a/graphqomb/feedforward.py +++ b/graphqomb/feedforward.py @@ -101,7 +101,7 @@ def dag_from_flow( msg = "Invalid zflow object" raise TypeError(msg) for node in non_output_nodes: - target_nodes = xflow.get(node, set()) | zflow.get(node, set()) - {node} # remove self-loops + target_nodes = (xflow.get(node, set()) | zflow.get(node, set())) - {node} # remove self-loops dag[node] = target_nodes for output in output_nodes: dag[output] = set() diff --git a/graphqomb/graphstate.py b/graphqomb/graphstate.py index 9405f56f7..b7314ef21 100644 --- a/graphqomb/graphstate.py +++ b/graphqomb/graphstate.py @@ -21,6 +21,7 @@ from abc import ABC from collections.abc import Hashable, Iterable, Mapping, Sequence from collections.abc import Set as AbstractSet +from types import MappingProxyType from typing import TYPE_CHECKING, NamedTuple, TypeVar import typing_extensions @@ -83,12 +84,12 @@ def physical_edges(self) -> set[tuple[int, int]]: @property @abc.abstractmethod - def meas_bases(self) -> dict[int, MeasBasis]: + def meas_bases(self) -> MappingProxyType[int, MeasBasis]: r"""Return measurement bases. Returns ------- - `dict`\[`int`, `MeasBasis`\] + `types.MappingProxyType`\[`int`, `MeasBasis`\] measurement bases of each physical node. """ @@ -199,6 +200,8 @@ class GraphState(BaseGraphState): __node_counter: int + _cached_physical_nodes: frozenset[int] | None = None + def __init__(self) -> None: self.__input_node_indices = {} self.__output_node_indices = {} @@ -244,7 +247,9 @@ def physical_nodes(self) -> set[int]: `set`\[`int`\] set of physical nodes. """ - return self.__physical_nodes.copy() + if self._cached_physical_nodes is None: + self._cached_physical_nodes = frozenset(self.__physical_nodes) + return set(self._cached_physical_nodes) @property @typing_extensions.override @@ -265,15 +270,15 @@ def physical_edges(self) -> set[tuple[int, int]]: @property @typing_extensions.override - def meas_bases(self) -> dict[int, MeasBasis]: + def meas_bases(self) -> MappingProxyType[int, MeasBasis]: r"""Return measurement bases. Returns ------- - `dict`\[`int`, `MeasBasis`\] + `types.MappingProxyType`\[`int`, `MeasBasis`\] measurement bases of each physical node. """ - return self.__meas_bases.copy() + return MappingProxyType(self.__meas_bases) @property def local_cliffords(self) -> dict[int, LocalClifford]: @@ -356,6 +361,7 @@ def add_physical_node(self, coordinate: tuple[float, ...] | None = None) -> int: if coordinate is not None: self.__coordinates[node] = coordinate self.__node_counter += 1 + self._cached_physical_nodes = None return node @@ -416,6 +422,8 @@ def remove_physical_node(self, node: int) -> None: self.__local_cliffords.pop(node, None) self.__coordinates.pop(node, None) + self._cached_physical_nodes = None + def remove_physical_edge(self, node1: int, node2: int) -> None: """Remove a physical edge from the graph state. @@ -561,7 +569,7 @@ def check_canonical_form(self) -> None: if self.__local_cliffords: msg = "Clifford operators are applied." raise ValueError(msg) - for node in self.physical_nodes - set(self.output_node_indices): + for node in self.physical_nodes - self.output_node_indices.keys(): if self.meas_bases.get(node) is None: msg = "All non-output nodes must have measurement basis." raise ValueError(msg) diff --git a/graphqomb/pauli_frame.py b/graphqomb/pauli_frame.py index db0b4bc1c..84057fb08 100644 --- a/graphqomb/pauli_frame.py +++ b/graphqomb/pauli_frame.py @@ -83,9 +83,11 @@ def __init__( # Pre-compute Pauli axes for performance optimization # Only cache nodes that have measurement bases # NOTE: if non-Pauli measurements are involved, the stim_compile func will error out earlier - self._pauli_axis_cache = { - node: determine_pauli_axis(meas_basis) for node, meas_basis in graphstate.meas_bases.items() - } + self._pauli_axis_cache = ( + {node: determine_pauli_axis(meas_basis) for node, meas_basis in graphstate.meas_bases.items()} + if parity_check_group + else {} + ) # only necessary for FTQC # Cache for memoization of dependent chains self._chain_cache = {} diff --git a/tests/test_circuit.py b/tests/test_circuit.py index b52521531..11f303111 100644 --- a/tests/test_circuit.py +++ b/tests/test_circuit.py @@ -7,8 +7,9 @@ import numpy as np import pytest -from graphqomb.circuit import BaseCircuit, Circuit, MBQCCircuit, circuit2graph +from graphqomb.circuit import BaseCircuit, Circuit, CircuitScheduleStrategy, MBQCCircuit, circuit2graph from graphqomb.common import Plane, PlannerMeasBasis +from graphqomb.feedforward import pauli_simplification, signal_shifting from graphqomb.gates import ( CNOT, CZ, @@ -23,6 +24,10 @@ Y, Z, ) +from graphqomb.qompiler import qompile +from graphqomb.schedule_solver import ScheduleConfig, Strategy +from graphqomb.scheduler import Scheduler +from graphqomb.simulator import CircuitSimulator, PatternSimulator, SimulatorBackend # MBQCCircuit tests @@ -245,7 +250,7 @@ def test_circuit2graph_simple_circuit() -> None: circuit.cz(qubit1=0, qubit2=1) circuit.j(qubit=1, angle=-0.3) - graph, gflow = circuit2graph(circuit) + graph, gflow, scheduler = circuit2graph(circuit) # Check graph properties assert len(graph.input_node_indices) == 2 @@ -256,13 +261,16 @@ def test_circuit2graph_simple_circuit() -> None: # Check gflow assert len(gflow) == 2 # Two J gates should have gflow + # Check scheduler + assert isinstance(scheduler, Scheduler) + def test_circuit2graph_phase_gadget_circuit() -> None: """Test conversion with phase gadget.""" circuit = MBQCCircuit(num_qubits=3) circuit.phase_gadget(qubits=[0, 1, 2], angle=0.25) - graph, _ = circuit2graph(circuit) + graph, _, _ = circuit2graph(circuit) # Check graph properties assert len(graph.input_node_indices) == 3 @@ -283,7 +291,7 @@ def test_circuit2graph_phase_gadget_circuit() -> None: def test_circuit2graph_empty_circuit() -> None: """Test conversion of empty circuit.""" circuit = MBQCCircuit(num_qubits=2) - graph, gflow = circuit2graph(circuit) + graph, gflow, scheduler = circuit2graph(circuit) # Check graph properties assert len(graph.input_node_indices) == 2 @@ -291,6 +299,9 @@ def test_circuit2graph_empty_circuit() -> None: assert len(graph.physical_nodes) == 2 # Only input/output nodes assert len(gflow) == 0 # No gflow for empty circuit + # Check scheduler + assert isinstance(scheduler, Scheduler) + def test_circuit2graph_invalid_instruction() -> None: """Test that invalid instruction raises TypeError.""" @@ -326,7 +337,7 @@ def test_circuit2graph_complex_circuit() -> None: circuit.j(qubit=2, angle=-np.pi / 4) circuit.j(qubit=3, angle=np.pi) - graph, gflow = circuit2graph(circuit) + graph, gflow, scheduler = circuit2graph(circuit) # Check basic properties assert len(graph.input_node_indices) == 4 @@ -338,6 +349,9 @@ def test_circuit2graph_complex_circuit() -> None: # Check gflow: 4 J gates + 1 phase gadget = 5 entries assert len(gflow) == 5 + # Check scheduler + assert isinstance(scheduler, Scheduler) + def test_circuit2graph_measurement_basis_assignment() -> None: """Test that measurement bases are correctly assigned.""" @@ -345,7 +359,7 @@ def test_circuit2graph_measurement_basis_assignment() -> None: circuit.j(qubit=0, angle=0.7) circuit.j(qubit=1, angle=-1.2) - graph, _ = circuit2graph(circuit) + graph, _, _ = circuit2graph(circuit) # Find non-output nodes with measurement basis (J gates are applied to input nodes) measured_nodes = [ @@ -368,10 +382,259 @@ def test_circuit2graph_circuit_with_macro_gates() -> None: circuit.apply_macro_gate(H(qubit=0)) circuit.apply_macro_gate(CNOT(qubits=(0, 1))) - graph, _ = circuit2graph(circuit) + graph, _, _ = circuit2graph(circuit) # Check that macro gates are properly expanded assert len(graph.input_node_indices) == 2 assert len(graph.output_node_indices) == 2 # H expands to 1 J, CNOT expands to 2 J + 1 CZ = 3 nodes total assert len(graph.physical_nodes) == 5 # 2 inputs + 3 new nodes + + +# circuit2graph scheduling tests + + +def test_circuit2graph_returns_scheduler() -> None: + """Test that circuit2graph returns a valid Scheduler object.""" + circuit = MBQCCircuit(num_qubits=2) + circuit.j(qubit=0, angle=0.5) + circuit.cz(qubit1=0, qubit2=1) + + graph, _gflow, scheduler = circuit2graph(circuit) + + assert isinstance(scheduler, Scheduler) + assert scheduler.graph is graph + + +def test_circuit2graph_j_gate_timing() -> None: + """Test that J gates are scheduled sequentially on the same qubit.""" + circuit = MBQCCircuit(num_qubits=1) + circuit.j(qubit=0, angle=0.5) + circuit.j(qubit=0, angle=0.3) + circuit.j(qubit=0, angle=0.1) + + _graph, _gflow, scheduler = circuit2graph(circuit) + + # Check that measurement times are unique and ordered + measure_times = [t for t in scheduler.measure_time.values() if t is not None] + assert measure_times == sorted(measure_times) + assert len(set(measure_times)) == len(measure_times) # All unique + + +def test_circuit2graph_minimize_qubits_strategy_serializes() -> None: + """Test that MINIMIZE_SPACE strategy serializes independent J gates.""" + circuit = MBQCCircuit(num_qubits=2) + circuit.j(qubit=0, angle=0.1) + circuit.j(qubit=1, angle=0.2) + + graph_parallel, _gflow_parallel, scheduler_parallel = circuit2graph(circuit) + graph_min, _gflow_min, scheduler_min = circuit2graph( + circuit, + schedule_strategy=CircuitScheduleStrategy.MINIMIZE_SPACE, + ) + + parallel_input_nodes = list(graph_parallel.input_node_indices.keys()) + parallel_meas_times = [scheduler_parallel.measure_time[node] for node in parallel_input_nodes] + assert all(time is not None for time in parallel_meas_times) + parallel_meas_times_int = [time for time in parallel_meas_times if time is not None] + assert sorted(parallel_meas_times_int) == [1, 1] + + min_input_nodes = list(graph_min.input_node_indices.keys()) + min_meas_times = [scheduler_min.measure_time[node] for node in min_input_nodes] + assert all(time is not None for time in min_meas_times) + min_meas_times_int = [time for time in min_meas_times if time is not None] + assert sorted(min_meas_times_int) == [1, 2] + + scheduler_min.validate_schedule() + + +def test_circuit2graph_cz_timestep_alignment() -> None: + """Test that CZ gates align timesteps of interacting qubits.""" + circuit = MBQCCircuit(num_qubits=2) + circuit.j(qubit=0, angle=0.5) # qubit 0 at timestep 1 + circuit.j(qubit=0, angle=0.3) # qubit 0 at timestep 2 + circuit.cz(qubit1=0, qubit2=1) # Should align qubit 1 to timestep 2 + circuit.j(qubit=1, angle=0.1) # Now qubit 1 at timestep 3 + + _graph, _gflow, scheduler = circuit2graph(circuit) + + # Validate schedule respects DAG constraints + scheduler.validate_schedule() + + +def test_circuit2graph_phase_gadget_timing() -> None: + """Test that phase gadget has valid timing.""" + circuit = MBQCCircuit(num_qubits=3) + circuit.j(qubit=0, angle=0.5) # qubit 0 at timestep 1 + circuit.j(qubit=0, angle=0.3) # qubit 0 at timestep 2 + circuit.phase_gadget(qubits=[0, 1, 2], angle=0.25) + + graph, _gflow, scheduler = circuit2graph(circuit) + + # Check that phase gadget node has valid timing + pg_nodes = [n for n in graph.physical_nodes if graph.meas_bases.get(n) and graph.meas_bases[n].plane == Plane.YZ] + assert len(pg_nodes) == 1 + assert scheduler.prepare_time.get(pg_nodes[0]) is not None + assert scheduler.measure_time.get(pg_nodes[0]) is not None + + # Phase gadget should be prepared at max timestep of involved qubits + assert scheduler.prepare_time.get(pg_nodes[0]) == 2 # qubit 0 at timestep 2 + + +def test_circuit2graph_schedule_is_valid() -> None: + """Test that generated schedule passes validation.""" + circuit = MBQCCircuit(num_qubits=3) + circuit.j(qubit=0, angle=0.5) + circuit.cz(qubit1=0, qubit2=1) + circuit.j(qubit=1, angle=0.3) + circuit.cz(qubit1=1, qubit2=2) + circuit.j(qubit=2, angle=0.1) + + _graph, _gflow, scheduler = circuit2graph(circuit) + + # This should not raise any exceptions + scheduler.validate_schedule() + + +def test_signal_shifting_circuit_integration() -> None: + """Test signal_shifting integration with circuit compilation and simulation.""" + # Create a simple quantum circuit + circuit = MBQCCircuit(3) + circuit.j(0, 0.5 * np.pi) + circuit.cz(0, 1) + circuit.cz(0, 2) + circuit.j(1, 0.75 * np.pi) + circuit.j(2, 0.25 * np.pi) + circuit.cz(0, 2) + circuit.cz(1, 2) + + # Convert circuit to graph and gflow + graphstate, gflow, _ = circuit2graph(circuit) + + # Apply signal shifting + xflow, zflow = signal_shifting(graphstate, gflow) + + # Compile to pattern + pattern = qompile(graphstate, xflow, zflow) + + # Verify pattern is runnable + assert pattern is not None + assert pattern.max_space >= 0 + assert pattern.depth >= 0 + + # Simulate the pattern + simulator = PatternSimulator(pattern, SimulatorBackend.StateVector) + simulator.simulate() + state = simulator.state + statevec = state.state() + + # Compare with circuit simulator + circ_simulator = CircuitSimulator(circuit, SimulatorBackend.StateVector) + circ_simulator.simulate() + circ_state = circ_simulator.state.state() + inner_product = np.vdot(statevec, circ_state) + + # Verify that the results match (inner product should be close to 1) + assert np.isclose(np.abs(inner_product), 1.0) + + +def test_pauli_simplification_circuit_integration() -> None: + """Test pauli_simplification integration with circuit compilation and simulation.""" + # Create a quantum circuit (using j for rotations, cz for entanglement) + circuit = MBQCCircuit(2) + circuit.j(0, 0.5 * np.pi) # Rotation on qubit 0 + circuit.cz(0, 1) + circuit.j(1, 0.25 * np.pi) # Rotation on qubit 1 + + # Convert circuit to graph and gflow + graphstate, gflow, _ = circuit2graph(circuit) + + # Apply pauli simplification + xflow, zflow = pauli_simplification(graphstate, gflow) + + # Compile to pattern + pattern = qompile(graphstate, xflow, zflow) + + # Verify pattern is runnable + assert pattern is not None + assert pattern.max_space >= 0 + + # Simulate the pattern + simulator = PatternSimulator(pattern, SimulatorBackend.StateVector) + simulator.simulate() + state = simulator.state + statevec = state.state() + + # Compare with circuit simulator + circ_simulator = CircuitSimulator(circuit, SimulatorBackend.StateVector) + circ_simulator.simulate() + circ_state = circ_simulator.state.state() + inner_product = np.vdot(statevec, circ_state) + + # Verify that the results match (inner product should be close to 1) + assert np.isclose(np.abs(inner_product), 1.0) + + +def test_circuit2graph_single_qubit_no_gates() -> None: + """Test single qubit circuit with no gates.""" + circuit = MBQCCircuit(num_qubits=1) + + graph, gflow, scheduler = circuit2graph(circuit) + + assert len(graph.physical_nodes) == 1 + assert len(gflow) == 0 + assert isinstance(scheduler, Scheduler) + + +def test_circuit2graph_multiple_parallel_qubits() -> None: + """Test circuit with operations on multiple independent qubits.""" + circuit = MBQCCircuit(num_qubits=4) + circuit.j(qubit=0, angle=0.1) + circuit.j(qubit=1, angle=0.2) + circuit.j(qubit=2, angle=0.3) + circuit.j(qubit=3, angle=0.4) + + _graph, _gflow, scheduler = circuit2graph(circuit) + + # All qubits should have valid schedules + scheduler.validate_schedule() + + +def test_circuit2graph_deep_circuit() -> None: + """Test circuit with many sequential operations.""" + circuit = MBQCCircuit(num_qubits=2) + for i in range(10): + circuit.j(qubit=0, angle=0.1 * i) + circuit.cz(qubit1=0, qubit2=1) + circuit.j(qubit=1, angle=0.1 * i) + + graph, _gflow, scheduler = circuit2graph(circuit) + + # Verify schedule is valid + scheduler.validate_schedule() + + # Check expected number of nodes: 2 input + 20 J gates + assert len(graph.physical_nodes) == 22 + + +def test_circuit2graph_scheduler_can_resolve_with_different_strategy() -> None: + """Test that scheduler can be re-solved with different optimization strategy.""" + circuit = MBQCCircuit(num_qubits=3) + circuit.j(qubit=0, angle=0.5) + circuit.cz(qubit1=0, qubit2=1) + circuit.j(qubit=1, angle=0.3) + circuit.cz(qubit1=1, qubit2=2) + circuit.j(qubit=2, angle=0.1) + + _graph, _gflow, scheduler = circuit2graph(circuit) + + # Re-solve with different strategy + config = ScheduleConfig(Strategy.MINIMIZE_SPACE) + result = scheduler.solve_schedule(config) + + # solve_schedule should return True on success + assert result is True + + # The schedule should still have valid prepare/measure times + assert all(t is not None for t in scheduler.prepare_time.values()) + assert all(t is not None for t in scheduler.measure_time.values()) diff --git a/tests/test_feedforward.py b/tests/test_feedforward.py index b49beb644..482f34bc1 100644 --- a/tests/test_feedforward.py +++ b/tests/test_feedforward.py @@ -3,7 +3,6 @@ import numpy as np import pytest -from graphqomb.circuit import MBQCCircuit, circuit2graph from graphqomb.common import Axis, AxisMeasBasis, Plane, PlannerMeasBasis, Sign from graphqomb.feedforward import ( _is_flow, @@ -16,8 +15,6 @@ signal_shifting, ) from graphqomb.graphstate import GraphState -from graphqomb.qompiler import qompile -from graphqomb.simulator import CircuitSimulator, PatternSimulator, SimulatorBackend def two_node_graph() -> tuple[GraphState, int, int]: @@ -297,48 +294,6 @@ def test_signal_shifting_zflow_none() -> None: assert isinstance(new_zflow, dict) -def test_signal_shifting_circuit_integration() -> None: - """Test signal_shifting integration with circuit compilation and simulation.""" - # Create a simple quantum circuit - circuit = MBQCCircuit(3) - circuit.j(0, 0.5 * np.pi) - circuit.cz(0, 1) - circuit.cz(0, 2) - circuit.j(1, 0.75 * np.pi) - circuit.j(2, 0.25 * np.pi) - circuit.cz(0, 2) - circuit.cz(1, 2) - - # Convert circuit to graph and gflow - graphstate, gflow = circuit2graph(circuit) - - # Apply signal shifting - xflow, zflow = signal_shifting(graphstate, gflow) - - # Compile to pattern - pattern = qompile(graphstate, xflow, zflow) - - # Verify pattern is runnable - assert pattern is not None - assert pattern.max_space >= 0 - assert pattern.depth >= 0 - - # Simulate the pattern - simulator = PatternSimulator(pattern, SimulatorBackend.StateVector) - simulator.simulate() - state = simulator.state - statevec = state.state() - - # Compare with circuit simulator - circ_simulator = CircuitSimulator(circuit, SimulatorBackend.StateVector) - circ_simulator.simulate() - circ_state = circ_simulator.state.state() - inner_product = np.vdot(statevec, circ_state) - - # Verify that the results match (inner product should be close to 1) - assert np.isclose(np.abs(inner_product), 1.0) - - # Tests for pauli_simplification @@ -514,40 +469,3 @@ def test_pauli_simplification_preserves_original_flows() -> None: # Original flows should be unchanged assert xflow[parent] == original_xflow_parent assert zflow[parent] == original_zflow_parent - - -def test_pauli_simplification_circuit_integration() -> None: - """Test pauli_simplification integration with circuit compilation and simulation.""" - # Create a quantum circuit (using j for rotations, cz for entanglement) - circuit = MBQCCircuit(2) - circuit.j(0, 0.5 * np.pi) # Rotation on qubit 0 - circuit.cz(0, 1) - circuit.j(1, 0.25 * np.pi) # Rotation on qubit 1 - - # Convert circuit to graph and gflow - graphstate, gflow = circuit2graph(circuit) - - # Apply pauli simplification - xflow, zflow = pauli_simplification(graphstate, gflow) - - # Compile to pattern - pattern = qompile(graphstate, xflow, zflow) - - # Verify pattern is runnable - assert pattern is not None - assert pattern.max_space >= 0 - - # Simulate the pattern - simulator = PatternSimulator(pattern, SimulatorBackend.StateVector) - simulator.simulate() - state = simulator.state - statevec = state.state() - - # Compare with circuit simulator - circ_simulator = CircuitSimulator(circuit, SimulatorBackend.StateVector) - circ_simulator.simulate() - circ_state = circ_simulator.state.state() - inner_product = np.vdot(statevec, circ_state) - - # Verify that the results match (inner product should be close to 1) - assert np.isclose(np.abs(inner_product), 1.0) diff --git a/tests/test_pauli_frame.py b/tests/test_pauli_frame.py index 864c3afb7..6d80d94a4 100644 --- a/tests/test_pauli_frame.py +++ b/tests/test_pauli_frame.py @@ -59,7 +59,9 @@ def simple_pauli_frame( A simple PauliFrame instance """ graph, xflow, zflow = simple_graph_with_flows - return PauliFrame(graph, xflow, zflow) + # Provide parity_check_group to enable _pauli_axis_cache initialization + parity_check_group = [set(graph.physical_nodes)] + return PauliFrame(graph, xflow, zflow, parity_check_group) @pytest.fixture @@ -107,7 +109,9 @@ def x_axis_pauli_frame() -> PauliFrame: xflow = {n0: {n1}, n1: {n2}} zflow: dict[int, set[int]] = {} - return PauliFrame(graph, xflow, zflow) + # Provide parity_check_group to enable _pauli_axis_cache initialization + parity_check_group = [set(graph.physical_nodes)] + return PauliFrame(graph, xflow, zflow, parity_check_group) @pytest.fixture @@ -137,7 +141,9 @@ def y_axis_pauli_frame() -> PauliFrame: xflow = {n0: {n1}, n1: {n2}} zflow = {n0: {n0}} - return PauliFrame(graph, xflow, zflow) + # Provide parity_check_group to enable _pauli_axis_cache initialization + parity_check_group = [set(graph.physical_nodes)] + return PauliFrame(graph, xflow, zflow, parity_check_group) @pytest.fixture @@ -167,7 +173,9 @@ def z_axis_pauli_frame() -> PauliFrame: xflow = {n0: {n1}, n1: {n2}} zflow: dict[int, set[int]] = {} - return PauliFrame(graph, xflow, zflow) + # Provide parity_check_group to enable _pauli_axis_cache initialization + parity_check_group = [set(graph.physical_nodes)] + return PauliFrame(graph, xflow, zflow, parity_check_group) def test_x_flip(simple_pauli_frame: PauliFrame, simple_nodes: list[int]) -> None: @@ -378,7 +386,9 @@ def test_logical_observables_group() -> None: xflow = {n0: {n1}, n1: {n2}} zflow: dict[int, set[int]] = {n0: {n0, n2}} - pframe = PauliFrame(graph, xflow, zflow) + # Provide parity_check_group to enable _pauli_axis_cache initialization + parity_check_group = [set(graph.physical_nodes)] + pframe = PauliFrame(graph, xflow, zflow, parity_check_group) # Get logical observables group target_nodes = [n2] @@ -410,7 +420,9 @@ def test_collect_dependent_chain_cache_hit() -> None: xflow = {n0: {n1}, n1: {n2}, n2: {n3}} zflow = {n0: {n0}} - pframe = PauliFrame(graph, xflow, zflow) + # Provide parity_check_group to enable _pauli_axis_cache initialization + parity_check_group = [set(graph.physical_nodes)] + pframe = PauliFrame(graph, xflow, zflow, parity_check_group) # First call to n2 chain1 = pframe._collect_dependent_chain(n2) @@ -468,7 +480,9 @@ def test_collect_dependent_chain_diamond_cancellation() -> None: xflow = {n0: {n1, n2}, n1: {n3}, n2: {n3}, n3: {n4}} zflow: dict[int, set[int]] = {} - pframe = PauliFrame(graph, xflow, zflow) + # Provide parity_check_group to enable _pauli_axis_cache initialization + parity_check_group = [set(graph.physical_nodes)] + pframe = PauliFrame(graph, xflow, zflow, parity_check_group) # Verify the chain for n3 chain_n3 = pframe._collect_dependent_chain(n3) diff --git a/tests/test_stim_compiler.py b/tests/test_stim_compiler.py index 7c3d783f1..55b59eaab 100644 --- a/tests/test_stim_compiler.py +++ b/tests/test_stim_compiler.py @@ -241,7 +241,29 @@ def test_stim_compile_with_detectors() -> None: def test_stim_compile_with_logical_observables() -> None: """Test OBSERVABLE_INCLUDE generation.""" - pattern, meas_node, _ = create_simple_pattern_x_measurement() + # Create pattern with parity_check_group for logical observables support + graph = GraphState() + in_node = graph.add_physical_node() + meas_node = graph.add_physical_node() + out_node = graph.add_physical_node() + + q_idx = 0 + graph.register_input(in_node, q_idx) + graph.register_output(out_node, q_idx) + + graph.add_physical_edge(in_node, meas_node) + graph.add_physical_edge(meas_node, out_node) + + # X measurement: XY plane with angle 0 + graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(meas_node, PlannerMeasBasis(Plane.XY, 0.0)) + + xflow = {in_node: {meas_node}, meas_node: {out_node}} + # Provide parity_check_group to enable _pauli_axis_cache for logical observables + # Only include measured nodes (exclude output nodes which don't have measurement bases) + measured_nodes = {in_node, meas_node} + parity_check_group = [measured_nodes] + pattern = qompile(graph, xflow, parity_check_group=parity_check_group) # Define logical observables logical_observables = {0: [meas_node]}