From 4ac8e69099c175273859cf15809cc9d50bcb8902 Mon Sep 17 00:00:00 2001 From: SAY-5 Date: Tue, 28 Apr 2026 14:31:44 -0700 Subject: [PATCH 1/2] fix: DataLoader.collate clones cached hdata on sample_full_hypergraph MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `DataLoader.collate()` returned `self.__cached_dataset_hdata.to(...)` when `sample_full_hypergraph=True`. Because `HData.to()` is in-place, that returned the cached dataset object itself — so iterating the dataloader and mutating the batch (or transferring through a different device path on the next iteration) silently mutated the dataset's cached `hdata`. This change adds an `HData.clone()` method that returns a structurally independent `HData` (every tensor field cloned, scalar fields passed through), and wires the loader to `clone().to(...)` instead of `to(...)` directly. The clone happens once per batch in the sample-full path, so the cost is bounded by the dataset size — same order of magnitude as the device transfer that already happens there. Three regression tests in `hyperbench/tests/data/loader_test.py`: - `test_collate_sample_full_hypergraph_does_not_share_storage_with_cached_hdata` asserts `data_ptr` inequality across `x`, `hyperedge_index`, `hyperedge_attr`. - `test_collate_sample_full_hypergraph_mutating_batch_does_not_affect_cached_hdata` mutates the batch in place and confirms the cached hdata's tensors are unchanged. - `test_collate_sample_full_hypergraph_with_weights_isolates_weights` exercises the same isolation for `hyperedge_weights`. Each fails when `loader.py` and `hdata.py` are stashed, confirming they exercise the new behaviour. Existing `test_collate_sample_full_hypergraph_returns_cached_hdata` continues to pass — the equality of contents is preserved. Closes #173 --- hyperbench/data/loader.py | 7 ++- hyperbench/tests/data/loader_test.py | 69 ++++++++++++++++++++++++++++ hyperbench/types/hdata.py | 27 +++++++++++ 3 files changed, 102 insertions(+), 1 deletion(-) diff --git a/hyperbench/data/loader.py b/hyperbench/data/loader.py index 4ccd374..075d782 100644 --- a/hyperbench/data/loader.py +++ b/hyperbench/data/loader.py @@ -71,7 +71,12 @@ def collate(self, batch: List[HData]) -> HData: A single :class:`HData` object containing the collated data. """ if self.__sample_full_hypergraph: - return self.__cached_dataset_hdata.to(batch[0].device) + # `HData.to` is in-place, so without `clone()` the dataloader + # would return the cached dataset object itself — meaning any + # downstream mutation (or device transfer through a different + # path on the next iteration) would silently mutate shared + # dataset state. See issue #173. + return self.__cached_dataset_hdata.clone().to(batch[0].device) collated_hyperedge_index = torch.cat([data.hyperedge_index for data in batch], dim=1) hyperedge_index_wrapper = HyperedgeIndex(collated_hyperedge_index).remove_duplicate_edges() diff --git a/hyperbench/tests/data/loader_test.py b/hyperbench/tests/data/loader_test.py index d213126..e6019e3 100644 --- a/hyperbench/tests/data/loader_test.py +++ b/hyperbench/tests/data/loader_test.py @@ -451,3 +451,72 @@ def test_collate_with_node_sampled_batch(): assert torch.equal(batched.hyperedge_index, expected_hyperedge_index) assert batched.hyperedge_attr is None + + +# --------------------------------------------------------------------------- +# Storage-isolation regression tests for issue #173. +# +# When sample_full_hypergraph=True, DataLoader.collate() used to return +# `self.__cached_dataset_hdata.to(batch[0].device)`. Because HData.to() is +# in-place, that returned the cached dataset object itself — so iterating +# the dataloader and mutating the batch (or moving it through a different +# code path on the next iteration) silently mutated the dataset's cached +# hdata. +# --------------------------------------------------------------------------- + + +def test_collate_sample_full_hypergraph_does_not_share_storage_with_cached_hdata( + mock_dataset_single_sample, +): + loader = DataLoader(mock_dataset_single_sample, sample_full_hypergraph=True) + + batched = loader.collate([mock_dataset_single_sample[0]]) + + cached: HData = mock_dataset_single_sample.hdata + assert batched is not cached + # The most-likely-mutated tensors must not share storage with the + # cached dataset, otherwise a downstream caller mutating the batch + # would corrupt the dataset. + assert batched.x.data_ptr() != cached.x.data_ptr() + assert batched.hyperedge_index.data_ptr() != cached.hyperedge_index.data_ptr() + assert batched.hyperedge_attr is not None + assert batched.hyperedge_attr.data_ptr() != cached.hyperedge_attr.data_ptr() + + +def test_collate_sample_full_hypergraph_mutating_batch_does_not_affect_cached_hdata( + mock_dataset_single_sample, +): + cached: HData = mock_dataset_single_sample.hdata + cached_x_snapshot = cached.x.clone() + cached_hyperedge_index_snapshot = cached.hyperedge_index.clone() + cached_hyperedge_attr_snapshot = cached.hyperedge_attr.clone() + + loader = DataLoader(mock_dataset_single_sample, sample_full_hypergraph=True) + batched = loader.collate([mock_dataset_single_sample[0]]) + + batched.x.add_(1.0) + batched.hyperedge_index.fill_(99) + batched.hyperedge_attr.add_(1.0) + + assert torch.equal(cached.x, cached_x_snapshot) + assert torch.equal(cached.hyperedge_index, cached_hyperedge_index_snapshot) + assert torch.equal(cached.hyperedge_attr, cached_hyperedge_attr_snapshot) + + +def test_collate_sample_full_hypergraph_with_weights_isolates_weights( + mock_dataset_single_sample_with_weights, +): + cached: HData = mock_dataset_single_sample_with_weights.hdata + cached_weights_snapshot = cached.hyperedge_weights.clone() + + loader = DataLoader( + mock_dataset_single_sample_with_weights, sample_full_hypergraph=True, + ) + batched = loader.collate([mock_dataset_single_sample_with_weights[0]]) + + assert batched.hyperedge_weights is not None + assert batched.hyperedge_weights.data_ptr() != cached.hyperedge_weights.data_ptr() + + batched.hyperedge_weights.fill_(0.0) + + assert torch.equal(cached.hyperedge_weights, cached_weights_snapshot) diff --git a/hyperbench/types/hdata.py b/hyperbench/types/hdata.py index ecf6ffb..f34f4f3 100644 --- a/hyperbench/types/hdata.py +++ b/hyperbench/types/hdata.py @@ -679,6 +679,33 @@ def shuffle(self, seed: Optional[int] = None) -> "HData": y=new_y, ) + def clone(self) -> "HData": + """ + Return a deep copy of this :class:`HData` with independent tensor storage. + + Useful when a caller wants to mutate or device-transfer a derived + :class:`HData` without leaking changes back into the original — see + e.g. :class:`hyperbench.data.DataLoader` with + ``sample_full_hypergraph=True``, which historically returned the + cached dataset object directly and was vulnerable to in-place + mutation by ``HData.to`` (issue #173). + + Returns: + A new :class:`HData` whose every tensor field is a fresh + ``.clone()`` of the original; ``num_nodes`` / ``num_hyperedges`` + are passed through unchanged. + """ + return self.__class__( + x=self.x.clone(), + hyperedge_index=self.hyperedge_index.clone(), + hyperedge_weights=self.hyperedge_weights.clone() if self.hyperedge_weights is not None else None, + hyperedge_attr=self.hyperedge_attr.clone() if self.hyperedge_attr is not None else None, + num_nodes=self.num_nodes, + num_hyperedges=self.num_hyperedges, + global_node_ids=self.global_node_ids.clone() if self.global_node_ids is not None else None, + y=self.y.clone(), + ) + def to(self, device: torch.device | str, non_blocking: bool = False) -> "HData": """ Move all tensors to the specified device. From bd402577fbfc775cea60528568161d21d0ec1b9b Mon Sep 17 00:00:00 2001 From: Tiziano Date: Wed, 29 Apr 2026 13:37:34 +0200 Subject: [PATCH 2/2] refactor: change HData.clone() to follow codebase paractices --- hyperbench/data/loader.py | 5 -- hyperbench/tests/data/loader_test.py | 84 ++++++++++++++-------------- hyperbench/types/hdata.py | 28 +++++----- 3 files changed, 56 insertions(+), 61 deletions(-) diff --git a/hyperbench/data/loader.py b/hyperbench/data/loader.py index 075d782..4730b96 100644 --- a/hyperbench/data/loader.py +++ b/hyperbench/data/loader.py @@ -71,11 +71,6 @@ def collate(self, batch: List[HData]) -> HData: A single :class:`HData` object containing the collated data. """ if self.__sample_full_hypergraph: - # `HData.to` is in-place, so without `clone()` the dataloader - # would return the cached dataset object itself — meaning any - # downstream mutation (or device transfer through a different - # path on the next iteration) would silently mutate shared - # dataset state. See issue #173. return self.__cached_dataset_hdata.clone().to(batch[0].device) collated_hyperedge_index = torch.cat([data.hyperedge_index for data in batch], dim=1) diff --git a/hyperbench/tests/data/loader_test.py b/hyperbench/tests/data/loader_test.py index e6019e3..9057f62 100644 --- a/hyperbench/tests/data/loader_test.py +++ b/hyperbench/tests/data/loader_test.py @@ -13,7 +13,13 @@ def mock_dataset_single_sample(): x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) hyperedge_index = torch.tensor([[0, 1, 1, 2], [0, 0, 1, 1]]) hyperedge_attr = torch.tensor([[0.5], [0.7]]) - hdata = HData(x=x, hyperedge_index=hyperedge_index, hyperedge_attr=hyperedge_attr) + hyperedge_weights = torch.tensor([[0.8], [0.9]]) + hdata = HData( + x=x, + hyperedge_index=hyperedge_index, + hyperedge_attr=hyperedge_attr, + hyperedge_weights=hyperedge_weights, + ) dataset = MagicMock(spec=Dataset) dataset.hdata = hdata @@ -453,18 +459,6 @@ def test_collate_with_node_sampled_batch(): assert batched.hyperedge_attr is None -# --------------------------------------------------------------------------- -# Storage-isolation regression tests for issue #173. -# -# When sample_full_hypergraph=True, DataLoader.collate() used to return -# `self.__cached_dataset_hdata.to(batch[0].device)`. Because HData.to() is -# in-place, that returned the cached dataset object itself — so iterating -# the dataloader and mutating the batch (or moving it through a different -# code path on the next iteration) silently mutated the dataset's cached -# hdata. -# --------------------------------------------------------------------------- - - def test_collate_sample_full_hypergraph_does_not_share_storage_with_cached_hdata( mock_dataset_single_sample, ): @@ -473,50 +467,56 @@ def test_collate_sample_full_hypergraph_does_not_share_storage_with_cached_hdata batched = loader.collate([mock_dataset_single_sample[0]]) cached: HData = mock_dataset_single_sample.hdata + assert batched is not cached - # The most-likely-mutated tensors must not share storage with the - # cached dataset, otherwise a downstream caller mutating the batch - # would corrupt the dataset. assert batched.x.data_ptr() != cached.x.data_ptr() assert batched.hyperedge_index.data_ptr() != cached.hyperedge_index.data_ptr() assert batched.hyperedge_attr is not None - assert batched.hyperedge_attr.data_ptr() != cached.hyperedge_attr.data_ptr() + assert ( + batched.hyperedge_attr.data_ptr() + != utils.to_non_empty_edgeattr(cached.hyperedge_attr).data_ptr() + ) + assert batched.hyperedge_weights is not None + assert ( + batched.hyperedge_weights.data_ptr() + != utils.to_non_empty_edgeattr(cached.hyperedge_weights).data_ptr() + ) def test_collate_sample_full_hypergraph_mutating_batch_does_not_affect_cached_hdata( mock_dataset_single_sample, ): + loader = DataLoader(mock_dataset_single_sample, sample_full_hypergraph=True) + cached: HData = mock_dataset_single_sample.hdata - cached_x_snapshot = cached.x.clone() - cached_hyperedge_index_snapshot = cached.hyperedge_index.clone() - cached_hyperedge_attr_snapshot = cached.hyperedge_attr.clone() + cached_x = cached.x.clone() + cached_hyperedge_index = cached.hyperedge_index.clone() + cached_hyperedge_attr = utils.to_non_empty_edgeattr(cached.hyperedge_attr).clone() + cached_hypeedge_weights = utils.to_non_empty_edgeattr(cached.hyperedge_weights).clone() - loader = DataLoader(mock_dataset_single_sample, sample_full_hypergraph=True) batched = loader.collate([mock_dataset_single_sample[0]]) - batched.x.add_(1.0) - batched.hyperedge_index.fill_(99) - batched.hyperedge_attr.add_(1.0) + batched.x = torch.zeros_like(cached_x) + batched.hyperedge_index = torch.zeros_like(cached_hyperedge_index) + batched.hyperedge_attr = torch.zeros_like(cached_hyperedge_attr) + batched.hyperedge_weights = torch.zeros_like(cached_hypeedge_weights) - assert torch.equal(cached.x, cached_x_snapshot) - assert torch.equal(cached.hyperedge_index, cached_hyperedge_index_snapshot) - assert torch.equal(cached.hyperedge_attr, cached_hyperedge_attr_snapshot) + assert torch.equal(cached.x, cached_x) + assert not torch.equal(cached.x, batched.x) + assert torch.equal(cached.hyperedge_index, cached_hyperedge_index) + assert not torch.equal(cached.hyperedge_index, batched.hyperedge_index) -def test_collate_sample_full_hypergraph_with_weights_isolates_weights( - mock_dataset_single_sample_with_weights, -): - cached: HData = mock_dataset_single_sample_with_weights.hdata - cached_weights_snapshot = cached.hyperedge_weights.clone() - - loader = DataLoader( - mock_dataset_single_sample_with_weights, sample_full_hypergraph=True, + assert torch.equal(utils.to_non_empty_edgeattr(cached.hyperedge_attr), cached_hyperedge_attr) + assert not torch.equal( + utils.to_non_empty_edgeattr(cached.hyperedge_attr), + utils.to_non_empty_edgeattr(batched.hyperedge_attr), ) - batched = loader.collate([mock_dataset_single_sample_with_weights[0]]) - assert batched.hyperedge_weights is not None - assert batched.hyperedge_weights.data_ptr() != cached.hyperedge_weights.data_ptr() - - batched.hyperedge_weights.fill_(0.0) - - assert torch.equal(cached.hyperedge_weights, cached_weights_snapshot) + assert torch.equal( + utils.to_non_empty_edgeattr(cached.hyperedge_weights), cached_hypeedge_weights + ) + assert not torch.equal( + utils.to_non_empty_edgeattr(cached.hyperedge_weights), + utils.to_non_empty_edgeattr(batched.hyperedge_weights), + ) diff --git a/hyperbench/types/hdata.py b/hyperbench/types/hdata.py index f34f4f3..8c4d812 100644 --- a/hyperbench/types/hdata.py +++ b/hyperbench/types/hdata.py @@ -681,28 +681,28 @@ def shuffle(self, seed: Optional[int] = None) -> "HData": def clone(self) -> "HData": """ - Return a deep copy of this :class:`HData` with independent tensor storage. - - Useful when a caller wants to mutate or device-transfer a derived - :class:`HData` without leaking changes back into the original — see - e.g. :class:`hyperbench.data.DataLoader` with - ``sample_full_hypergraph=True``, which historically returned the - cached dataset object directly and was vulnerable to in-place - mutation by ``HData.to`` (issue #173). + Return a deep copy of this :class:`HData`. Returns: - A new :class:`HData` whose every tensor field is a fresh - ``.clone()`` of the original; ``num_nodes`` / ``num_hyperedges`` - are passed through unchanged. + A new :class:`HData` that is a deep copy of this instance. """ + cloned_hyperedge_weights = ( + self.hyperedge_weights.clone() if self.hyperedge_weights is not None else None + ) + cloned_hyperedge_attr = ( + self.hyperedge_attr.clone() if self.hyperedge_attr is not None else None + ) + cloned_global_node_ids = ( + self.global_node_ids.clone() if self.global_node_ids is not None else None + ) return self.__class__( x=self.x.clone(), hyperedge_index=self.hyperedge_index.clone(), - hyperedge_weights=self.hyperedge_weights.clone() if self.hyperedge_weights is not None else None, - hyperedge_attr=self.hyperedge_attr.clone() if self.hyperedge_attr is not None else None, + hyperedge_weights=cloned_hyperedge_weights, + hyperedge_attr=cloned_hyperedge_attr, num_nodes=self.num_nodes, num_hyperedges=self.num_hyperedges, - global_node_ids=self.global_node_ids.clone() if self.global_node_ids is not None else None, + global_node_ids=cloned_global_node_ids, y=self.y.clone(), )