Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions chatterbox_node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import contextlib
import torch
import torchaudio
import numpy as np
Expand Down Expand Up @@ -156,7 +157,8 @@ def load_turbo_model(device: str) -> ChatterboxTurboTTS:
]

download_chatterbox_models("ResembleAI/chatterbox-turbo", turbo_files, local_dir)
return ChatterboxTurboTTS.from_local(str(local_dir), device)
with default_map_location():
return ChatterboxTurboTTS.from_local(str(local_dir), device)


def load_tts_model(device: str) -> ChatterboxTTS:
Expand All @@ -177,7 +179,8 @@ def load_tts_model(device: str) -> ChatterboxTTS:
]

download_chatterbox_models("ResembleAI/chatterbox", tts_files, local_dir)
return ChatterboxTTS.from_local(str(local_dir), device)
with default_map_location():
return ChatterboxTTS.from_local(str(local_dir), device)


def load_multilingual_model(device: str) -> ChatterboxMultilingualTTS:
Expand All @@ -199,7 +202,8 @@ def load_multilingual_model(device: str) -> ChatterboxMultilingualTTS:
]

download_chatterbox_models("ResembleAI/chatterbox", mtl_files, local_dir)
return ChatterboxMultilingualTTS.from_local(str(local_dir), device)
with default_map_location():
return ChatterboxMultilingualTTS.from_local(str(local_dir), device)


def load_vc_model(device: str) -> ChatterboxVC:
Expand All @@ -218,9 +222,16 @@ def load_vc_model(device: str) -> ChatterboxVC:
]

download_chatterbox_models("ResembleAI/chatterbox", vc_files, local_dir)
return ChatterboxVC.from_local(str(local_dir), device)

# Monkey patch torch.load to use MPS or CPU if map_location is not specified
with default_map_location():
return ChatterboxVC.from_local(str(local_dir), device)

# torch.load wrapper: default map_location to the active device (MPS / CUDA /
# CPU) when the caller did not specify one. Installed via a context manager
# rather than replacing torch.load process-wide — a global replacement
# clobbers (and is clobbered by) other custom node packs that also wrap
# torch.load, and forces map_location onto ComfyUI core and every other pack
# that never opted in. Scoping keeps the device-defaulting only for this
# pack's own Chatterbox model loads.
original_torch_load = torch.load
def patched_torch_load(*args, **kwargs):
if 'map_location' not in kwargs:
Expand All @@ -234,7 +245,16 @@ def patched_torch_load(*args, **kwargs):
kwargs['map_location'] = torch.device(device)
return original_torch_load(*args, **kwargs)

torch.load = patched_torch_load

@contextlib.contextmanager
def default_map_location():
"""Temporarily install patched_torch_load for this pack's model loads."""
previous = torch.load
torch.load = patched_torch_load
try:
yield
finally:
torch.load = previous


class AudioNodeBase:
Expand Down