diff --git a/tribev2/demo_utils.py b/tribev2/demo_utils.py index fa735a7..792d82f 100644 --- a/tribev2/demo_utils.py +++ b/tribev2/demo_utils.py @@ -177,7 +177,9 @@ def from_pretrained( Cluster backend forwarded to feature-extractor infra (``"auto"`` by default). device: - Torch device string. ``"auto"`` selects CUDA when available. + Torch device. ``"auto"`` selects CUDA > MPS > CPU. + Note: feature extractors (text, audio, video) always run on CPU + regardless of the specified device. config_update: Optional dictionary of config overrides applied after the YAML config is loaded. @@ -190,7 +192,12 @@ def from_pretrained( if cache_folder is not None: Path(cache_folder).mkdir(parents=True, exist_ok=True) if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" checkpoint_dir = Path(checkpoint_dir) if checkpoint_dir.exists(): config_path = checkpoint_dir / "config.yaml" @@ -203,6 +210,12 @@ def from_pretrained( ckpt_path = hf_hub_download(repo_id, checkpoint_name) with open(config_path, "r") as f: config = ConfDict(yaml.load(f, Loader=yaml.UnsafeLoader)) + if device in ("cpu", "mps"): # mps not supported by neuralset extractors + # Override all extractor devices to cpu when cuda is unavailable + for modality in ["text", "audio"]: + config[f"data.{modality}_feature.device"] = "cpu" + config["data.image_feature.image.device"] = "cpu" + config["data.video_feature.image.device"] = "cpu" for modality in ["text", "audio", "video"]: config[f"data.{modality}_feature.infra.folder"] = cache_folder config[f"data.{modality}_feature.infra.cluster"] = cluster