diff --git a/hyperbench/data/loader.py b/hyperbench/data/loader.py index 4ccd374..4730b96 100644 --- a/hyperbench/data/loader.py +++ b/hyperbench/data/loader.py @@ -71,7 +71,7 @@ def collate(self, batch: List[HData]) -> HData: A single :class:`HData` object containing the collated data. """ if self.__sample_full_hypergraph: - return self.__cached_dataset_hdata.to(batch[0].device) + return self.__cached_dataset_hdata.clone().to(batch[0].device) collated_hyperedge_index = torch.cat([data.hyperedge_index for data in batch], dim=1) hyperedge_index_wrapper = HyperedgeIndex(collated_hyperedge_index).remove_duplicate_edges() diff --git a/hyperbench/tests/data/loader_test.py b/hyperbench/tests/data/loader_test.py index d213126..9057f62 100644 --- a/hyperbench/tests/data/loader_test.py +++ b/hyperbench/tests/data/loader_test.py @@ -13,7 +13,13 @@ def mock_dataset_single_sample(): x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) hyperedge_index = torch.tensor([[0, 1, 1, 2], [0, 0, 1, 1]]) hyperedge_attr = torch.tensor([[0.5], [0.7]]) - hdata = HData(x=x, hyperedge_index=hyperedge_index, hyperedge_attr=hyperedge_attr) + hyperedge_weights = torch.tensor([[0.8], [0.9]]) + hdata = HData( + x=x, + hyperedge_index=hyperedge_index, + hyperedge_attr=hyperedge_attr, + hyperedge_weights=hyperedge_weights, + ) dataset = MagicMock(spec=Dataset) dataset.hdata = hdata @@ -451,3 +457,66 @@ def test_collate_with_node_sampled_batch(): assert torch.equal(batched.hyperedge_index, expected_hyperedge_index) assert batched.hyperedge_attr is None + + +def test_collate_sample_full_hypergraph_does_not_share_storage_with_cached_hdata( + mock_dataset_single_sample, +): + loader = DataLoader(mock_dataset_single_sample, sample_full_hypergraph=True) + + batched = loader.collate([mock_dataset_single_sample[0]]) + + cached: HData = mock_dataset_single_sample.hdata + + assert batched is not cached + assert batched.x.data_ptr() != cached.x.data_ptr() + assert batched.hyperedge_index.data_ptr() != cached.hyperedge_index.data_ptr() + assert batched.hyperedge_attr is not None + assert ( + batched.hyperedge_attr.data_ptr() + != utils.to_non_empty_edgeattr(cached.hyperedge_attr).data_ptr() + ) + assert batched.hyperedge_weights is not None + assert ( + batched.hyperedge_weights.data_ptr() + != utils.to_non_empty_edgeattr(cached.hyperedge_weights).data_ptr() + ) + + +def test_collate_sample_full_hypergraph_mutating_batch_does_not_affect_cached_hdata( + mock_dataset_single_sample, +): + loader = DataLoader(mock_dataset_single_sample, sample_full_hypergraph=True) + + cached: HData = mock_dataset_single_sample.hdata + cached_x = cached.x.clone() + cached_hyperedge_index = cached.hyperedge_index.clone() + cached_hyperedge_attr = utils.to_non_empty_edgeattr(cached.hyperedge_attr).clone() + cached_hypeedge_weights = utils.to_non_empty_edgeattr(cached.hyperedge_weights).clone() + + batched = loader.collate([mock_dataset_single_sample[0]]) + + batched.x = torch.zeros_like(cached_x) + batched.hyperedge_index = torch.zeros_like(cached_hyperedge_index) + batched.hyperedge_attr = torch.zeros_like(cached_hyperedge_attr) + batched.hyperedge_weights = torch.zeros_like(cached_hypeedge_weights) + + assert torch.equal(cached.x, cached_x) + assert not torch.equal(cached.x, batched.x) + + assert torch.equal(cached.hyperedge_index, cached_hyperedge_index) + assert not torch.equal(cached.hyperedge_index, batched.hyperedge_index) + + assert torch.equal(utils.to_non_empty_edgeattr(cached.hyperedge_attr), cached_hyperedge_attr) + assert not torch.equal( + utils.to_non_empty_edgeattr(cached.hyperedge_attr), + utils.to_non_empty_edgeattr(batched.hyperedge_attr), + ) + + assert torch.equal( + utils.to_non_empty_edgeattr(cached.hyperedge_weights), cached_hypeedge_weights + ) + assert not torch.equal( + utils.to_non_empty_edgeattr(cached.hyperedge_weights), + utils.to_non_empty_edgeattr(batched.hyperedge_weights), + ) diff --git a/hyperbench/types/hdata.py b/hyperbench/types/hdata.py index ecf6ffb..8c4d812 100644 --- a/hyperbench/types/hdata.py +++ b/hyperbench/types/hdata.py @@ -679,6 +679,33 @@ def shuffle(self, seed: Optional[int] = None) -> "HData": y=new_y, ) + def clone(self) -> "HData": + """ + Return a deep copy of this :class:`HData`. + + Returns: + A new :class:`HData` that is a deep copy of this instance. + """ + cloned_hyperedge_weights = ( + self.hyperedge_weights.clone() if self.hyperedge_weights is not None else None + ) + cloned_hyperedge_attr = ( + self.hyperedge_attr.clone() if self.hyperedge_attr is not None else None + ) + cloned_global_node_ids = ( + self.global_node_ids.clone() if self.global_node_ids is not None else None + ) + return self.__class__( + x=self.x.clone(), + hyperedge_index=self.hyperedge_index.clone(), + hyperedge_weights=cloned_hyperedge_weights, + hyperedge_attr=cloned_hyperedge_attr, + num_nodes=self.num_nodes, + num_hyperedges=self.num_hyperedges, + global_node_ids=cloned_global_node_ids, + y=self.y.clone(), + ) + def to(self, device: torch.device | str, non_blocking: bool = False) -> "HData": """ Move all tensors to the specified device.