From f44ff62f9094d63902d1b8d84461d1674c636c81 Mon Sep 17 00:00:00 2001 From: Silia Taider Date: Wed, 22 Apr 2026 11:49:00 +0200 Subject: [PATCH 1/2] [Python][ML] Add device param to RDataLoader's .as_torch() --- .../python/ROOT/_pythonization/_ml_dataloader.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py b/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py index d2ea35a73c4c7..1d50a57335f5c 100644 --- a/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py +++ b/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py @@ -387,7 +387,7 @@ def ConvertBatchToNumpy(self, batch) -> np.ndarray: return return_data - def ConvertBatchToPyTorch(self, batch: Any) -> torch.Tensor: + def ConvertBatchToPyTorch(self, batch: Any, device=None) -> torch.Tensor: """Convert a RTensor into a PyTorch tensor Args: @@ -404,7 +404,7 @@ def ConvertBatchToPyTorch(self, batch: Any) -> torch.Tensor: data.reshape((batch_size * num_columns,)) - return_data = torch.as_tensor(np.asarray(data)).reshape(batch_size, num_columns) + return_data = torch.as_tensor(np.asarray(data), device=device).reshape(batch_size, num_columns) # Splice target column from the data if target is given if self.target_given: @@ -704,12 +704,16 @@ def as_numpy(self) -> FormattedLoader: self._ensure_created() return FormattedLoader(self._internal, self._internal.ConvertBatchToNumpy, self._is_training) - def as_torch(self) -> FormattedLoader: + def as_torch(self, device: str | torch.device | None = None) -> FormattedLoader: """ Return an iterable that yields batches as PyTorch tensors. + + Args: + device: If given, the returned tensors are moved to the specified device. """ self._ensure_created() - return FormattedLoader(self._internal, self._internal.ConvertBatchToPyTorch, self._is_training) + conversion_fn = lambda batch: self._internal.ConvertBatchToPyTorch(batch, device) # noqa: E731 + return FormattedLoader(self._internal, conversion_fn, self._is_training) def as_tensorflow(self) -> tf.data.Dataset: """ From 1cdfa0fa0e6d11662b7ea6d8a1f1b290f06eeb68 Mon Sep 17 00:00:00 2001 From: Silia Taider Date: Wed, 22 Apr 2026 11:52:59 +0200 Subject: [PATCH 2/2] [Python][ML] Update PyTorch tests to use device --- bindings/pyroot/pythonizations/test/ml_dataloader.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/bindings/pyroot/pythonizations/test/ml_dataloader.py b/bindings/pyroot/pythonizations/test/ml_dataloader.py index 647cf6cd605f1..2d49c2a11dc13 100644 --- a/bindings/pyroot/pythonizations/test/ml_dataloader.py +++ b/bindings/pyroot/pythonizations/test/ml_dataloader.py @@ -695,7 +695,7 @@ def test12_PyTorch(self): collected_z_train = [] collected_z_val = [] - iter_train = iter(gen_train.as_torch()) + iter_train = iter(gen_train.as_torch(device="cpu")) iter_val = iter(gen_validation.as_torch()) for _ in range(self.n_train_batch): @@ -1792,7 +1792,7 @@ def test12_PyTorch(self): collected_z_val = [] iter_train = iter(gen_train.as_torch()) - iter_val = iter(gen_validation.as_torch()) + iter_val = iter(gen_validation.as_torch(device="cpu")) for _ in range(self.n_train_batch): x, y, z = next(iter_train) @@ -2984,7 +2984,7 @@ def test12_PyTorch(self): collected_z_train = [] collected_z_val = [] - iter_train = iter(gen_train.as_torch()) + iter_train = iter(gen_train.as_torch(device="cpu")) iter_val = iter(gen_validation.as_torch()) for _ in range(self.n_train_batch): @@ -4133,7 +4133,7 @@ def test12_PyTorch(self): collected_z_val = [] iter_train = iter(gen_train.as_torch()) - iter_val = iter(gen_validation.as_torch()) + iter_val = iter(gen_validation.as_torch(device="cpu")) for _ in range(self.n_train_batch): x, y, z = next(iter_train) @@ -5546,7 +5546,7 @@ def test12_PyTorch(self): collected_z_train = [] collected_z_val = [] - iter_train = iter(gen_train.as_torch()) + iter_train = iter(gen_train.as_torch(device="cpu")) iter_val = iter(gen_validation.as_torch()) for _ in range(self.n_train_batch): @@ -7086,7 +7086,7 @@ def test12_PyTorch(self): collected_z_val = [] iter_train = iter(gen_train.as_torch()) - iter_val = iter(gen_validation.as_torch()) + iter_val = iter(gen_validation.as_torch(device="cpu")) for _ in range(self.n_train_batch): x, y, z = next(iter_train)