diff --git a/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py b/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py index d2ea35a73c4c7..1d50a57335f5c 100644 --- a/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py +++ b/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py @@ -387,7 +387,7 @@ def ConvertBatchToNumpy(self, batch) -> np.ndarray: return return_data - def ConvertBatchToPyTorch(self, batch: Any) -> torch.Tensor: + def ConvertBatchToPyTorch(self, batch: Any, device=None) -> torch.Tensor: """Convert a RTensor into a PyTorch tensor Args: @@ -404,7 +404,7 @@ def ConvertBatchToPyTorch(self, batch: Any) -> torch.Tensor: data.reshape((batch_size * num_columns,)) - return_data = torch.as_tensor(np.asarray(data)).reshape(batch_size, num_columns) + return_data = torch.as_tensor(np.asarray(data), device=device).reshape(batch_size, num_columns) # Splice target column from the data if target is given if self.target_given: @@ -704,12 +704,16 @@ def as_numpy(self) -> FormattedLoader: self._ensure_created() return FormattedLoader(self._internal, self._internal.ConvertBatchToNumpy, self._is_training) - def as_torch(self) -> FormattedLoader: + def as_torch(self, device: str | torch.device | None = None) -> FormattedLoader: """ Return an iterable that yields batches as PyTorch tensors. + + Args: + device: If given, the returned tensors are moved to the specified device. """ self._ensure_created() - return FormattedLoader(self._internal, self._internal.ConvertBatchToPyTorch, self._is_training) + conversion_fn = lambda batch: self._internal.ConvertBatchToPyTorch(batch, device) # noqa: E731 + return FormattedLoader(self._internal, conversion_fn, self._is_training) def as_tensorflow(self) -> tf.data.Dataset: """ diff --git a/bindings/pyroot/pythonizations/test/ml_dataloader.py b/bindings/pyroot/pythonizations/test/ml_dataloader.py index 647cf6cd605f1..2d49c2a11dc13 100644 --- a/bindings/pyroot/pythonizations/test/ml_dataloader.py +++ b/bindings/pyroot/pythonizations/test/ml_dataloader.py @@ -695,7 +695,7 @@ def test12_PyTorch(self): collected_z_train = [] collected_z_val = [] - iter_train = iter(gen_train.as_torch()) + iter_train = iter(gen_train.as_torch(device="cpu")) iter_val = iter(gen_validation.as_torch()) for _ in range(self.n_train_batch): @@ -1792,7 +1792,7 @@ def test12_PyTorch(self): collected_z_val = [] iter_train = iter(gen_train.as_torch()) - iter_val = iter(gen_validation.as_torch()) + iter_val = iter(gen_validation.as_torch(device="cpu")) for _ in range(self.n_train_batch): x, y, z = next(iter_train) @@ -2984,7 +2984,7 @@ def test12_PyTorch(self): collected_z_train = [] collected_z_val = [] - iter_train = iter(gen_train.as_torch()) + iter_train = iter(gen_train.as_torch(device="cpu")) iter_val = iter(gen_validation.as_torch()) for _ in range(self.n_train_batch): @@ -4133,7 +4133,7 @@ def test12_PyTorch(self): collected_z_val = [] iter_train = iter(gen_train.as_torch()) - iter_val = iter(gen_validation.as_torch()) + iter_val = iter(gen_validation.as_torch(device="cpu")) for _ in range(self.n_train_batch): x, y, z = next(iter_train) @@ -5546,7 +5546,7 @@ def test12_PyTorch(self): collected_z_train = [] collected_z_val = [] - iter_train = iter(gen_train.as_torch()) + iter_train = iter(gen_train.as_torch(device="cpu")) iter_val = iter(gen_validation.as_torch()) for _ in range(self.n_train_batch): @@ -7086,7 +7086,7 @@ def test12_PyTorch(self): collected_z_val = [] iter_train = iter(gen_train.as_torch()) - iter_val = iter(gen_validation.as_torch()) + iter_val = iter(gen_validation.as_torch(device="cpu")) for _ in range(self.n_train_batch): x, y, z = next(iter_train)