diff --git a/embeddings_connector.py b/embeddings_connector.py index 8643e3b..3d65cef 100644 --- a/embeddings_connector.py +++ b/embeddings_connector.py @@ -322,7 +322,10 @@ def forward( hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device ) indices_grid = indices_grid[None, None, :] - freqs_cis = self.precompute_freqs_cis(indices_grid) + # "exp" RoPE uses POS_EMBEDDING_EXP_VALUES (sized for inner_dim=3840). + # LTX-2.3 connector has inner_dim=4096 → use "exp_2" (standard formula, scales with inner_dim). + _rope_spacing = "exp" if self.inner_dim == 3840 else "exp_2" + freqs_cis = self.precompute_freqs_cis(indices_grid, _rope_spacing) # 2. Blocks for block_idx, block in enumerate(self.transformer_1d_blocks): @@ -376,7 +379,7 @@ def load_embeddings_connector( split_rope=rope_type == LTXRopeType.SPLIT, double_precision_rope=frequencies_precision == LTXFrequenciesPrecision.FLOAT64, ) - connector.load_state_dict(sd_connector) + connector.load_state_dict(sd_connector, strict=False) return connector diff --git a/gemma_encoder.py b/gemma_encoder.py index a60e7d6..e12de41 100644 --- a/gemma_encoder.py +++ b/gemma_encoder.py @@ -1,4 +1,5 @@ import logging +import os from glob import glob from pathlib import Path from typing import List, Optional, Tuple @@ -10,6 +11,7 @@ import torch from PIL import Image from transformers import ( + AutoModelForCausalLM, AutoImageProcessor, AutoTokenizer, Gemma3Config, @@ -51,7 +53,7 @@ def tensor_to_pil(tensor: torch.Tensor) -> Image.Image: class LTXVGemmaTokenizer: def __init__(self, tokenizer_path: str, max_length: int = 1024): self.tokenizer = AutoTokenizer.from_pretrained( - tokenizer_path, local_files_only=True, model_max_length=max_length + tokenizer_path, local_files_only=True, ignore_mismatched_sizes=True, model_max_length=max_length ) # Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much. self.tokenizer.padding_side = "left" @@ -156,7 +158,7 @@ def load_sd(self, sd): return self.model.load_state_dict(sd, strict=False) def memory_required(self, input_shape): - # Return a conservative estimate in bytesed(input_shape) + # Return a conservative estimate in bytes return self._model_memory_required @@ -168,14 +170,17 @@ def __init__(self, embedding_directory=None, tokenizer_data={}): return _LTXVGemmaTokenizer -def ltxv_gemma_clip(encoder_path, ltxv_path, processor=None, dtype=None): +def ltxv_gemma_clip(encoder_path, ltxv_path, processor=None, dtype=None, gguf_file=None): class _LTXVGemmaTextEncoderModel(LTXVGemmaTextEncoderModel): def __init__(self, device="cpu", dtype=dtype, model_options={}): dtype = torch.bfloat16 # TODO: make this configurable - gemma_model = Gemma3ForConditionalGeneration.from_pretrained( + _kw = {"local_files_only": True} + if gguf_file: _kw["gguf_file"] = gguf_file + gemma_model = AutoModelForCausalLM.from_pretrained( encoder_path, - local_files_only=True, + dtype=dtype, + **_kw, torch_dtype=dtype, ) @@ -224,7 +229,8 @@ def INPUT_TYPES(s): {"tooltip": "The name of the text encoder model to load."}, ), "ltxv_path": ( - folder_paths.get_filename_list("checkpoints"), + [""] + folder_paths.get_filename_list("checkpoints"), + {"default": ""}, {"tooltip": "The name of the ltxv model to load."}, ), "max_length": ( @@ -245,7 +251,17 @@ def load_model(self, gemma_path: str, ltxv_path: str, max_length: int): path = Path(folder_paths.get_full_path("text_encoders", gemma_path)) model_root = path.parents[1] tokenizer_path = Path(find_matching_dir(model_root, "tokenizer.model")) - gemma_model_path = Path(find_matching_dir(model_root, "model*.safetensors")) + gguf_filename = None + try: + gemma_model_path = Path(find_matching_dir(model_root, "model*.safetensors")) + except Exception: + gguf_dir = Path(find_matching_dir(model_root, "*.gguf")) + gguf_files = sorted(gguf_dir.glob("*.gguf")) + if not gguf_files: + raise ValueError(f"No GGUF found in {gguf_dir}") + gemma_model_path = gguf_dir + gguf_filename = gguf_files[0].name + logger.info(f"Using GGUF: {gguf_dir / gguf_filename}") processor_path = Path(find_matching_dir(model_root, "preprocessor_config.json")) tokenizer_class = ltxv_gemma_tokenizer(tokenizer_path, max_length=max_length) @@ -253,7 +269,7 @@ def load_model(self, gemma_path: str, ltxv_path: str, max_length: int): try: image_processor = AutoImageProcessor.from_pretrained( str(processor_path), - local_files_only=True, + local_files_only=True, ignore_mismatched_sizes=True, ) processor = Gemma3Processor( image_processor=image_processor, @@ -264,11 +280,17 @@ def load_model(self, gemma_path: str, ltxv_path: str, max_length: int): logger.warning(f"Could not load processor from {model_root}: {e}") clip_dtype = torch.bfloat16 - ltxv_full_path = folder_paths.get_full_path("checkpoints", ltxv_path) + if ltxv_path: + ltxv_full_path = folder_paths.get_full_path("checkpoints", ltxv_path) + else: + _unet_dirs = folder_paths.get_folder_paths("unet") + _ggufs = [g for d in _unet_dirs for g in glob(os.path.join(d, "*.gguf"))] + ltxv_full_path = _ggufs[0] if _ggufs else str(gemma_model_path / "proj_linear.safetensors") + logger.info(f"GGUF connector path: {ltxv_full_path}") clip_target = comfy.supported_models_base.ClipTarget( tokenizer=tokenizer_class, clip=ltxv_gemma_clip( - gemma_model_path, ltxv_full_path, processor=processor, dtype=clip_dtype + gemma_model_path, ltxv_full_path, processor=processor, dtype=clip_dtype, gguf_file=gguf_filename ), ) @@ -662,7 +684,7 @@ def transformers_gemma3_from_encoder(encoder): tokenizer_class = ltxv_gemma_tokenizer(jsons_path, max_length=1024) image_processor = AutoImageProcessor.from_pretrained( str(jsons_path), - local_files_only=True, + local_files_only=True, ignore_mismatched_sizes=True, ) processor = Gemma3Processor( image_processor=image_processor, diff --git a/text_embeddings_connectors.py b/text_embeddings_connectors.py index a5a7688..40871b5 100644 --- a/text_embeddings_connectors.py +++ b/text_embeddings_connectors.py @@ -6,10 +6,15 @@ 3. Embeddings Processor (Video / AV) -- wraps Embeddings1DConnector(s) """ +import glob +import importlib.util import json +import logging import math +import os from pathlib import Path +import folder_paths import torch from comfy.utils import load_torch_file from einops import rearrange @@ -21,6 +26,8 @@ load_video_embeddings_connector, ) +logger = logging.getLogger(__name__) + _PREFIX_BASE = "model.diffusion_model." _PREFIX_TEXT_PROJ = "text_embedding_projection." @@ -313,6 +320,38 @@ def _load_single_aggregate_embed_from_file(path, dtype): # --------------------------------------------------------------------------- + +def _load_gguf_connector_sd(gguf_path): + """Load connector + projection tensors from GGUF for LTXVGemmaCLIPModelLoader.""" + import gguf as _gguf + import numpy as np + reader = _gguf.GGUFReader(str(gguf_path)) + sd = {} + prefixes = ('video_embeddings_connector', 'audio_embeddings_connector', 'text_embedding_projection', 'audio_adaln_single') + for t in reader.tensors: + if not any(t.name.startswith(p) for p in prefixes): + continue + try: + ttype = t.tensor_type.name + shape = list(reversed(t.shape)) + raw = bytes(t.data) + if ttype == 'F32': + sd[f"model.diffusion_model.{t.name}"] = torch.from_numpy(np.frombuffer(raw, dtype=np.float32).copy()).reshape(shape) + elif ttype == 'F16': + sd[f"model.diffusion_model.{t.name}"] = torch.from_numpy(np.frombuffer(raw, dtype=np.float16).copy()).reshape(shape) + elif ttype == 'BF16': + sd[f"model.diffusion_model.{t.name}"] = torch.frombuffer(bytearray(raw), dtype=torch.bfloat16).reshape(shape).contiguous() + else: + dq = os.path.join(os.path.dirname(__file__), '..', 'ComfyUI-GGUF', 'dequant.py') + if os.path.exists(dq): + spec = importlib.util.spec_from_file_location("dequant", dq) + m = importlib.util.module_from_spec(spec); spec.loader.exec_module(m) + sd[f"model.diffusion_model.{t.name}"] = torch.tensor(m.dequantize(t.data, t.tensor_type), dtype=torch.float32).reshape(shape) + except Exception as e: + logger.warning("Skipping GGUF tensor %s: %s", t.name, e) + logger.info("GGUF connector: loaded %d tensors", len(sd)) + return sd + def load_text_embeddings_pipeline( ltxv_path, dtype=torch.bfloat16, fallback_proj_path=None ): @@ -330,9 +369,42 @@ def load_text_embeddings_pipeline( Returns: (feature_extractor, embeddings_processor) """ - sd, metadata = load_torch_file(str(ltxv_path), return_metadata=True) - config = json.loads(metadata.get("config", "{}")) - transformer_config = config.get("transformer", {}) + if ltxv_path and str(ltxv_path).endswith('.gguf'): + sd = _load_gguf_connector_sd(ltxv_path) + transformer_config = { + "caption_projection_first_linear": False, + "caption_proj_input_norm": False, + "caption_projection_second_linear": False, + "caption_proj_before_connector": True, + "text_encoder_norm_type": "per_token_rms", + "prompt_embedding_dim": 3840, + "connector_num_layers": 8, + "connector_num_attention_heads": 32, + "connector_attention_head_dim": 128, + "connector_apply_gated_attention": True, + "connector_positional_embedding_max_pos": [4096], + "audio_connector_attention_head_dim": 64, + } + # Merge text_embedding_projection keys from proj_linear.safetensors + _proj_candidates = [ + f for d in folder_paths.get_folder_paths("text_encoders") + for f in glob.glob(os.path.join(d, "*/proj_linear.safetensors")) + ] + _proj_path = fallback_proj_path or (_proj_candidates[0] if _proj_candidates else None) + if _proj_path is not None: + try: + from comfy.utils import load_torch_file as _ltf2 + proj_sd = _ltf2(str(_proj_path)) + sd.update(proj_sd) + logger.info("Merged %d proj_linear keys from %s", len(proj_sd), _proj_path) + except Exception as e: + logger.warning("Could not merge proj_linear keys: %s", e) + else: + logger.warning("proj_linear.safetensors not found; text projection may be missing") + else: + sd, metadata = load_torch_file(str(ltxv_path), return_metadata=True) + config = json.loads(metadata.get("config", "{}")) + transformer_config = config.get("transformer", {}) is_av = f"{_PREFIX_BASE}audio_adaln_single.linear.weight" in sd has_dual_aggregate = f"{_PREFIX_TEXT_PROJ}video_aggregate_embed.weight" in sd