diff --git a/graphqomb/pauli_frame.py b/graphqomb/pauli_frame.py index b43135f3..db0b4bc1 100644 --- a/graphqomb/pauli_frame.py +++ b/graphqomb/pauli_frame.py @@ -192,6 +192,8 @@ def logical_observables_group(self, target_nodes: Collection[int]) -> set[int]: def _collect_dependent_chain(self, node: int) -> set[int]: r"""Generalized dependent-chain collector that respects measurement planes. + Uses recursive memoization to correctly XOR nodes reached via multiple paths. + Parameters ---------- node : `int` @@ -211,38 +213,26 @@ def _collect_dependent_chain(self, node: int) -> set[int]: if node in self._chain_cache: return set(self._chain_cache[node]) - chain: set[int] = set() - untracked = {node} - tracked: set[int] = set() - - while untracked: - current = untracked.pop() - - # Optimized XOR operation: toggle membership - if current in chain: - chain.remove(current) - else: - chain.add(current) - - # Use pre-computed Pauli axis from cache - axis = self._pauli_axis_cache[current] - - # NOTE: might have to support plane instead of axis - if axis == Axis.X: - # Use defaultdict direct access (no need for .get with default) - parents = self.inv_zflow[current] - elif axis == Axis.Y: - # Optimized symmetric difference for Y axis - parents = self.inv_xflow[current].symmetric_difference(self.inv_zflow[current]) - elif axis == Axis.Z: - parents = self.inv_xflow[current] - else: - msg = f"Unexpected measurement axis: {axis}" - raise ValueError(msg) - - # Add untracked parents in bulk - untracked.update(p for p in parents if p not in tracked) - tracked.add(current) + chain: set[int] = {node} + + # Use pre-computed Pauli axis from cache + axis = self._pauli_axis_cache[node] + + # NOTE: might have to support plane instead of axis + if axis == Axis.X: + parents = self.inv_zflow[node] + elif axis == Axis.Y: + parents = self.inv_xflow[node].symmetric_difference(self.inv_zflow[node]) + elif axis == Axis.Z: + parents = self.inv_xflow[node] + else: + msg = f"Unexpected measurement axis: {axis}" + raise ValueError(msg) + + # Recursively collect and XOR parent chains + for parent in parents: + parent_chain = self._collect_dependent_chain(parent) + chain ^= parent_chain # Store result in cache for future calls self._chain_cache[node] = frozenset(chain) diff --git a/tests/test_pauli_frame.py b/tests/test_pauli_frame.py index 09ecfb9e..864c3afb 100644 --- a/tests/test_pauli_frame.py +++ b/tests/test_pauli_frame.py @@ -425,3 +425,61 @@ def test_collect_dependent_chain_cache_hit() -> None: # Results should be identical assert chain1 == chain2 assert set(cached_result) == chain1 + + +def test_collect_dependent_chain_diamond_cancellation() -> None: + """Test that nodes reached via multiple paths are correctly XOR'd. + + Diamond graph structure (5 nodes with n4 as output): + n0 → n1, n0 → n2, n1 → n3, n2 → n3, n3 → n4 + + When collecting dependent chain for n3: + - chain(n0) = {n0} + - chain(n1) = {n1} ^ chain(n0) = {n0, n1} + - chain(n2) = {n2} ^ chain(n0) = {n0, n2} + - chain(n3) = {n3} ^ chain(n1) ^ chain(n2) = {n3} ^ {n0, n1} ^ {n0, n2} = {n1, n2, n3} + + Node n0 should be canceled out because it's reached via two paths. + """ + graph = GraphState() + n0 = graph.add_physical_node() + n1 = graph.add_physical_node() + n2 = graph.add_physical_node() + n3 = graph.add_physical_node() + n4 = graph.add_physical_node() + + graph.register_input(n0, 0) + graph.register_output(n4, 0) + + # Diamond edges + edge to output + graph.add_physical_edge(n0, n1) + graph.add_physical_edge(n0, n2) + graph.add_physical_edge(n1, n3) + graph.add_physical_edge(n2, n3) + graph.add_physical_edge(n3, n4) + + # All Z measurements (XZ plane, angle 0) so parents come from inv_xflow + graph.assign_meas_basis(n0, PlannerMeasBasis(Plane.XZ, 0.0)) + graph.assign_meas_basis(n1, PlannerMeasBasis(Plane.XZ, 0.0)) + graph.assign_meas_basis(n2, PlannerMeasBasis(Plane.XZ, 0.0)) + graph.assign_meas_basis(n3, PlannerMeasBasis(Plane.XZ, 0.0)) + + # xflow: n0 → {n1, n2}, n1 → {n3}, n2 → {n3}, n3 → {n4} + xflow = {n0: {n1, n2}, n1: {n3}, n2: {n3}, n3: {n4}} + zflow: dict[int, set[int]] = {} + + pframe = PauliFrame(graph, xflow, zflow) + + # Verify the chain for n3 + chain_n3 = pframe._collect_dependent_chain(n3) + + # n0 should be canceled out (reached via n1 and n2) + assert n0 not in chain_n3, f"n0 should be canceled out but chain is {chain_n3}" + assert chain_n3 == {n1, n2, n3}, f"Expected {{n1, n2, n3}} but got {chain_n3}" + + # Also verify intermediate chains + chain_n1 = pframe._collect_dependent_chain(n1) + assert chain_n1 == {n0, n1}, f"Expected {{n0, n1}} but got {chain_n1}" + + chain_n2 = pframe._collect_dependent_chain(n2) + assert chain_n2 == {n0, n2}, f"Expected {{n0, n2}} but got {chain_n2}"