From caabcdc6c714cd52d92e33425f3c3f991ce2f4f2 Mon Sep 17 00:00:00 2001 From: Yajur Khanna Date: Sat, 28 Mar 2026 23:08:42 -0400 Subject: [PATCH] Fix device override in from_pretrained, add MPS support - Move extractor device overrides to after config YAML is loaded; previously they were applied before yaml.load() reassigned config, silently discarding all overrides and causing AssertionError on CPU-only systems ("Torch not compiled with CUDA enabled"). - Add MPS auto-detection via torch.backends.mps.is_available() so Apple Silicon users get GPU acceleration on the FmriEncoder. - Force neuralset feature extractors to CPU when device is mps, since their device field is Literal["auto","cpu","cuda","accelerate"]. - Update docstring to document the CPU extractor fallback behavior. --- tribev2/demo_utils.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/tribev2/demo_utils.py b/tribev2/demo_utils.py index fa735a7..792d82f 100644 --- a/tribev2/demo_utils.py +++ b/tribev2/demo_utils.py @@ -177,7 +177,9 @@ def from_pretrained( Cluster backend forwarded to feature-extractor infra (``"auto"`` by default). device: - Torch device string. ``"auto"`` selects CUDA when available. + Torch device. ``"auto"`` selects CUDA > MPS > CPU. + Note: feature extractors (text, audio, video) always run on CPU + regardless of the specified device. config_update: Optional dictionary of config overrides applied after the YAML config is loaded. @@ -190,7 +192,12 @@ def from_pretrained( if cache_folder is not None: Path(cache_folder).mkdir(parents=True, exist_ok=True) if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" checkpoint_dir = Path(checkpoint_dir) if checkpoint_dir.exists(): config_path = checkpoint_dir / "config.yaml" @@ -203,6 +210,12 @@ def from_pretrained( ckpt_path = hf_hub_download(repo_id, checkpoint_name) with open(config_path, "r") as f: config = ConfDict(yaml.load(f, Loader=yaml.UnsafeLoader)) + if device in ("cpu", "mps"): # mps not supported by neuralset extractors + # Override all extractor devices to cpu when cuda is unavailable + for modality in ["text", "audio"]: + config[f"data.{modality}_feature.device"] = "cpu" + config["data.image_feature.image.device"] = "cpu" + config["data.video_feature.image.device"] = "cpu" for modality in ["text", "audio", "video"]: config[f"data.{modality}_feature.infra.folder"] = cache_folder config[f"data.{modality}_feature.infra.cluster"] = cluster