From c9e59dbe578903f0321fa83c8187f595840c2b73 Mon Sep 17 00:00:00 2001 From: Masato Fukushima Date: Thu, 15 Jan 2026 18:27:34 +0900 Subject: [PATCH] Fix XOR cancellation bug in _collect_dependent_chain The previous iterative BFS implementation tracked visited nodes with a `tracked` set, which prevented nodes reached via multiple paths from being correctly XOR'd (canceled out). For example, in a diamond graph where node A reaches node D via both B and C, node A should cancel out in the dependent chain for D. Changed to recursive memoization approach where each node's chain is computed by XORing the node with all parent chains. This correctly handles multi-path cancellation while maintaining efficiency through memoization. Added test case for diamond-shaped graph to verify XOR cancellation. Co-Authored-By: Claude Opus 4.5 --- graphqomb/pauli_frame.py | 54 +++++++++++++++--------------------- tests/test_pauli_frame.py | 58 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 32 deletions(-) diff --git a/graphqomb/pauli_frame.py b/graphqomb/pauli_frame.py index b43135f3f..db0b4bc1c 100644 --- a/graphqomb/pauli_frame.py +++ b/graphqomb/pauli_frame.py @@ -192,6 +192,8 @@ def logical_observables_group(self, target_nodes: Collection[int]) -> set[int]: def _collect_dependent_chain(self, node: int) -> set[int]: r"""Generalized dependent-chain collector that respects measurement planes. + Uses recursive memoization to correctly XOR nodes reached via multiple paths. + Parameters ---------- node : `int` @@ -211,38 +213,26 @@ def _collect_dependent_chain(self, node: int) -> set[int]: if node in self._chain_cache: return set(self._chain_cache[node]) - chain: set[int] = set() - untracked = {node} - tracked: set[int] = set() - - while untracked: - current = untracked.pop() - - # Optimized XOR operation: toggle membership - if current in chain: - chain.remove(current) - else: - chain.add(current) - - # Use pre-computed Pauli axis from cache - axis = self._pauli_axis_cache[current] - - # NOTE: might have to support plane instead of axis - if axis == Axis.X: - # Use defaultdict direct access (no need for .get with default) - parents = self.inv_zflow[current] - elif axis == Axis.Y: - # Optimized symmetric difference for Y axis - parents = self.inv_xflow[current].symmetric_difference(self.inv_zflow[current]) - elif axis == Axis.Z: - parents = self.inv_xflow[current] - else: - msg = f"Unexpected measurement axis: {axis}" - raise ValueError(msg) - - # Add untracked parents in bulk - untracked.update(p for p in parents if p not in tracked) - tracked.add(current) + chain: set[int] = {node} + + # Use pre-computed Pauli axis from cache + axis = self._pauli_axis_cache[node] + + # NOTE: might have to support plane instead of axis + if axis == Axis.X: + parents = self.inv_zflow[node] + elif axis == Axis.Y: + parents = self.inv_xflow[node].symmetric_difference(self.inv_zflow[node]) + elif axis == Axis.Z: + parents = self.inv_xflow[node] + else: + msg = f"Unexpected measurement axis: {axis}" + raise ValueError(msg) + + # Recursively collect and XOR parent chains + for parent in parents: + parent_chain = self._collect_dependent_chain(parent) + chain ^= parent_chain # Store result in cache for future calls self._chain_cache[node] = frozenset(chain) diff --git a/tests/test_pauli_frame.py b/tests/test_pauli_frame.py index 09ecfb9e8..864c3afb7 100644 --- a/tests/test_pauli_frame.py +++ b/tests/test_pauli_frame.py @@ -425,3 +425,61 @@ def test_collect_dependent_chain_cache_hit() -> None: # Results should be identical assert chain1 == chain2 assert set(cached_result) == chain1 + + +def test_collect_dependent_chain_diamond_cancellation() -> None: + """Test that nodes reached via multiple paths are correctly XOR'd. + + Diamond graph structure (5 nodes with n4 as output): + n0 → n1, n0 → n2, n1 → n3, n2 → n3, n3 → n4 + + When collecting dependent chain for n3: + - chain(n0) = {n0} + - chain(n1) = {n1} ^ chain(n0) = {n0, n1} + - chain(n2) = {n2} ^ chain(n0) = {n0, n2} + - chain(n3) = {n3} ^ chain(n1) ^ chain(n2) = {n3} ^ {n0, n1} ^ {n0, n2} = {n1, n2, n3} + + Node n0 should be canceled out because it's reached via two paths. + """ + graph = GraphState() + n0 = graph.add_physical_node() + n1 = graph.add_physical_node() + n2 = graph.add_physical_node() + n3 = graph.add_physical_node() + n4 = graph.add_physical_node() + + graph.register_input(n0, 0) + graph.register_output(n4, 0) + + # Diamond edges + edge to output + graph.add_physical_edge(n0, n1) + graph.add_physical_edge(n0, n2) + graph.add_physical_edge(n1, n3) + graph.add_physical_edge(n2, n3) + graph.add_physical_edge(n3, n4) + + # All Z measurements (XZ plane, angle 0) so parents come from inv_xflow + graph.assign_meas_basis(n0, PlannerMeasBasis(Plane.XZ, 0.0)) + graph.assign_meas_basis(n1, PlannerMeasBasis(Plane.XZ, 0.0)) + graph.assign_meas_basis(n2, PlannerMeasBasis(Plane.XZ, 0.0)) + graph.assign_meas_basis(n3, PlannerMeasBasis(Plane.XZ, 0.0)) + + # xflow: n0 → {n1, n2}, n1 → {n3}, n2 → {n3}, n3 → {n4} + xflow = {n0: {n1, n2}, n1: {n3}, n2: {n3}, n3: {n4}} + zflow: dict[int, set[int]] = {} + + pframe = PauliFrame(graph, xflow, zflow) + + # Verify the chain for n3 + chain_n3 = pframe._collect_dependent_chain(n3) + + # n0 should be canceled out (reached via n1 and n2) + assert n0 not in chain_n3, f"n0 should be canceled out but chain is {chain_n3}" + assert chain_n3 == {n1, n2, n3}, f"Expected {{n1, n2, n3}} but got {chain_n3}" + + # Also verify intermediate chains + chain_n1 = pframe._collect_dependent_chain(n1) + assert chain_n1 == {n0, n1}, f"Expected {{n0, n1}} but got {chain_n1}" + + chain_n2 = pframe._collect_dependent_chain(n2) + assert chain_n2 == {n0, n2}, f"Expected {{n0, n2}} but got {chain_n2}"