diff --git a/CHANGELOG.md b/CHANGELOG.md index 260b03df..147ac603 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ - Added v2→v3 config migration with automatic upgrade support for the new serialization settings - Added support for channels as quantum components via multiple inheritance, enabling channel-level macros and operations (e.g., `class HybridChannel(SingleChannel, Qubit)`). This allows macros to be attached directly to channels instead of requiring a parent qubit component. - Added `skip_save` field metadata support to exclude specific dataclass fields from serialization while keeping them accessible at runtime. Use `field(metadata={"skip_save": True})` to mark fields that should not be saved to JSON +- Added transient-state recording APIs on `QuamRoot` (`record_transient()`, `get_transient_changes()`, `revert_transient()`), with change records reported as `path`, `original`, and `transient`. ### Changed @@ -48,6 +49,7 @@ All deprecated properties now show migration guidance with code examples. See [P ### Fixed +- `QuamRoot.save()` now warns and persists original values when transient changes are active, then clears the transient state - Added `exponential_dc_gain` and `high_pass_filter` fields to `LFFEMAnalogOutputPort` for QOP 3.5+ filter support; fixed validation so the two fields can coexist and `exponential_dc_gain` alone conflicts with `feedback_filter` - Clarified in documentation how kwargs and attributes differ for method macros: kwargs are per-call overrides, attributes are persistent calibrated values that are saved with the QUAM state - Improved error messages for inferred frequency properties (`inferred_RF_frequency`, `inferred_intermediate_frequency`, `inferred_LO_frequency`) in `_OutComplexChannel` (`IQChannel` and `MWChannel`): errors now clearly identify the specific field and whether it is `None` or an unresolved reference diff --git a/docs/features/index.md b/docs/features/index.md index 05cda724..9423dc16 100644 --- a/docs/features/index.md +++ b/docs/features/index.md @@ -8,6 +8,12 @@ QUAM provides comprehensive serialization capabilities to save and load quantum - **[Serialization Documentation](serialization.md)**: Learn how to save and load QUAM configurations, control default value inclusion, and exclude specific fields from serialization using the `skip_save` metadata. This feature is essential for managing machine state, version controlling configurations, and separating runtime data from persistent configuration. +## Transient State + +Transient state records temporary runtime mutations so they can affect normal QUAM behavior and config generation without being persisted to disk. + +- **[Transient State Documentation](transient-state.md)**: Learn how to record, inspect, revert, and save temporary changes with `record_transient()`, `get_transient_changes()`, `revert_transient()`, and transient-aware `save()` behavior. + ## Gate-Level Operations Gate-level operations provide an abstraction layer that transforms low-level pulse definitions into high-level quantum gate operations. This feature allows users to build circuit-level QUA programs by working with quantum components (qubits and qubit pairs) and applying macros that represent common quantum gates. diff --git a/docs/features/transient-state.md b/docs/features/transient-state.md new file mode 100644 index 00000000..1daa2bbf --- /dev/null +++ b/docs/features/transient-state.md @@ -0,0 +1,256 @@ +# Transient State + +Transient state lets you make temporary changes to a [QuamRoot][quam.core.quam_classes.QuamRoot] object for runtime use without saving those changes as calibrated state. + +This is useful in calibration routines. A routine may need to temporarily change a machine parameter before generating a QUA config or running a program. That temporary value should affect the experiment, but it should not be persisted unless the later analysis decides it is the right calibrated value. + +## Motivation + +Consider a readout calibration that sweeps readout power. Before generating the QUA config, the calibration may need to raise the readout pulse amplitude to the maximum value used in the sweep. This ensures that the generated config contains a pulse large enough for all amplitude-scale factors used by the program. + +That maximum amplitude is not necessarily the value you want to save. It is a temporary runtime value used to run the sweep. After analysis, the calibration may choose a different fitted amplitude as the value to keep. + +Transient state separates these two steps: + +1. Record and apply temporary values for config generation or execution. +2. Revert those temporary values. +3. Save only the final calibrated values selected by the analysis. + +## Complete Example + +The following script uses the superconducting-qubits example components from `quam.examples.superconducting_qubits`. It creates a small QUAM, adds readout pulses, temporarily increases the readout amplitudes for config generation, then reverts those temporary values before saving the fitted calibration result. + +```python +from quam.components import pulses +from quam.examples.superconducting_qubits.generate_superconducting_quam import ( + create_quam_superconducting_referenced, +) + + +machine = create_quam_superconducting_referenced(num_qubits=2) + +for qubit in machine.qubits.values(): + qubit.resonator.operations["readout"] = pulses.SquareReadoutPulse( + length=1000, + amplitude=0.05, + ) + +max_readout_amplitude = 0.2 + +with machine.record_transient(): + for qubit in machine.qubits.values(): + qubit.resonator.operations["readout"].amplitude = max_readout_amplitude + +config = machine.generate_config() +assert config["waveforms"]["IQ0.readout.wf.I"]["sample"] == max_readout_amplitude + +print(machine.get_transient_changes()) + +machine.revert_transient() +assert machine.qubits["q0"].resonator.operations["readout"].amplitude == 0.05 + +fitted_amplitudes = { + "q0": 0.08, + "q1": 0.07, +} + +for qubit_name, amplitude in fitted_amplitudes.items(): + machine.qubits[qubit_name].resonator.operations["readout"].amplitude = amplitude + +machine.save() +``` + +This is the main transient-state pattern: + +- Use transient values to run the experiment. +- Revert the transient values after they are no longer needed. +- Use normal assignments for the analyzed calibration result. +- Save only the values you intend to keep. + +## Recording Temporary Values + +The transient recording scope is the part of the script where temporary values are assigned: + +```python +with machine.record_transient(): + for qubit in machine.qubits.values(): + qubit.resonator.operations["readout"].amplitude = max_readout_amplitude +``` + +`record_transient()` records the original values before the writes happen. It does not revert the values when the `with` block exits. The temporary values remain live on the machine: + +```python +print(machine.qubits["q0"].resonator.operations["readout"].amplitude) +# 0.2 +``` + +This is the key behavior: the temporary values are available to normal QUAM access and config generation. + +## Generate a Config With Temporary Values + +After recording the temporary changes, generate the config as usual: + +```python +config = machine.generate_config() +``` + +The generated config sees the temporary readout amplitudes because they are still live on the QUAM object: + +```python +print(config["waveforms"]["IQ0.readout.wf.I"]["sample"]) +# 0.2 +``` + +This lets the calibration run with the values needed for the experiment without making those values permanent. + +## Inspect and Revert + +Use `get_transient_changes()` to see what is currently recorded: + +```python +changes = machine.get_transient_changes() +print(changes) +``` + +The output contains the QUAM path, the original value, and the current temporary value: + +```python +[ + { + "path": "#/qubits/q0/resonator/operations/readout/amplitude", + "original": 0.05, + "transient": 0.2, + }, + { + "path": "#/qubits/q1/resonator/operations/readout/amplitude", + "original": 0.05, + "transient": 0.2, + }, +] +``` + +When the temporary values are no longer needed, revert them: + +```python +machine.revert_transient() + +print(machine.qubits["q0"].resonator.operations["readout"].amplitude) +# 0.05 +print(machine.get_transient_changes()) +# [] +``` + +The machine is now back to the state it had before the temporary calibration changes. + +## Save Only the Calibration Result + +After the experiment and analysis, apply the values you actually want to persist using normal assignments: + +```python +machine.revert_transient() + +fitted_amplitudes = { + "q0": 0.08, + "q1": 0.07, +} + +for qubit_name, amplitude in fitted_amplitudes.items(): + machine.qubits[qubit_name].resonator.operations["readout"].amplitude = amplitude + +machine.save("state.json") +``` + +The distinction is important: + +- Transient values are for running the experiment. +- Normal assignments are for calibrated values you intend to keep. + +This pattern prevents temporary sweep setup from being accidentally saved as the machine's calibrated state. + +## Saving With Active Transient Changes + +`save()` also has a safety behavior. If transient changes are still active when you save, QUAM: + +1. Emits a `UserWarning` with the number of active transient changes. +2. Reverts the object to the original pre-transient values. +3. Saves those original values to disk. +4. Clears the transient records after a successful save. + +```python +with machine.record_transient(): + machine.qubits["q0"].resonator.operations["readout"].amplitude = 0.2 + +machine.save("state.json") +``` + +In this case, the saved state contains the original amplitude, not `0.2`. + +This behavior is a guardrail. In calibration code, it is usually clearer to call `revert_transient()` explicitly before applying and saving the final fitted values. + +If saving fails after QUAM has reverted the transient values, QUAM restores the transient live state and records before raising the original exception. + +## Additional Details + +### First Write Is Recorded + +Only the first write to a given attribute, dictionary key, or list is recorded. Later writes update the live value, but the `original` value remains the value that will be restored: + +```python +with machine.record_transient(): + readout = machine.qubits["q0"].resonator.operations["readout"] + readout.amplitude = 0.1 + readout.amplitude = 0.2 + +print(machine.get_transient_changes()) +# [{"path": ".../amplitude", "original": 0.05, "transient": 0.2}] +``` + +### Dictionaries and Lists + +Transient recording also tracks writes through QUAM dictionaries and lists. + +Dictionary mutations are recorded per key: + +```python +with machine.record_transient(): + machine.wiring["temporary_mode"] = "power_sweep" +``` + +For added or deleted dictionary keys, `original` or `transient` is the `MISSING` sentinel from `quam.core.transient`. + +List mutations are recorded as a snapshot of the whole list: + +```python +machine.wiring["active_qubits"] = ["q0", "q1"] + +with machine.record_transient(): + machine.wiring["active_qubits"].append("q2") + +print(machine.get_transient_changes()) +# [{"path": "#/wiring/active_qubits", "original": ["q0", "q1"], "transient": ["q0", "q1", "q2"]}] +``` + +List changes are tracked at list granularity, not per index. + +### Overwriting Outside the Recording Scope + +If a recorded transient value is overwritten outside a `record_transient()` scope, QUAM treats that as a permanent write. It warns and removes the transient record: + +```python +with machine.record_transient(): + readout = machine.qubits["q0"].resonator.operations["readout"] + readout.amplitude = 0.2 + +readout.amplitude = 0.15 # warns; this value is now permanent + +print(machine.get_transient_changes()) +# [] +``` + +After this happens, `revert_transient()` will not restore the old value for that path because the transient record has been removed. + +### Scope and Limitations + +- Nested `record_transient()` scopes are not supported and raise `RuntimeError`. +- Detached components are not recorded into the last instantiated root; only objects attached to the active root are tracked. +- Transient state is for runtime mutations. Use `skip_save` metadata for fields that should never be serialized, even when they are not transient. diff --git a/mkdocs.yml b/mkdocs.yml index 0ce826f4..dfb681b8 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -58,6 +58,7 @@ nav: - "QUAM Features": - "features/index.md" - "features/serialization.md" + - "features/transient-state.md" - "features/gate-level-operations.md" - "features/quam-references.md" - migrating-to-quam.md diff --git a/quam/core/quam_classes.py b/quam/core/quam_classes.py index 53986dab..c89cd5c7 100644 --- a/quam/core/quam_classes.py +++ b/quam/core/quam_classes.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Iterable from collections import UserList +from contextlib import contextmanager import sys import warnings from pathlib import Path @@ -19,6 +20,7 @@ get_origin, get_args, Optional, + Hashable, ) from functools import partial from dataclasses import dataclass, fields, is_dataclass, MISSING @@ -41,6 +43,14 @@ from qm.type_hinting import DictQuaConfig +from .transient import ( + MISSING as TRANSIENT_MISSING, + TransientState, + _AttrRecord, + _DictRecord, + _ListRecord, +) + __all__ = [ "QuamBase", "QuamRoot", @@ -148,6 +158,235 @@ def sort_quam_components( return sorted_components +def _warn_transient_overwrite(transient_state: TransientState, token) -> None: + warnings.warn( + "A transient value is being permanently overwritten outside the recording " + "scope; the transient record was removed." + ) + transient_state.remove(token) + + +def _get_attached_root(obj: Any) -> Optional["QuamRoot"]: + current = obj + while current is not None: + if isinstance(current, QuamRoot): + return current + current = getattr(current, "parent", None) + return None + + +def _clear_transient_parent(value: Any) -> None: + if isinstance(value, QuamBase): + value.parent = None + + +def _restore_transient_parent(parent: Any, value: Any) -> None: + if isinstance(value, QuamBase) and value.parent is None: + value.parent = parent + + +def _transient_added_items(current: list[Any], snapshot: list[Any]) -> list[Any]: + remaining = list(snapshot) + added = [] + + for item in current: + for index, original in enumerate(remaining): + if item is original: + remaining.pop(index) + break + else: + added.append(item) + + return added + + +def _snapshot_transient_record_value(record: Any) -> Any: + if isinstance(record, _AttrRecord): + return getattr(record.obj, record.attr, TRANSIENT_MISSING) + if isinstance(record, _DictRecord): + return record.obj.data.get(record.key, TRANSIENT_MISSING) + if isinstance(record, _ListRecord): + return list(record.obj.data) + + raise TypeError(f"Unsupported transient record type: {type(record)}") + + +def _restore_transient_records_after_failed_save( + transient_state: TransientState, + record_snapshots: list[tuple[tuple[int, Hashable], Any, Any]], +) -> None: + for _, record, transient_value in record_snapshots: + if isinstance(record, _AttrRecord): + current = getattr(record.obj, record.attr, TRANSIENT_MISSING) + if transient_value is TRANSIENT_MISSING: + if current is not TRANSIENT_MISSING: + _clear_transient_parent(current) + object.__delattr__(record.obj, record.attr) + else: + if current is not TRANSIENT_MISSING and current is not transient_value: + _clear_transient_parent(current) + object.__setattr__(record.obj, record.attr, transient_value) + _restore_transient_parent(record.obj, transient_value) + continue + + if isinstance(record, _DictRecord): + current = record.obj.data.get(record.key, TRANSIENT_MISSING) + if transient_value is TRANSIENT_MISSING: + if current is not TRANSIENT_MISSING: + _clear_transient_parent(current) + del record.obj.data[record.key] + else: + if current is not TRANSIENT_MISSING and current is not transient_value: + _clear_transient_parent(current) + record.obj.data[record.key] = transient_value + _restore_transient_parent(record.obj, transient_value) + continue + + if isinstance(record, _ListRecord): + for item in _transient_added_items(list(record.obj.data), transient_value): + _clear_transient_parent(item) + + record.obj.data[:] = transient_value + for item in transient_value: + _restore_transient_parent(record.obj, item) + continue + + raise TypeError(f"Unsupported transient record type: {type(record)}") + + transient_state._records = [ + (token, record) for token, record, _ in record_snapshots + ] + transient_state._seen = {token for token, _, _ in record_snapshots} + transient_state._is_recording = False + + +def _is_transient_subtree_record(record_obj: Any, subtree_root: Any) -> bool: + if not isinstance(subtree_root, QuamBase) or not isinstance(record_obj, QuamBase): + return False + + current = record_obj + while current is not None: + if current is subtree_root: + return True + current = getattr(current, "parent", None) + return False + + +def _drop_transient_subtree_records( + transient_state: TransientState, subtree_root: Any +) -> bool: + kept_records = [] + removed = False + + for existing_token, record in transient_state._records: + if _is_transient_subtree_record(record.obj, subtree_root): + transient_state._seen.discard(existing_token) + removed = True + else: + kept_records.append((existing_token, record)) + + if removed: + transient_state._records = kept_records + + return removed + + +def _record_attr_write(obj: "QuamBase", name: str, value: Any) -> None: + if not getattr(obj, "_initialized", False): + return + + root = _get_attached_root(obj) + transient_state = getattr(root, "_transient_state", None) + if transient_state is None: + return + + try: + current = object.__getattribute__(obj, name) + except AttributeError: + current = TRANSIENT_MISSING + + if current is value: + return + + token = (id(obj), name) + if transient_state._is_recording: + if token in transient_state._seen: + return + + transient_state.record(_AttrRecord(obj, name, current), name) + return + + removed_subtree_records = False + if current is not TRANSIENT_MISSING and current is not value: + removed_subtree_records = _drop_transient_subtree_records( + transient_state, current + ) + + if token in transient_state._seen or removed_subtree_records: + _warn_transient_overwrite(transient_state, token) + + +def _get_transient_state_for_container( + container: "QuamBase", +) -> Optional[TransientState]: + if not getattr(container, "_initialized", False): + return None + + root = _get_attached_root(container) + return getattr(root, "_transient_state", None) + + +def _record_dict_write( + container: "QuamDict", key: Hashable, value: Any, *, allow_missing: bool +) -> None: + transient_state = _get_transient_state_for_container(container) + if transient_state is None: + return + + token = (id(container), key) + if transient_state._is_recording: + if token in transient_state._seen: + return + + if not allow_missing and key not in container.data: + return + + original = container.data[key] if key in container.data else TRANSIENT_MISSING + transient_state.record( + _DictRecord(container, key, original), + key, + ) + return + + current = container.data[key] if key in container.data else TRANSIENT_MISSING + removed_subtree_records = False + if current is not TRANSIENT_MISSING and current is not value: + removed_subtree_records = _drop_transient_subtree_records( + transient_state, current + ) + + if token in transient_state._seen or removed_subtree_records: + _warn_transient_overwrite(transient_state, token) + + +def _record_list_snapshot(container: "QuamList") -> None: + transient_state = _get_transient_state_for_container(container) + if transient_state is None: + return + + token = (id(container), "__list__") + if transient_state._is_recording: + if token in transient_state._seen: + return + + transient_state.record( + _ListRecord(container, container.data[:]), + "__list__", + ) + elif token in transient_state._seen: + _warn_transient_overwrite(transient_state, token) + + def _quam_dataclass(cls=None, **kwargs): """Dataclass for QUAM classes. @@ -289,8 +528,8 @@ def get_root(self) -> Optional[QuamRoot]: if self._last_instantiated_root is not None: warnings.warn( - f"This component is not part of any QuamRoot, using last " - f"instantiated QuamRoot. This is not recommended as it may lead to " + "This component is not part of any QuamRoot, using last " + "instantiated QuamRoot. This is not recommended as it may lead to " f"unexpected behaviour. Component: {self.__class__.__name__}" ) return self._last_instantiated_root @@ -338,7 +577,8 @@ def inferred_id(self) -> Union[str, int]: return self.id if self.parent is None: raise AttributeError( - f"Cannot infer id of {self.__class__.__name__} because it has no parent." + f"Cannot infer id of {self.__class__.__name__} because it has no" + " parent." ) return str(self.parent.get_attr_name(self)) @@ -634,7 +874,7 @@ def _get_referenced_value(self, reference: str) -> Any: root = self.get_root() if string_reference.is_absolute_reference(reference) and root is None: warnings.warn( - f"No QuamRoot initialized, cannot retrieve absolute reference " + "No QuamRoot initialized, cannot retrieve absolute reference " f"{reference} from {self.__class__.__name__}" ) return reference @@ -730,8 +970,9 @@ def _follow_reference_chain( max_depth = self._MAX_REFERENCE_DEPTH if max_depth <= 0: raise RecursionError( - f"Reference chain exceeded maximum depth of {self._MAX_REFERENCE_DEPTH}. " - f"Possible circular reference starting from {obj.get_attr_path()}" + "Reference chain exceeded maximum depth of" + f" {self._MAX_REFERENCE_DEPTH}. Possible circular reference starting" + f" from {obj.get_attr_path()}" ) # Handle list/dict index access specially @@ -770,9 +1011,7 @@ def _follow_reference_chain( # Recursively follow the chain return self._follow_reference_chain(parent_obj, parent_attr, max_depth - 1) - def set_at_reference( - self, attr: str, value: Any, allow_non_reference: bool = True - ): + def set_at_reference(self, attr: str, value: Any, allow_non_reference: bool = True): """Follow the reference of an attribute and set the value at the reference. This method follows reference chains recursively. If an attribute contains @@ -796,8 +1035,8 @@ def set_at_reference( if not string_reference.is_reference(raw_value): if not allow_non_reference: raise ValueError( - f"Cannot set at reference because attr '{attr}' is not a reference. " - f"'{attr}' = {raw_value}" + f"Cannot set at reference because attr '{attr}' is not a reference." + f" '{attr}' = {raw_value}" ) target_obj, target_attr = self, attr else: @@ -833,17 +1072,37 @@ class QuamRoot(QuamBase): """ def __post_init__(self): + self.__dict__["_transient_state"] = TransientState() QuamBase._last_instantiated_root = self self.serialiser = self.get_serialiser() super().__post_init__() def __setattr__(self, name, value): + _record_attr_write(self, name, value) converted_val = convert_dict_and_list(value, cls_or_obj=self, attr=name) super().__setattr__(name, converted_val) if isinstance(converted_val, QuamBase) and name != "parent": converted_val.parent = self + @contextmanager + def record_transient(self): + transient_state = self._transient_state + if transient_state._is_recording: + raise RuntimeError("Nested recording scopes are not supported.") + + transient_state._is_recording = True + try: + yield self + finally: + transient_state._is_recording = False + + def revert_transient(self) -> None: + self._transient_state.revert() + + def get_transient_changes(self) -> list[dict[str, Any]]: + return self._transient_state.describe() + @classmethod def get_serialiser(cls) -> AbstractSerialiser: """Get the serialiser for the QuamRoot class, which is the JSONSerialiser. @@ -894,6 +1153,34 @@ def save( value. ignore: A list of components to ignore. """ + if self._transient_state._records: + active_change_count = len(self._transient_state._records) + change_label = "change" if active_change_count == 1 else "changes" + transient_record_snapshots = [ + (token, record, _snapshot_transient_record_value(record)) + for token, record in self._transient_state._records + ] + warnings.warn( + f"{active_change_count} active transient {change_label}; save() " + "will revert them, persist the original pre-transient values, " + "and clear transient state." + ) + self._transient_state.revert() + try: + self.serialiser.save( + quam_obj=self, + path=path, + content_mapping=content_mapping, + include_defaults=include_defaults, + ignore=ignore, + ) + except Exception: + _restore_transient_records_after_failed_save( + self._transient_state, transient_record_snapshots + ) + raise + return + self.serialiser.save( quam_obj=self, path=path, @@ -980,6 +1267,7 @@ class QuamComponent(QuamBase): """ def __setattr__(self, name, value): + _record_attr_write(self, name, value) converted_val = convert_dict_and_list(value, cls_or_obj=self, attr=name) super().__setattr__(name, converted_val) @@ -1073,6 +1361,7 @@ def __getitem__(self, i): # Overriding methods from UserDict def __setitem__(self, key, value): + _record_dict_write(self, key, value, allow_missing=True) value = convert_dict_and_list(value) self._is_valid_setattr(key, value, error_on_False=True) super().__setitem__(key, value) @@ -1080,6 +1369,10 @@ def __setitem__(self, key, value): if isinstance(value, QuamBase): value.parent = self + def __delitem__(self, key): + _record_dict_write(self, key, TRANSIENT_MISSING, allow_missing=False) + super().__delitem__(key) + def __eq__(self, other) -> bool: if isinstance(other, dict): return self.data == other @@ -1247,13 +1540,15 @@ class QuamList(UserList, QuamBase): _value_annotation: ClassVar[type] = None def __init__(self, *args, value_annotation: type = None): - self._value_annotation = value_annotation + self.__dict__["_value_annotation"] = value_annotation + self.__dict__["_initialized"] = False # We manually add elements using extend instead of passing to super() # To ensure that any dicts and lists get converted to QuamDict and QuamList super().__init__() if args: self.extend(*args) + self.__dict__["_initialized"] = True # Overloading methods from UserList def __eq__(self, value: object) -> bool: @@ -1296,6 +1591,7 @@ def __getitem__(self, i): return self._get_referenced_value(elem) def __setitem__(self, i, item): + self._record_list_snapshot() converted_item = convert_dict_and_list(item) super().__setitem__(i, converted_item) @@ -1304,9 +1600,12 @@ def __setitem__(self, i, item): def __iadd__(self, other: Iterable): converted_other = [convert_dict_and_list(elem) for elem in other] + if converted_other: + self._record_list_snapshot() return super().__iadd__(converted_other) def append(self, item: Any) -> None: + self._record_list_snapshot() converted_item = convert_dict_and_list(item) if isinstance(converted_item, QuamBase): @@ -1315,6 +1614,7 @@ def append(self, item: Any) -> None: return super().append(converted_item) def insert(self, i: int, item: Any) -> None: + self._record_list_snapshot() converted_item = convert_dict_and_list(item) if isinstance(converted_item, QuamBase): @@ -1324,12 +1624,36 @@ def insert(self, i: int, item: Any) -> None: def extend(self, iterable: Iterator) -> None: converted_iterable = [convert_dict_and_list(elem) for elem in iterable] + if not converted_iterable: + return super().extend(converted_iterable) + + self._record_list_snapshot() for converted_item in converted_iterable: if isinstance(converted_item, QuamBase): converted_item.parent = self return super().extend(converted_iterable) + def remove(self, item: Any) -> None: + self._record_list_snapshot() + return super().remove(item) + + def pop(self, i: int = -1): + self._record_list_snapshot() + return super().pop(i) + + def __delitem__(self, i): + self._record_list_snapshot() + return super().__delitem__(i) + + def clear(self) -> None: + if self.data: + self._record_list_snapshot() + return super().clear() + + def _record_list_snapshot(self) -> None: + _record_list_snapshot(self) + # Quam methods def _val_matches_attr_annotation(self, attr: str, val: Any) -> bool: """Check whether the type of an attribute matches the annotation. diff --git a/quam/core/transient.py b/quam/core/transient.py new file mode 100644 index 00000000..7a621351 --- /dev/null +++ b/quam/core/transient.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Hashable + +MISSING = object() + + +def _path_for(obj: Any, suffix: str | None = None) -> str: + try: + base = obj.get_reference() + except Exception: + base = obj.__class__.__name__ + + if suffix is None: + return base + + if base.endswith("/"): + return f"{base}{suffix}" + return f"{base}/{suffix}" + + +def _clear_parent(value: Any) -> None: + if _is_quam_base(value): + value.parent = None + + +def _restore_parent(parent: Any, value: Any) -> None: + if _is_quam_base(value) and value.parent is None: + value.parent = parent + + +def _is_quam_base(value: Any) -> bool: + try: + from quam.core.quam_classes import QuamBase + except ImportError: + return False + + return isinstance(value, QuamBase) + + +def _added_items(current: list[Any], snapshot: list[Any]) -> list[Any]: + remaining = list(snapshot) + added = [] + + for item in current: + for index, original in enumerate(remaining): + if item is original: + remaining.pop(index) + break + else: + added.append(item) + + return added + + +@dataclass(slots=True) +class _AttrRecord: + obj: Any + attr: str + original: Any + + def describe(self) -> dict[str, Any]: + return { + "path": _path_for(self.obj, self.attr), + "original": self.original, + "transient": getattr(self.obj, self.attr, MISSING), + } + + def revert(self) -> None: + current = getattr(self.obj, self.attr, MISSING) + + if self.original is MISSING: + if current is not MISSING: + _clear_parent(current) + object.__delattr__(self.obj, self.attr) + return + + if current is not MISSING and current is not self.original: + _clear_parent(current) + + object.__setattr__(self.obj, self.attr, self.original) + _restore_parent(self.obj, self.original) + + +@dataclass(slots=True) +class _DictRecord: + obj: Any + key: Any + original: Any + + def describe(self) -> dict[str, Any]: + return { + "path": _path_for(self.obj, str(self.key)), + "original": self.original, + "transient": self.obj.data.get(self.key, MISSING), + } + + def revert(self) -> None: + current = self.obj.data.get(self.key, MISSING) + + if self.original is MISSING: + if current is not MISSING: + _clear_parent(current) + del self.obj.data[self.key] + return + + if current is not MISSING and current is not self.original: + _clear_parent(current) + + self.obj.data[self.key] = self.original + _restore_parent(self.obj, self.original) + + +@dataclass(slots=True) +class _ListRecord: + obj: Any + snapshot: list[Any] + + def describe(self) -> dict[str, Any]: + return { + "path": _path_for(self.obj), + "original": list(self.snapshot), + "transient": list(self.obj.data), + } + + def revert(self) -> None: + for item in _added_items(list(self.obj.data), self.snapshot): + _clear_parent(item) + + self.obj.data[:] = self.snapshot + for item in self.snapshot: + _restore_parent(self.obj, item) + + +class TransientState: + def __init__(self): + self._is_recording = False + self._records: list[tuple[tuple[int, Hashable], Any]] = [] + self._seen: set[tuple[int, Hashable]] = set() + + def record(self, record: Any, lookup_key: Hashable): + token = (id(record.obj), lookup_key) + if not self._is_recording or token in self._seen: + return token + + self._seen.add(token) + self._records.append((token, record)) + return token + + def remove(self, token: tuple[int, Hashable]) -> None: + if token not in self._seen: + return + + self._seen.remove(token) + self._records = [ + (existing_token, record) + for existing_token, record in self._records + if existing_token != token + ] + + def describe(self) -> list[dict[str, Any]]: + return [record.describe() for _, record in self._records] + + def revert(self) -> None: + self._is_recording = False + try: + for _, record in reversed(self._records): + record.revert() + finally: + self._records.clear() + self._seen.clear() + self._is_recording = False diff --git a/tests/quam_base/test_transient_state.py b/tests/quam_base/test_transient_state.py new file mode 100644 index 00000000..07082205 --- /dev/null +++ b/tests/quam_base/test_transient_state.py @@ -0,0 +1,460 @@ +import ast +from dataclasses import field +import json +from pathlib import Path +import warnings + +import pytest + +from quam.core import QuamComponent, QuamRoot, quam_dataclass +from quam.core.transient import ( + MISSING, + TransientState, + _AttrRecord, + _DictRecord, + _ListRecord, +) + + +@quam_dataclass +class Leaf(QuamComponent): + value: int = 0 + + +@quam_dataclass +class Root(QuamRoot): + child: Leaf + mapping: dict = field(default_factory=dict) + items: list = field(default_factory=list) + + +@quam_dataclass +class AttrRoot(QuamRoot): + child: Leaf + + +class HookedObject: + def __init__(self): + object.__setattr__(self, "value", 1) + object.__setattr__(self, "_locked", True) + + def __setattr__(self, name, value): + if getattr(self, "_locked", False): + raise RuntimeError("write hook should be bypassed") + super().__setattr__(name, value) + + def __delattr__(self, name): + if getattr(self, "_locked", False): + raise RuntimeError("delete hook should be bypassed") + super().__delattr__(name) + + +def test_transient_state_records_first_attr_write_and_can_remove(): + root = Root(child=Leaf(value=1)) + state = TransientState() + state._is_recording = True + + token = state.record(_AttrRecord(root.child, "value", 1), "value") + root.child.value = 2 + + duplicate_token = state.record(_AttrRecord(root.child, "value", 2), "value") + root.child.value = 3 + + assert duplicate_token == token + assert state.describe() == [ + {"path": "#/child/value", "original": 1, "transient": 3} + ] + + state.remove(token) + assert state.describe() == [] + + state.record(_AttrRecord(root.child, "value", 3), "value") + root.child.value = 4 + assert state.describe() == [ + {"path": "#/child/value", "original": 3, "transient": 4} + ] + + +def test_attr_record_revert_bypasses_write_hooks_and_removes_missing_attrs(): + obj = HookedObject() + state = TransientState() + state._is_recording = True + + state.record(_AttrRecord(obj, "value", 1), "value") + state.record(_AttrRecord(obj, "extra", MISSING), "extra") + + object.__setattr__(obj, "value", 9) + object.__setattr__(obj, "extra", "temp") + + state.revert() + + assert obj.value == 1 + assert not hasattr(obj, "extra") + assert state._is_recording is False + + +def test_attr_record_revert_restores_original_parent_and_clears_replacement_parent(): + original = Leaf(value=1) + replacement = Leaf(value=2) + root = AttrRoot(child=original) + state = TransientState() + state._is_recording = True + + state.record(_AttrRecord(root, "child", original), "child") + root.child = replacement + + assert original.parent is root + assert replacement.parent is root + + state.revert() + + assert root.child is original + assert original.parent is root + assert replacement.parent is None + + +def test_dict_record_revert_restores_original_values_and_clears_added_parents(): + original = Leaf(value=1) + replacement = Leaf(value=2) + added = Leaf(value=3) + root = Root(child=Leaf(), mapping={"item": original}) + state = TransientState() + state._is_recording = True + + state.record(_DictRecord(root.mapping, "item", original), "item") + root.mapping["item"] = replacement + + state.record(_DictRecord(root.mapping, "added", MISSING), "added") + root.mapping["added"] = added + + assert state.describe() == [ + {"path": "#/mapping/item", "original": original, "transient": replacement}, + {"path": "#/mapping/added", "original": MISSING, "transient": added}, + ] + + state.revert() + + assert root.mapping["item"] is original + assert "added" not in root.mapping.data + assert original.parent is root.mapping + assert replacement.parent is None + assert added.parent is None + + +def test_list_record_revert_restores_snapshot_and_clears_added_parents(): + original = Leaf(value=1) + replacement = Leaf(value=2) + appended = Leaf(value=3) + root = Root(child=Leaf(), items=[original]) + state = TransientState() + state._is_recording = True + + token = state.record(_ListRecord(root.items, root.items.data[:]), "__list__") + root.items[0] = replacement + + duplicate_token = state.record( + _ListRecord(root.items, root.items.data[:]), "__list__" + ) + root.items.append(appended) + + assert duplicate_token == token + assert state.describe() == [ + { + "path": "#/items", + "original": [original], + "transient": [replacement, appended], + } + ] + + state.revert() + + assert root.items.data == [original] + assert original.parent is root.items + assert replacement.parent is None + assert appended.parent is None + + +def test_transient_state_starts_disabled_and_revert_leaves_it_disabled(): + root = Root(child=Leaf(value=1)) + state = TransientState() + + token = state.record(_AttrRecord(root.child, "value", 1), "value") + root.child.value = 2 + + assert token == (id(root.child), "value") + assert state.describe() == [] + assert state._is_recording is False + + state._is_recording = True + state.record(_AttrRecord(root.child, "value", 1), "value") + root.child.value = 3 + + state.revert() + + assert root.child.value == 1 + assert state._is_recording is False + + +def test_transient_module_has_no_top_level_quam_classes_import(): + transient_source = ( + Path(__file__).resolve().parents[2] / "quam" / "core" / "transient.py" + ) + source_text = transient_source.read_text() + module = ast.parse(source_text) + + for node in module.body: + if isinstance(node, ast.ImportFrom): + assert node.module != "quam.core.quam_classes" + + +def test_record_transient_records_component_attribute_until_explicit_revert(): + root = Root(child=Leaf(value=1)) + + with root.record_transient(): + root.child.value = 2 + + assert isinstance(root._transient_state, TransientState) + assert root.child.value == 2 + assert root.get_transient_changes() == [ + {"path": "#/child/value", "original": 1, "transient": 2} + ] + + root.revert_transient() + + assert root.child.value == 1 + assert root.get_transient_changes() == [] + + +def test_record_transient_records_dict_add_modify_delete_and_reverts(): + root = Root( + child=Leaf(), + mapping={"modified": 1, "deleted": 2}, + ) + + with root.record_transient(): + root.mapping["modified"] = 10 + root.mapping["added"] = 20 + del root.mapping["deleted"] + + assert root.mapping == {"modified": 10, "added": 20} + assert root.get_transient_changes() == [ + {"path": "#/mapping/modified", "original": 1, "transient": 10}, + {"path": "#/mapping/added", "original": MISSING, "transient": 20}, + {"path": "#/mapping/deleted", "original": 2, "transient": MISSING}, + ] + + root.revert_transient() + + assert root.mapping == {"modified": 1, "deleted": 2} + + +def test_record_transient_records_list_changes_and_reverts(): + root = Root(child=Leaf(), items=[1, 2]) + + with root.record_transient(): + root.items.append(3) + root.items.insert(0, 0) + root.items.remove(2) + + assert root.items == [0, 1, 3] + assert root.get_transient_changes() == [ + {"path": "#/items", "original": [1, 2], "transient": [0, 1, 3]} + ] + + root.revert_transient() + + assert root.items == [1, 2] + + +def test_overwrite_outside_recording_scope_warns_and_drops_transient_record(): + root = Root(child=Leaf(value=1)) + + with root.record_transient(): + root.child.value = 2 + + assert root.get_transient_changes() == [ + {"path": "#/child/value", "original": 1, "transient": 2} + ] + + with pytest.warns( + UserWarning, + match=( + "transient value is being permanently overwritten.*transient record was" + " removed" + ), + ): + root.child.value = 3 + + assert root.child.value == 3 + assert root.get_transient_changes() == [] + + root.revert_transient() + + assert root.child.value == 3 + + +def test_get_transient_changes_returns_human_readable_path_original_transient(): + root = Root(child=Leaf(value=1), mapping={"status": "idle"}, items=[1]) + + with root.record_transient(): + root.child.value = 5 + root.mapping["status"] = "busy" + root.items.append(2) + + assert root.get_transient_changes() == [ + {"path": "#/child/value", "original": 1, "transient": 5}, + {"path": "#/mapping/status", "original": "idle", "transient": "busy"}, + {"path": "#/items", "original": [1], "transient": [1, 2]}, + ] + + +def test_record_transient_nested_scope_raises(): + root = Root(child=Leaf()) + + with root.record_transient(): + with pytest.raises( + RuntimeError, match="Nested recording scopes are not supported." + ): + with root.record_transient(): + pass + + +def test_detached_component_write_is_not_recorded_by_another_root(): + root = Root(child=Leaf(value=1)) + detached = Leaf(value=7) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + with root.record_transient(): + detached.value = 9 + + assert caught == [] + assert detached.value == 9 + assert root.get_transient_changes() == [] + + +def test_overwriting_ancestor_outside_scope_drops_descendant_transient_record(): + root = Root(child=Leaf(value=1)) + old_child = root.child + replacement = Leaf(value=5) + + with root.record_transient(): + root.child.value = 2 + + assert root.get_transient_changes() == [ + {"path": "#/child/value", "original": 1, "transient": 2} + ] + + with pytest.warns( + UserWarning, + match=( + "transient value is being permanently overwritten.*transient record was" + " removed" + ), + ): + root.child = replacement + + assert root.child is replacement + assert root.get_transient_changes() == [] + + root.revert_transient() + + assert root.child is replacement + assert replacement.value == 5 + assert old_child.value == 2 + + +def test_record_transient_tracks_list_iadd_delete_and_clear_without_noop_snapshot(): + root = Root(child=Leaf(), items=[1, 2]) + + with root.record_transient(): + root.items += [] + assert root.get_transient_changes() == [] + + root.items += [3] + del root.items[1] + root.items.clear() + + assert root.items == [] + assert root.get_transient_changes() == [ + {"path": "#/items", "original": [1, 2], "transient": []} + ] + + root.revert_transient() + + assert root.items == [1, 2] + + +def test_save_warns_when_transient_changes_are_active(tmp_path): + root = Root(child=Leaf(value=1)) + + with root.record_transient(): + root.child.value = 2 + + with pytest.warns( + UserWarning, + match=( + "1 active transient change.*save\\(\\) will revert.*original " + "pre-transient values.*clear transient state" + ), + ): + root.save(tmp_path / "state.json") + + +def test_save_persists_original_values_instead_of_transient_ones(tmp_path): + root = Root(child=Leaf(value=1), mapping={"status": "idle"}) + + with root.record_transient(): + root.child.value = 2 + root.mapping["status"] = "busy" + + root.save(tmp_path / "state.json") + + saved = json.loads((tmp_path / "state.json").read_text()) + + assert saved["child"]["value"] == 1 + assert saved["mapping"]["status"] == "idle" + + +def test_save_clears_transient_state_after_reverting(tmp_path): + root = Root(child=Leaf(value=1)) + + with root.record_transient(): + root.child.value = 2 + + root.save(tmp_path / "state.json") + + assert root.get_transient_changes() == [] + assert root._transient_state._records == [] + + +def test_save_reverts_live_object_state_after_save(tmp_path): + root = Root(child=Leaf(value=1), items=[1, 2]) + + with root.record_transient(): + root.child.value = 2 + root.items.append(3) + + root.save(tmp_path / "state.json") + + assert root.child.value == 1 + assert root.items == [1, 2] + + +def test_save_failure_restores_transient_live_state_and_records(tmp_path): + root = Root(child=Leaf(value=1)) + + with root.record_transient(): + root.child.value = 2 + + assert root.get_transient_changes() == [ + {"path": "#/child/value", "original": 1, "transient": 2} + ] + + with pytest.raises(ValueError, match="Unsupported path suffix"): + root.save(tmp_path / "state.txt") + + assert root.child.value == 2 + assert root.get_transient_changes() == [ + {"path": "#/child/value", "original": 1, "transient": 2} + ]