Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions gemma_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
def ltxv_gemma_clip(encoder_path, ltxv_path, processor=None, dtype=None):
class _LTXVGemmaTextEncoderModel(LTXVGemmaTextEncoderModel):
def __init__(self, device="cpu", dtype=dtype, model_options={}):
dtype = torch.bfloat16 # TODO: make this configurable
if dtype is None:
dtype = torch.bfloat16

gemma_model = Gemma3ForConditionalGeneration.from_pretrained(
encoder_path,
Expand Down Expand Up @@ -264,6 +265,15 @@ def load_model(self, gemma_path: str, ltxv_path: str, max_length: int):
logger.warning(f"Could not load processor from {model_root}: {e}")

clip_dtype = torch.bfloat16
try:
# MPS does not have full support bfloat16, use float16 instead
if comfy.model_management.get_torch_device().type == "mps":
clip_dtype = torch.float16
except Exception as e:
logger.debug(
f"Could not detect device type for dtype selection: {e}", exc_info=True
)

ltxv_full_path = folder_paths.get_full_path("checkpoints", ltxv_path)
clip_target = comfy.supported_models_base.ClipTarget(
tokenizer=tokenizer_class,
Expand Down Expand Up @@ -428,9 +438,10 @@ def _enhance(
)
model_inputs = _pad_inputs_for_attention_alignment(model_inputs, pad_token_id)

devices = [model.device] if model.device.type == "cuda" else []
with (
torch.inference_mode(),
torch.random.fork_rng(devices=[model.device]),
torch.random.fork_rng(devices=devices),
torch.autocast(device_type=model.device.type, dtype=model.dtype),
):
torch.manual_seed(seed)
Expand Down