diff --git a/tribev2/demo_utils.py b/tribev2/demo_utils.py index fa735a7..29e404f 100644 --- a/tribev2/demo_utils.py +++ b/tribev2/demo_utils.py @@ -198,11 +198,26 @@ def from_pretrained( else: from huggingface_hub import hf_hub_download - repo_id = str(checkpoint_dir) + repo_id = checkpoint_dir.as_posix() config_path = hf_hub_download(repo_id, "config.yaml") ckpt_path = hf_hub_download(repo_id, checkpoint_name) + class _WindowsCompatLoader(yaml.UnsafeLoader): + pass + + def _posixpath_constructor(loader, node): + args = loader.construct_sequence(node, deep=True) + return Path(*args) + + for _tag in ( + "tag:yaml.org,2002:python/object/apply:pathlib.PosixPath", + "tag:yaml.org,2002:python/object/apply:pathlib.WindowsPath", + "tag:yaml.org,2002:python/object/new:pathlib.PosixPath", + "tag:yaml.org,2002:python/object/new:pathlib.WindowsPath", + ): + _WindowsCompatLoader.yaml_constructors[_tag] = _posixpath_constructor + with open(config_path, "r") as f: - config = ConfDict(yaml.load(f, Loader=yaml.UnsafeLoader)) + config = ConfDict(yaml.load(f, Loader=_WindowsCompatLoader)) for modality in ["text", "audio", "video"]: config[f"data.{modality}_feature.infra.folder"] = cache_folder config[f"data.{modality}_feature.infra.cluster"] = cluster