From 7728d42fb77f3660d2c0f33c444a10e5296c5ae4 Mon Sep 17 00:00:00 2001 From: Zac <2215540+zboyles@users.noreply.github.com> Date: Sun, 7 Jun 2026 07:03:17 -0400 Subject: [PATCH 1/2] =?UTF-8?q?fix(gemma):=20add=20MPS=20support=20?= =?UTF-8?q?=E2=80=94=20fork=5Frng=20guard,=20dtype=20handling=20(float16?= =?UTF-8?q?=20on=20MPS)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gemma_encoder.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/gemma_encoder.py b/gemma_encoder.py index a60e7d6..4241984 100644 --- a/gemma_encoder.py +++ b/gemma_encoder.py @@ -171,7 +171,8 @@ def __init__(self, embedding_directory=None, tokenizer_data={}): def ltxv_gemma_clip(encoder_path, ltxv_path, processor=None, dtype=None): class _LTXVGemmaTextEncoderModel(LTXVGemmaTextEncoderModel): def __init__(self, device="cpu", dtype=dtype, model_options={}): - dtype = torch.bfloat16 # TODO: make this configurable + if dtype is None: + dtype = torch.bfloat16 gemma_model = Gemma3ForConditionalGeneration.from_pretrained( encoder_path, @@ -264,6 +265,13 @@ def load_model(self, gemma_path: str, ltxv_path: str, max_length: int): logger.warning(f"Could not load processor from {model_root}: {e}") clip_dtype = torch.bfloat16 + try: + # MPS does not have full support bfloat16, use float16 instead + if comfy.model_management.get_torch_device().type == "mps": + clip_dtype = torch.float16 + except Exception as e: + logger.debug(f"Could not detect device type for dtype selection: {e}") + ltxv_full_path = folder_paths.get_full_path("checkpoints", ltxv_path) clip_target = comfy.supported_models_base.ClipTarget( tokenizer=tokenizer_class, @@ -428,9 +436,10 @@ def _enhance( ) model_inputs = _pad_inputs_for_attention_alignment(model_inputs, pad_token_id) + devices = [model.device] if model.device.type == "cuda" else [] with ( torch.inference_mode(), - torch.random.fork_rng(devices=[model.device]), + torch.random.fork_rng(devices=devices), torch.autocast(device_type=model.device.type, dtype=model.dtype), ): torch.manual_seed(seed) From 90828c47a7f8b621e47b5bbb7fe890c1bb5f0b5e Mon Sep 17 00:00:00 2001 From: Zac <2215540+zboyles@users.noreply.github.com> Date: Sun, 7 Jun 2026 07:35:46 -0400 Subject: [PATCH 2/2] fix(gemma): log traceback when device detection fails (Copilot review) --- gemma_encoder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gemma_encoder.py b/gemma_encoder.py index 4241984..c796a09 100644 --- a/gemma_encoder.py +++ b/gemma_encoder.py @@ -270,7 +270,9 @@ def load_model(self, gemma_path: str, ltxv_path: str, max_length: int): if comfy.model_management.get_torch_device().type == "mps": clip_dtype = torch.float16 except Exception as e: - logger.debug(f"Could not detect device type for dtype selection: {e}") + logger.debug( + f"Could not detect device type for dtype selection: {e}", exc_info=True + ) ltxv_full_path = folder_paths.get_full_path("checkpoints", ltxv_path) clip_target = comfy.supported_models_base.ClipTarget(