diff --git a/gemma_encoder.py b/gemma_encoder.py index a60e7d6..c796a09 100644 --- a/gemma_encoder.py +++ b/gemma_encoder.py @@ -171,7 +171,8 @@ def __init__(self, embedding_directory=None, tokenizer_data={}): def ltxv_gemma_clip(encoder_path, ltxv_path, processor=None, dtype=None): class _LTXVGemmaTextEncoderModel(LTXVGemmaTextEncoderModel): def __init__(self, device="cpu", dtype=dtype, model_options={}): - dtype = torch.bfloat16 # TODO: make this configurable + if dtype is None: + dtype = torch.bfloat16 gemma_model = Gemma3ForConditionalGeneration.from_pretrained( encoder_path, @@ -264,6 +265,15 @@ def load_model(self, gemma_path: str, ltxv_path: str, max_length: int): logger.warning(f"Could not load processor from {model_root}: {e}") clip_dtype = torch.bfloat16 + try: + # MPS does not have full support bfloat16, use float16 instead + if comfy.model_management.get_torch_device().type == "mps": + clip_dtype = torch.float16 + except Exception as e: + logger.debug( + f"Could not detect device type for dtype selection: {e}", exc_info=True + ) + ltxv_full_path = folder_paths.get_full_path("checkpoints", ltxv_path) clip_target = comfy.supported_models_base.ClipTarget( tokenizer=tokenizer_class, @@ -428,9 +438,10 @@ def _enhance( ) model_inputs = _pad_inputs_for_attention_alignment(model_inputs, pad_token_id) + devices = [model.device] if model.device.type == "cuda" else [] with ( torch.inference_mode(), - torch.random.fork_rng(devices=[model.device]), + torch.random.fork_rng(devices=devices), torch.autocast(device_type=model.device.type, dtype=model.dtype), ): torch.manual_seed(seed)