From fe7b835f760af39c645f834b472c0ae401100c5f Mon Sep 17 00:00:00 2001 From: Deep Mehta Date: Thu, 21 May 2026 16:28:48 -0700 Subject: [PATCH] fix: scope the torch.load map_location wrapper instead of replacing it globally MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `chatterbox_node.py` did `torch.load = patched_torch_load` at import time, which replaces torch.load for the entire Python process — ComfyUI core and every other custom node pack included. Two problems in a shared environment (and on multi-tenant cloud runtimes where one process serves many users' jobs back to back): 1. Cross-pack clobbering. Other packs also wrap `torch.load`. Whichever imports last wins, so process-wide torch.load behavior depends on custom-node import order, which is not deterministic. 2. Behavior imposed on unrelated callers. After import, every `torch.load` in the process gets `map_location` forced onto it — including ComfyUI core's checkpoint loading and other packs that explicitly wanted default device placement. Fix: keep `patched_torch_load` exactly as-is, but install it only for the duration of this pack's own Chatterbox model loads via a `default_map_location()` context manager, restoring the previous torch.load in `finally`. All four `Chatterbox*.from_local(...)` call sites are wrapped, so device defaulting still works for every Chatterbox load; torch.load is left untouched for everyone else. No functional change to how this pack loads models; only the blast radius of the patch is reduced from process-global to call-scoped. --- chatterbox_node.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/chatterbox_node.py b/chatterbox_node.py index 907242a..0473331 100644 --- a/chatterbox_node.py +++ b/chatterbox_node.py @@ -1,4 +1,5 @@ import os +import contextlib import torch import torchaudio import numpy as np @@ -156,7 +157,8 @@ def load_turbo_model(device: str) -> ChatterboxTurboTTS: ] download_chatterbox_models("ResembleAI/chatterbox-turbo", turbo_files, local_dir) - return ChatterboxTurboTTS.from_local(str(local_dir), device) + with default_map_location(): + return ChatterboxTurboTTS.from_local(str(local_dir), device) def load_tts_model(device: str) -> ChatterboxTTS: @@ -177,7 +179,8 @@ def load_tts_model(device: str) -> ChatterboxTTS: ] download_chatterbox_models("ResembleAI/chatterbox", tts_files, local_dir) - return ChatterboxTTS.from_local(str(local_dir), device) + with default_map_location(): + return ChatterboxTTS.from_local(str(local_dir), device) def load_multilingual_model(device: str) -> ChatterboxMultilingualTTS: @@ -199,7 +202,8 @@ def load_multilingual_model(device: str) -> ChatterboxMultilingualTTS: ] download_chatterbox_models("ResembleAI/chatterbox", mtl_files, local_dir) - return ChatterboxMultilingualTTS.from_local(str(local_dir), device) + with default_map_location(): + return ChatterboxMultilingualTTS.from_local(str(local_dir), device) def load_vc_model(device: str) -> ChatterboxVC: @@ -218,9 +222,16 @@ def load_vc_model(device: str) -> ChatterboxVC: ] download_chatterbox_models("ResembleAI/chatterbox", vc_files, local_dir) - return ChatterboxVC.from_local(str(local_dir), device) - -# Monkey patch torch.load to use MPS or CPU if map_location is not specified + with default_map_location(): + return ChatterboxVC.from_local(str(local_dir), device) + +# torch.load wrapper: default map_location to the active device (MPS / CUDA / +# CPU) when the caller did not specify one. Installed via a context manager +# rather than replacing torch.load process-wide — a global replacement +# clobbers (and is clobbered by) other custom node packs that also wrap +# torch.load, and forces map_location onto ComfyUI core and every other pack +# that never opted in. Scoping keeps the device-defaulting only for this +# pack's own Chatterbox model loads. original_torch_load = torch.load def patched_torch_load(*args, **kwargs): if 'map_location' not in kwargs: @@ -234,7 +245,16 @@ def patched_torch_load(*args, **kwargs): kwargs['map_location'] = torch.device(device) return original_torch_load(*args, **kwargs) -torch.load = patched_torch_load + +@contextlib.contextmanager +def default_map_location(): + """Temporarily install patched_torch_load for this pack's model loads.""" + previous = torch.load + torch.load = patched_torch_load + try: + yield + finally: + torch.load = previous class AudioNodeBase: