diff --git a/chatterbox_node.py b/chatterbox_node.py index 907242a..0473331 100644 --- a/chatterbox_node.py +++ b/chatterbox_node.py @@ -1,4 +1,5 @@ import os +import contextlib import torch import torchaudio import numpy as np @@ -156,7 +157,8 @@ def load_turbo_model(device: str) -> ChatterboxTurboTTS: ] download_chatterbox_models("ResembleAI/chatterbox-turbo", turbo_files, local_dir) - return ChatterboxTurboTTS.from_local(str(local_dir), device) + with default_map_location(): + return ChatterboxTurboTTS.from_local(str(local_dir), device) def load_tts_model(device: str) -> ChatterboxTTS: @@ -177,7 +179,8 @@ def load_tts_model(device: str) -> ChatterboxTTS: ] download_chatterbox_models("ResembleAI/chatterbox", tts_files, local_dir) - return ChatterboxTTS.from_local(str(local_dir), device) + with default_map_location(): + return ChatterboxTTS.from_local(str(local_dir), device) def load_multilingual_model(device: str) -> ChatterboxMultilingualTTS: @@ -199,7 +202,8 @@ def load_multilingual_model(device: str) -> ChatterboxMultilingualTTS: ] download_chatterbox_models("ResembleAI/chatterbox", mtl_files, local_dir) - return ChatterboxMultilingualTTS.from_local(str(local_dir), device) + with default_map_location(): + return ChatterboxMultilingualTTS.from_local(str(local_dir), device) def load_vc_model(device: str) -> ChatterboxVC: @@ -218,9 +222,16 @@ def load_vc_model(device: str) -> ChatterboxVC: ] download_chatterbox_models("ResembleAI/chatterbox", vc_files, local_dir) - return ChatterboxVC.from_local(str(local_dir), device) - -# Monkey patch torch.load to use MPS or CPU if map_location is not specified + with default_map_location(): + return ChatterboxVC.from_local(str(local_dir), device) + +# torch.load wrapper: default map_location to the active device (MPS / CUDA / +# CPU) when the caller did not specify one. Installed via a context manager +# rather than replacing torch.load process-wide — a global replacement +# clobbers (and is clobbered by) other custom node packs that also wrap +# torch.load, and forces map_location onto ComfyUI core and every other pack +# that never opted in. Scoping keeps the device-defaulting only for this +# pack's own Chatterbox model loads. original_torch_load = torch.load def patched_torch_load(*args, **kwargs): if 'map_location' not in kwargs: @@ -234,7 +245,16 @@ def patched_torch_load(*args, **kwargs): kwargs['map_location'] = torch.device(device) return original_torch_load(*args, **kwargs) -torch.load = patched_torch_load + +@contextlib.contextmanager +def default_map_location(): + """Temporarily install patched_torch_load for this pack's model loads.""" + previous = torch.load + torch.load = patched_torch_load + try: + yield + finally: + torch.load = previous class AudioNodeBase: