diff --git a/encoding/models/nested_cv.py b/encoding/models/nested_cv.py index bc2336e..afa0d97 100644 --- a/encoding/models/nested_cv.py +++ b/encoding/models/nested_cv.py @@ -81,7 +81,14 @@ def fit_predict( alphas = np.logspace(-1, 8, 10) # Determine device - use GPU if available and requested - device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu" + + if use_gpu: + if torch.backends.mps.is_available(): + device = "mps:0" + elif torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" logger.info(f"Using device: {device}") logger.info(f"Folding type: {folding_type}") @@ -397,7 +404,7 @@ def _find_best_alphas( # Find the best alpha for each voxel best_alpha_idx = torch.argmax(mean_inner_corrs, dim=0) # Shape: (n_voxels,) best_valphas = torch.tensor( - [alphas[i] for i in best_alpha_idx], device=X_train.device + [alphas[i] for i in best_alpha_idx], device=X_train.device, dtype=torch.float32 ) if logger: logger.info("Found best alphas for each voxel")