diff --git a/max/examples/diffusion/simple_offline_generation.py b/max/examples/diffusion/simple_offline_generation.py index 4f983c43799..f2f513741be 100644 --- a/max/examples/diffusion/simple_offline_generation.py +++ b/max/examples/diffusion/simple_offline_generation.py @@ -83,6 +83,17 @@ "Flux2Pipeline_ModuleV3", "Flux2KleinPipeline_ModuleV3", } +QWEN_IMAGE_ARCH_NAMES = { + "QwenImagePipeline", + "QwenImageEditPipeline", + "QwenImageEditPlusPipeline", +} +QWEN_IMAGE_EDIT_ARCH_NAMES = { + "QwenImageEditPipeline", + "QwenImageEditPlusPipeline", +} +QWEN_DEFAULT_GUIDANCE_SCALE = 1.0 +QWEN_DEFAULT_TRUE_CFG_SCALE = 4.0 def parse_args(argv: list[str] | None = None) -> argparse.Namespace: @@ -158,8 +169,21 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: parser.add_argument( "--guidance-scale", type=float, - default=3.5, - help="Guidance scale for classifier-free guidance. Set to 1.0 to disable CFG.", + default=None, + help=( + "Guidance scale for classifier-free guidance. " + "If omitted, defaults to 1.0 for QwenImage family and 3.5 otherwise." + ), + ) + parser.add_argument( + "--true-cfg-scale", + type=float, + default=None, + help=( + "True classifier-free guidance scale. " + "If omitted, defaults to 4.0 for QwenImage family when negative prompt is provided, " + "and 1.0 otherwise." + ), ) parser.add_argument( "--seed", @@ -188,8 +212,9 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: parser.add_argument( "--input-image", type=str, + action="append", default=None, - help="Input image for image-to-image generation.", + help="Input image for image-to-image generation. Can be specified multiple times.", ) parser.add_argument( "--profile-timings", @@ -263,19 +288,10 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: assert args.num_inference_steps > 0, ( "num-inference-steps must be a positive integer." ) - assert args.guidance_scale > 0.0, "guidance-scale must be positive." - if args.residual_threshold is not None: - assert args.residual_threshold >= 0.0, ( - "residual-threshold must be non-negative." - ) - if args.taylorseer_cache_interval is not None: - assert args.taylorseer_cache_interval >= 1, ( - "taylorseer-cache-interval must be >= 1." - ) - if args.taylorseer_warmup_steps is not None: - assert args.taylorseer_warmup_steps >= 1, ( - "taylorseer-warmup-steps must be >= 1." - ) + if args.guidance_scale is not None: + assert args.guidance_scale > 0.0, "guidance-scale must be positive." + if args.true_cfg_scale is not None: + assert args.true_cfg_scale > 0.0, "true-cfg-scale must be positive." return args @@ -381,6 +397,8 @@ async def generate_image(args: argparse.Namespace) -> None: ) if arch.name in _FLUX2_ARCH_NAMES: max_length = 512 + elif arch.name in QWEN_IMAGE_ARCH_NAMES: + max_length = 512 print(f"Using max length: {max_length} for tokenizer") if ( @@ -433,27 +451,50 @@ async def generate_image(args: argparse.Namespace) -> None: print(f"Generating image for prompt: '{args.prompt}'") # Step 4: Create an OpenResponsesRequest - # Load input image if provided and convert to data URI - input_image_data_uri = load_image_as_data_uri(args.input_image) + # Load input images if provided and convert to data URIs + input_image_data_uris: list[str] = [] + if args.input_image: + for img_path in args.input_image: + uri = load_image_as_data_uri(img_path) + if uri is not None: + input_image_data_uris.append(uri) + + is_qwen_image_edit_family = arch.name in QWEN_IMAGE_EDIT_ARCH_NAMES + guidance_scale = args.guidance_scale + if guidance_scale is None: + guidance_scale = ( + QWEN_DEFAULT_GUIDANCE_SCALE if is_qwen_image_edit_family else 3.5 + ) - # Create request with structured message if image is provided - if input_image_data_uri: + true_cfg_scale = args.true_cfg_scale + if true_cfg_scale is None: + if is_qwen_image_edit_family and args.negative_prompt is not None: + true_cfg_scale = QWEN_DEFAULT_TRUE_CFG_SCALE + else: + true_cfg_scale = 1.0 + + # Create request with structured message if images are provided + if input_image_data_uris: # Image-to-image: Use structured message with InputImageContent + InputTextContent + image_content_items: list[InputImageContent | InputTextContent] = [ + InputImageContent( + type="input_image", + image_url=uri, + ) + for uri in input_image_data_uris + ] + image_content_items.append( + InputTextContent( + type="input_text", + text=args.prompt, + ) + ) body = OpenResponsesRequestBody( model=args.model, input=[ UserMessage( role="user", - content=[ - InputImageContent( - type="input_image", - image_url=input_image_data_uri, - ), - InputTextContent( - type="input_text", - text=args.prompt, - ), - ], + content=image_content_items, ) ], seed=args.seed, @@ -463,7 +504,8 @@ async def generate_image(args: argparse.Namespace) -> None: height=args.height, width=args.width, steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, + guidance_scale=guidance_scale, + true_cfg_scale=true_cfg_scale, ) ), ) @@ -479,7 +521,8 @@ async def generate_image(args: argparse.Namespace) -> None: height=args.height, width=args.width, steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, + guidance_scale=guidance_scale, + true_cfg_scale=true_cfg_scale, ) ), ) @@ -487,7 +530,8 @@ async def generate_image(args: argparse.Namespace) -> None: request = OpenResponsesRequest(request_id=RequestID(), body=body) print( - f"Parameters: steps={args.num_inference_steps}, guidance={args.guidance_scale}" + "Parameters: " + f"steps={args.num_inference_steps}, guidance={guidance_scale}, true_cfg={true_cfg_scale}" ) # Step 5: Create a PixelContext object from the request @@ -545,16 +589,19 @@ async def generate_image(args: argparse.Namespace) -> None: height=args.height, width=args.width, steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, + guidance_scale=guidance_scale, + true_cfg_scale=true_cfg_scale, ) ), ) request_warmup = OpenResponsesRequest( request_id=RequestID(), body=body_warmup ) - input_image = Image.open(args.input_image) if args.input_image else None + warmup_image = ( + Image.open(args.input_image[0]) if args.input_image else None + ) context_warmup = await tokenizer.new_context( - request_warmup, input_image=input_image + request_warmup, input_image=warmup_image ) inputs_warmup = PixelGenerationInputs[PixelContext]( batch={context_warmup.request_id: context_warmup} diff --git a/max/python/max/pipelines/architectures/__init__.py b/max/python/max/pipelines/architectures/__init__.py index fe7e2771c56..61c0a49b417 100644 --- a/max/python/max/pipelines/architectures/__init__.py +++ b/max/python/max/pipelines/architectures/__init__.py @@ -79,6 +79,8 @@ def register_all_models() -> None: from .qwen3_embedding import qwen3_embedding_arch from .qwen3_embedding_modulev3 import qwen3_embedding_modulev3_arch from .qwen3vl_moe import qwen3vl_arch, qwen3vl_moe_arch + from .qwen_image import qwen_image_arch + from .qwen_image_edit import qwen_image_edit_arch, qwen_image_edit_plus_arch from .unified_eagle_llama3 import unified_eagle_llama3_arch from .unified_mtp_deepseekV3 import unified_mtp_deepseekV3_arch @@ -131,6 +133,9 @@ def register_all_models() -> None: qwen3_embedding_modulev3_arch, qwen3vl_arch, qwen3vl_moe_arch, + qwen_image_arch, + qwen_image_edit_arch, + qwen_image_edit_plus_arch, unified_eagle_llama3_arch, unified_mtp_deepseekV3_arch, ] diff --git a/max/python/max/pipelines/architectures/qwen2_5vl/encoder/__init__.py b/max/python/max/pipelines/architectures/qwen2_5vl/encoder/__init__.py new file mode 100644 index 00000000000..d18e229b147 --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen2_5vl/encoder/__init__.py @@ -0,0 +1,17 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .model import Qwen25VLEncoderModel +from .multimodal_encoder import Qwen25VLMultimodalEncoderModel + +__all__ = ["Qwen25VLEncoderModel", "Qwen25VLMultimodalEncoderModel"] diff --git a/max/python/max/pipelines/architectures/qwen2_5vl/encoder/multimodal_encoder.py b/max/python/max/pipelines/architectures/qwen2_5vl/encoder/multimodal_encoder.py new file mode 100644 index 00000000000..b84e31f1211 --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen2_5vl/encoder/multimodal_encoder.py @@ -0,0 +1,483 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Qwen2.5-VL multimodal text encoder helpers. + +This module owns prompt/image encoding that combines the shared module-v2 text +encoder with the Qwen2.5-VL vision encoder. Pipelines such as QwenImageEdit +should import this helper instead of defining an architecture-local copy. +""" + +from __future__ import annotations + +import logging +from typing import Any + +import numpy as np +import numpy.typing as npt +from max.driver import CPU, Buffer, Device +from max.dtype import DType +from max.engine import InferenceSession, Model +from max.graph import DeviceRef, Graph, TensorType, ops +from max.graph.weights import WeightData, Weights +from max.interfaces import TokenBuffer +from max.nn.comm import Signals +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.bfloat16_utils import float32_to_bfloat16_as_uint16 +from max.pipelines.lib.config.config_enums import supported_encoding_dtype +from PIL import Image + +from ..model_config import VisionConfig +from ..nn.data_processing import ( + get_seqlens, + get_window_index, + mrope_pos_ids_3d, +) +from ..nn.qwen_vl_utils import fetch_image +from ..nn.visual_transformer import VisionTransformer +from ..tokenizer import Qwen2_5VLImageProcessor +from ..weight_adapters import QWEN2_5_VL_MODEL_MAPPING +from .model import Qwen25VLEncoderModel + +logger = logging.getLogger(__name__) + +PROMPT_TEMPLATE_DROP_IDX = 64 + + +class Qwen25VLMultimodalEncoderModel: + """Multimodal prompt encoder built on the shared Qwen2.5-VL components.""" + + def __init__( + self, + text_encoder: Qwen25VLEncoderModel, + config: dict[str, Any], + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + session: InferenceSession, + tokenizer: Any, + ) -> None: + self.text_encoder = text_encoder + self.devices = devices + self.session = session + self.tokenizer = tokenizer + self.lang_config = text_encoder.config + self._cached_vision_inputs: dict[ + tuple[int, ...], + tuple[Buffer, Buffer, Buffer, Buffer, Buffer, Buffer, Buffer], + ] = {} + self._cached_scatter_indices: dict[tuple[int, ...], Buffer] = {} + self._cached_token_buffers: dict[tuple[int, ...], Buffer] = {} + + self._image_token_id = self.tokenizer.convert_tokens_to_ids( + "<|image_pad|>" + ) + vision_cfg = config.get("vision_config", {}) + enc_dtype = supported_encoding_dtype(encoding) + device_ref = DeviceRef.from_device(devices[0]) + self.vision_config = VisionConfig( + dtype=enc_dtype, + llm_dtype=enc_dtype, + devices=[device_ref], + patch_size=vision_cfg.get("patch_size", 14), + temporal_patch_size=vision_cfg.get("temporal_patch_size", 2), + in_channels=vision_cfg.get("in_channels", 3), + hidden_size=vision_cfg.get("hidden_size", 1280), + num_attention_heads=vision_cfg.get("num_heads", 16), + depth=vision_cfg.get("depth", 32), + intermediate_size=vision_cfg.get("intermediate_size", 5120), + out_hidden_size=vision_cfg.get( + "out_hidden_size", self.lang_config.hidden_size + ), + fullatt_block_indexes=vision_cfg.get( + "fullatt_block_indexes", + [7, 15, 23, 31], + ), + rms_norm_eps=vision_cfg.get("rms_norm_eps", 1e-6), + window_size=vision_cfg.get("window_size", 112), + spatial_merge_size=vision_cfg.get("spatial_merge_size", 2), + ) + self.image_processor = Qwen2_5VLImageProcessor( + patch_size=self.vision_config.patch_size, + temporal_patch_size=self.vision_config.temporal_patch_size, + merge_size=self.vision_config.spatial_merge_size, + ) + + self._compile_vision_encoder(weights) + self._compile_hidden_state_trimmer() + self._compile_vision_merger() + self._compile_hidden_state_tiler() + + def _compile_vision_encoder(self, weights: Weights) -> None: + device_ref = DeviceRef.from_device(self.devices[0]) + vc = self.vision_config + patch_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size**2 + + vision_state: dict[str, Any] = {} + for key, value in weights.items(): + wd = value.data() + if wd.dtype.is_float() and not wd.dtype.is_float8(): + is_scale = key.endswith(".weight_scale") or key.endswith( + ".input_scale" + ) + if not is_scale: + wd = wd.astype(DType.bfloat16) + + if "patch_embed.proj." in key: + buf = Buffer.from_dlpack(wd.data) + oc, ic, kh, kw, kd = buf.shape + buf = buf.view(dtype=buf.dtype, shape=(oc, ic * kh * kw * kd)) + wd = WeightData( + data=buf, + name=wd.name, + dtype=wd.dtype, + shape=wd.shape.__class__(buf.shape), + quantization_encoding=wd.quantization_encoding, + ) + + mapped = key + for before, after in QWEN2_5_VL_MODEL_MAPPING.items(): + mapped = mapped.replace(before, after) + + if mapped.startswith("vision_encoder."): + vision_state[mapped[len("vision_encoder.") :]] = wd + elif mapped.startswith("merger."): + vision_state[mapped] = wd + + vision_transformer = VisionTransformer(vc) + vision_transformer.load_state_dict( + vision_state, weight_alignment=1, strict=True + ) + + signals = Signals(devices=[device_ref]) + input_types = [ + TensorType( + vc.dtype, + shape=["vision_seq_len", patch_dim], + device=device_ref, + ), + TensorType( + DType.int64, shape=["vision_seq_len", 2], device=device_ref + ), + TensorType( + DType.int64, shape=["window_seq_len"], device=device_ref + ), + TensorType( + DType.uint32, shape=["n_full_seqlens"], device=device_ref + ), + TensorType( + DType.uint32, shape=["n_win_seqlens"], device=device_ref + ), + TensorType(DType.uint32, shape=[1], device=DeviceRef.CPU()), + TensorType(DType.uint32, shape=[1], device=DeviceRef.CPU()), + TensorType(DType.int32, shape=[], device=DeviceRef.CPU()), + *signals.input_types(), + ] + + with Graph("qwen_edit_vision", input_types=input_types) as vision_graph: + ins = vision_graph.inputs + signal_buffers = [inp.buffer for inp in ins[8:]] + outputs = vision_transformer( + pixel_values=[ins[0].tensor], + rot_pos_ids=[ins[1].tensor], + window_index=[ins[2].tensor], + cu_seqlens=[ins[3].tensor], + cu_window_seqlens=[ins[4].tensor], + max_seqlen=[ins[5].tensor], + max_window_seqlen=[ins[6].tensor], + max_grid_size=[ins[7].tensor], + signal_buffers=signal_buffers, + ) + vision_graph.output(outputs[0]) + + self._vision_model: Model = self.session.load( + vision_graph, weights_registry=vision_transformer.state_dict() + ) + self._vision_signals = signals + + def _compile_hidden_state_trimmer(self) -> None: + device_ref = DeviceRef.from_device(self.devices[0]) + hidden_size = self.lang_config.hidden_size + + with Graph( + "qwen_edit_trim_hidden_states", + input_types=[ + TensorType( + self.lang_config.dtype, + shape=["total_seq_len", hidden_size], + device=device_ref, + ) + ], + ) as graph: + hidden_states = graph.inputs[0].tensor + trimmed = ops.slice_tensor( + hidden_states, + [slice(PROMPT_TEMPLATE_DROP_IDX, None), slice(None)], + ) + graph.output(ops.unsqueeze(trimmed, 0)) + + self._hidden_state_trimmer: Model = self.session.load(graph) + + def _compile_vision_merger(self) -> None: + device_ref = DeviceRef.from_device(self.devices[0]) + hidden_size = self.lang_config.hidden_size + + with Graph( + "qwen_edit_merge_vision_embeddings", + input_types=[ + TensorType( + self.lang_config.dtype, + shape=["total_seq_len", hidden_size], + device=device_ref, + ), + TensorType( + self.lang_config.dtype, + shape=["num_image_tokens", hidden_size], + device=device_ref, + ), + TensorType( + DType.int64, + shape=["num_image_tokens", hidden_size], + device=device_ref, + ), + ], + ) as graph: + hidden_states = graph.inputs[0].tensor + vision_embeds = graph.inputs[1].tensor + image_token_indices = graph.inputs[2].tensor + graph.output( + ops.scatter( + input=hidden_states, + updates=vision_embeds, + indices=image_token_indices, + axis=0, + ) + ) + + self._vision_merger: Model = self.session.load(graph) + + def _compile_hidden_state_tiler(self) -> None: + device_ref = DeviceRef.from_device(self.devices[0]) + hidden_size = self.lang_config.hidden_size + + with Graph( + "qwen_edit_tile_hidden_states", + input_types=[ + TensorType( + self.lang_config.dtype, + shape=[1, "trimmed_seq_len", hidden_size], + device=device_ref, + ) + ], + ) as graph: + hidden_states = graph.inputs[0].tensor + graph.output(ops.tile(hidden_states, (2, 1, 1))) + + self._repeat_two_hidden_states: Model = self.session.load(graph) + + def _prepare_images( + self, images: list[npt.NDArray[np.uint8]] + ) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.uint16]]: + processed_images = [ + fetch_image({"image": Image.fromarray(image).convert("RGB")}) + for image in images + ] + processed = self.image_processor( + images=processed_images, return_tensors="np" + ) + processed_dict = ( + processed[0] if isinstance(processed, tuple) else processed + ) + image_grid_thw = np.asarray( + processed_dict["image_grid_thw"], dtype=np.int64 + ) + pixel_values = np.asarray( + processed_dict.get( + "pixel_values", + processed_dict.get("concatenated_pixel_values"), + ) + ) + if pixel_values.dtype == np.uint16: + pixel_values_u16 = pixel_values + else: + pixel_values_u16 = float32_to_bfloat16_as_uint16( + np.ascontiguousarray(pixel_values.astype(np.float32)) + ) + return image_grid_thw, pixel_values_u16 + + def _run_vision_encoder( + self, + image_grid_thw: npt.NDArray[np.int64], + pixel_values_u16: npt.NDArray[np.uint16], + ) -> Buffer: + vc = self.vision_config + device = self.devices[0] + + rot_pos_ids = mrope_pos_ids_3d(image_grid_thw, vc.spatial_merge_size) + window_idx, cu_win_seqlens = get_window_index( + image_grid_thw, + window_size=vc.window_size, + spatial_merge_size=vc.spatial_merge_size, + patch_size=vc.patch_size, + spatial_merge_unit=vc.spatial_merge_size**2, + ) + cu_seqlens, cu_window_seqlens, max_seqlen, max_window_seqlen = ( + get_seqlens(image_grid_thw, cu_win_seqlens) + ) + max_grid_size = int(image_grid_thw[:, 1:].max()) + grid_key = tuple(int(x) for x in image_grid_thw.reshape(-1)) + + if grid_key not in self._cached_vision_inputs: + self._cached_vision_inputs[grid_key] = ( + Buffer.from_numpy( + np.ascontiguousarray(rot_pos_ids.astype(np.int64)) + ).to(device), + Buffer.from_numpy( + np.ascontiguousarray(window_idx.astype(np.int64)) + ).to(device), + Buffer.from_numpy( + np.ascontiguousarray(cu_seqlens.astype(np.uint32)) + ).to(device), + Buffer.from_numpy( + np.ascontiguousarray(cu_window_seqlens.astype(np.uint32)) + ).to(device), + Buffer.from_numpy(np.array([max_seqlen], dtype=np.uint32)), + Buffer.from_numpy( + np.array([max_window_seqlen], dtype=np.uint32) + ), + Buffer.from_numpy(np.array(max_grid_size, dtype=np.int32)), + ) + ( + rot_pos_ids_buf, + window_idx_buf, + cu_seqlens_buf, + cu_window_seqlens_buf, + max_seqlen_buf, + max_window_seqlen_buf, + max_grid_size_buf, + ) = self._cached_vision_inputs[grid_key] + + if vc.dtype == DType.bfloat16: + pv_buf = Buffer.from_numpy( + np.ascontiguousarray(pixel_values_u16) + ).to(device) + pv_buf = pv_buf.view( + dtype=DType.bfloat16, shape=pixel_values_u16.shape + ) + else: + pixel_values = (pixel_values_u16.astype(np.uint32) << 16).view( + np.float32 + ) + if vc.dtype == DType.float16: + pixel_values = pixel_values.astype(np.float16) + pv_buf = Buffer.from_numpy(np.ascontiguousarray(pixel_values)).to( + device + ) + + result = self._vision_model.execute( + pv_buf, + rot_pos_ids_buf, + window_idx_buf, + cu_seqlens_buf, + cu_window_seqlens_buf, + max_seqlen_buf, + max_window_seqlen_buf, + max_grid_size_buf, + *self._vision_signals.buffers(), + ) + return result[0] + + def encode( + self, + tokens: TokenBuffer, + images: list[npt.NDArray[np.uint8]] | None = None, + num_images_per_prompt: int = 1, + ) -> Buffer: + device = self.devices[0] + + pixel_values_u16: npt.NDArray[np.uint16] | None = None + image_grid_thw: npt.NDArray[np.int64] | None = None + if images: + image_grid_thw, pixel_values_u16 = self._prepare_images(images) + + input_ids = ( + np.asarray(tokens.array).flatten().astype(np.int64, copy=False) + ) + token_key = tuple(int(token) for token in input_ids.tolist()) + if token_key not in self._cached_token_buffers: + self._cached_token_buffers[token_key] = Buffer.from_numpy( + np.ascontiguousarray(input_ids) + ).to(device) + token_buf = self._cached_token_buffers[token_key] + embed_result = self.text_encoder._embed_model.execute(token_buf) + lc = self.lang_config + merged_buf = embed_result[0] + + if images: + if image_grid_thw is None or pixel_values_u16 is None: + raise ValueError("vision inputs are required when images exist") + vision_emb = self._run_vision_encoder( + image_grid_thw, pixel_values_u16 + ) + pad_positions = np.where(input_ids == self._image_token_id)[0] + if len(pad_positions) == vision_emb.shape[0]: + scatter_key = tuple(int(x) for x in pad_positions.tolist()) + if scatter_key not in self._cached_scatter_indices: + scatter_indices = np.tile( + pad_positions[:, np.newaxis], + (1, vision_emb.shape[1]), + ).astype(np.int64, copy=False) + self._cached_scatter_indices[scatter_key] = ( + Buffer.from_numpy( + np.ascontiguousarray(scatter_indices) + ).to(device) + ) + merged_buf = self._vision_merger.execute( + merged_buf, + vision_emb, + self._cached_scatter_indices[scatter_key], + )[0] + else: + logger.warning( + "Vision token mismatch: %d pads vs %d embeddings. Skipping merge.", + len(pad_positions), + vision_emb.shape[0], + ) + + hs_buf = self.text_encoder._transform_model.execute(merged_buf)[0] + trimmed_buf = self._hidden_state_trimmer.execute(hs_buf)[0] + if num_images_per_prompt == 1: + return trimmed_buf + if num_images_per_prompt == 2: + return self._repeat_two_hidden_states.execute(trimmed_buf)[0] + + hs_cpu = hs_buf.to(CPU()) + if lc.dtype == DType.bfloat16: + hs_u16 = np.from_dlpack( + hs_cpu.view(dtype=DType.uint16, shape=hs_cpu.shape) + ) + hs_np = (hs_u16.astype(np.uint32) << 16).view(np.float32) + else: + hs_np = np.from_dlpack(hs_cpu).astype(np.float32) + + hs_np = hs_np[PROMPT_TEMPLATE_DROP_IDX:] + hs_np = hs_np[np.newaxis, :, :] + hs_np = np.repeat(hs_np, num_images_per_prompt, axis=0) + + if lc.dtype == DType.bfloat16: + result_u16 = float32_to_bfloat16_as_uint16( + np.ascontiguousarray(hs_np) + ) + buf = Buffer.from_numpy(result_u16).to(device) + return buf.view(dtype=DType.bfloat16, shape=hs_np.shape) + return Buffer.from_numpy(np.ascontiguousarray(hs_np)).to(device) diff --git a/max/python/max/pipelines/architectures/qwen_image_edit/__init__.py b/max/python/max/pipelines/architectures/qwen_image_edit/__init__.py new file mode 100644 index 00000000000..e63fad53391 --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen_image_edit/__init__.py @@ -0,0 +1,21 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .arch import qwen_image_edit_arch, qwen_image_edit_plus_arch +from .pipeline_qwen_image_edit import QwenImageEditPipeline + +__all__ = [ + "QwenImageEditPipeline", + "qwen_image_edit_arch", + "qwen_image_edit_plus_arch", +] diff --git a/max/python/max/pipelines/architectures/qwen_image_edit/arch.py b/max/python/max/pipelines/architectures/qwen_image_edit/arch.py new file mode 100644 index 00000000000..1e18a54ee09 --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen_image_edit/arch.py @@ -0,0 +1,50 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.graph.weights import WeightsFormat +from max.interfaces import PipelineTask +from max.pipelines.core import PixelContext +from max.pipelines.lib import PixelGenerationTokenizer, SupportedArchitecture + +from ..qwen_image.arch import QwenImageArchConfig +from .pipeline_qwen_image_edit import QwenImageEditPipeline + +qwen_image_edit_arch = SupportedArchitecture( + name="QwenImageEditPipeline", + task=PipelineTask.PIXEL_GENERATION, + default_encoding="bfloat16", + supported_encodings={"bfloat16": []}, + example_repo_ids=[ + "Qwen/Qwen-Image-Edit-2511", + ], + pipeline_model=QwenImageEditPipeline, # type: ignore[arg-type] + context_type=PixelContext, + default_weights_format=WeightsFormat.safetensors, + tokenizer=PixelGenerationTokenizer, + config=QwenImageArchConfig, +) + +qwen_image_edit_plus_arch = SupportedArchitecture( + name="QwenImageEditPlusPipeline", + task=PipelineTask.PIXEL_GENERATION, + default_encoding="bfloat16", + supported_encodings={"bfloat16": []}, + example_repo_ids=[ + "Qwen/Qwen-Image-Edit-2511", + ], + pipeline_model=QwenImageEditPipeline, # type: ignore[arg-type] + context_type=PixelContext, + default_weights_format=WeightsFormat.safetensors, + tokenizer=PixelGenerationTokenizer, + config=QwenImageArchConfig, +) diff --git a/max/python/max/pipelines/architectures/qwen_image_edit/model.py b/max/python/max/pipelines/architectures/qwen_image_edit/model.py new file mode 100644 index 00000000000..318bac65367 --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen_image_edit/model.py @@ -0,0 +1,89 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""QwenImage Edit transformer model. + +The edit path uses the same MAX-native module_v2 transformer graph as +text-to-image, but derives condition-token masking dynamically from the image +token IDs. That keeps the graph shape-dynamic and avoids recompiles when edit +requests change image resolution or denoising step count across runs. +""" + +from collections.abc import Callable +from typing import Any + +from max.driver import Buffer, Device +from max.engine import InferenceSession, Model +from max.graph import Graph +from max.graph.weights import Weights +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.interfaces.component_model import ComponentModel + +from ..qwen_image.model_config import QwenImageConfig +from ..qwen_image.qwen_image import QwenImageTransformer2DModel + + +class QwenImageEditTransformerModel(ComponentModel): + """Edit-specific transformer compiled once with dynamic token masking.""" + + def __init__( + self, + config: dict[str, Any], + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + session: InferenceSession, + ) -> None: + super().__init__(config, encoding, devices, weights) + self.session = session + self.config = QwenImageConfig.generate(config, encoding, devices) + self.load_model() + + def load_model(self) -> Callable[..., Any]: + state_dict = {key: value.data() for key, value in self.weights.items()} + + nn_model = QwenImageTransformer2DModel(self.config) + nn_model.load_state_dict(state_dict, weight_alignment=1, strict=True) + self.state_dict = nn_model.state_dict() + + with Graph( + "qwen_image_edit_transformer", + input_types=nn_model.input_types(), + ) as graph: + outputs = nn_model(*(value.tensor for value in graph.inputs)) + if isinstance(outputs, tuple): + graph.output(*outputs) + else: + graph.output(outputs) + + self.model: Model = self.session.load( + graph, + weights_registry=self.state_dict, + ) + return self.model.execute + + def __call__( + self, + hidden_states: Buffer, + encoder_hidden_states: Buffer, + timestep: Buffer, + img_ids: Buffer, + txt_ids: Buffer, + ) -> Any: + return self.model.execute( + hidden_states, + encoder_hidden_states, + timestep, + img_ids, + txt_ids, + ) diff --git a/max/python/max/pipelines/architectures/qwen_image_edit/pipeline_qwen_image_edit.py b/max/python/max/pipelines/architectures/qwen_image_edit/pipeline_qwen_image_edit.py new file mode 100644 index 00000000000..70568ea7a2d --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen_image_edit/pipeline_qwen_image_edit.py @@ -0,0 +1,1166 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""QwenImage edit diffusion pipeline. + +Key differences from QwenImagePipeline: +- Multimodal prompt encoding when edit images are present +- VAE image-conditioning path that concatenates condition latents to noise +- True CFG with two forward passes (positive + negative prompts) +""" + +from dataclasses import dataclass, field +from typing import Any, Literal + +import numpy as np +import numpy.typing as npt +from max.driver import CPU, Buffer, Device +from max.dtype import DType +from max.graph import TensorType, TensorValue, ops +from max.interfaces import TokenBuffer +from max.pipelines.core import PixelContext +from max.pipelines.lib.bfloat16_utils import float32_to_bfloat16_as_uint16 +from max.pipelines.lib.interfaces import DiffusionPipeline, PixelModelInputs +from max.pipelines.lib.interfaces.diffusion_pipeline import ( + max_compile, +) +from max.profiler import Tracer, traced + +from ..autoencoders.autoencoder_kl_qwen_image import AutoencoderKLQwenImageModel +from ..qwen2_5vl.encoder import ( + Qwen25VLEncoderModel, + Qwen25VLMultimodalEncoderModel, +) +from .model import QwenImageEditTransformerModel + + +@dataclass(kw_only=True) +class QwenImageEditModelInputs(PixelModelInputs): + """QwenImage-edit-specific PixelModelInputs. + + For image editing the recommended usage is + ``--guidance-scale 1.0 --true-cfg-scale 4.0``. + ``guidance_scale`` is unused (model is not guidance-distilled); + ``true_cfg_scale`` drives the two-pass CFG behavior. + """ + + width: int = 1024 + height: int = 1024 + guidance_scale: float = 1.0 + true_cfg_scale: float = 4.0 + num_inference_steps: int = 50 + num_images_per_prompt: int = 1 + prompt_images: list[npt.NDArray[np.uint8]] | None = None + vae_condition_images: list[npt.NDArray[np.uint8]] | None = None + + +@dataclass +class QwenImageEditPipelineOutput: + """Container for QwenImage edit pipeline results.""" + + images: np.ndarray | list + + +@dataclass +class QwenImageEditCache: + """Runtime cache for reusable edit-path buffers.""" + + sigmas: dict[str, Buffer] = field(default_factory=dict) + text_ids: dict[str, Buffer] = field(default_factory=dict) + shape_carriers: dict[int, Buffer] = field(default_factory=dict) + cfg_scales: dict[float, Buffer] = field(default_factory=dict) + noise_token_counts: dict[int, Buffer] = field(default_factory=dict) + condition_image_ids: dict[tuple[int, int, int], Buffer] = field( + default_factory=dict + ) + latent_image_ids: dict[tuple[int, int, int], Buffer] = field( + default_factory=dict + ) + prompt_tokens: dict[tuple[int, ...], Buffer] = field(default_factory=dict) + + +class QwenImageEditPipeline(DiffusionPipeline): + """Diffusion pipeline for QwenImage image editing. + + Wires together: + - Qwen2.5-VL prompt encoder + - QwenImage edit transformer denoiser + - QwenImage 3D VAE (with latents_mean/std normalization) + - Image-conditioning path (VAE encode -> normalize -> patchify -> concat) + """ + + vae: AutoencoderKLQwenImageModel + text_encoder: Qwen25VLEncoderModel + transformer: QwenImageEditTransformerModel + + components = { + "vae": AutoencoderKLQwenImageModel, + "text_encoder": Qwen25VLEncoderModel, + "transformer": QwenImageEditTransformerModel, + } + + # NOTE: + # `prompt_encoder` is intentionally not part of `components`. + # + # QwenImageEdit needs a multimodal prompt path that reuses the already-loaded + # `text_encoder` and layers a vision encoder + prompt/image merge logic on top. + # That makes it closer to an edit-specific helper than an independent pipeline + # submodel. Keeping it out of `components` avoids adding special loading rules + # to the shared DiffusionPipeline base just for this dependency shape. + prompt_encoder: Qwen25VLMultimodalEncoderModel | None = None + _prompt_encoder_config: dict[str, Any] | None = None + _prompt_encoder_weight_paths: list[str] | None = None + + def __init__(self, *args: Any, **kwargs: Any) -> None: + if args and len(args) >= 4: + self._weight_paths = args[3] + else: + self._weight_paths = kwargs.get("weight_paths", []) + super().__init__(*args, **kwargs) + + def init_remaining_components(self) -> None: + """Initialize derived attributes that depend on loaded components.""" + self.vae_scale_factor = 8 + + self._compile_runtime_helpers() + self.cache: QwenImageEditCache = QwenImageEditCache() + + diffusers_config = self.pipeline_config.model.diffusers_config + components_config = diffusers_config.get("components", {}) + self._prompt_encoder_config = components_config.get( + "text_encoder", {} + ).get("config_dict", {}) + + relative_paths = self._resolve_relative_component_paths() + text_encoder_rel_paths = relative_paths.get("text_encoder", []) + self._prompt_encoder_weight_paths = self._resolve_absolute_paths( + self._weight_paths, text_encoder_rel_paths + ) + + def _compile_runtime_helpers(self) -> None: + """Compile the helper graphs used by the edit pipeline runtime.""" + + def duplicate_batch(value: TensorValue) -> TensorValue: + return ops.concat([value, value], axis=0) + + def concat_sequence_pair( + left: TensorValue, + right: TensorValue, + ) -> TensorValue: + return ops.concat([left, right], axis=1) + + device = self.transformer.devices[0] + self.cached_patchify_and_pack = max_compile( + self._patchify_and_pack, + input_types=[ + TensorType( + DType.float32, + shape=["batch", "channels", "height", 2, "width", 2], + device=device, + ) + ], + ) + + self.cached_prepare_scheduler = max_compile( + self.prepare_scheduler, + input_types=[ + TensorType( + DType.float32, + shape=["num_sigmas"], + device=device, + ) + ], + ) + + dtype = self.transformer.config.dtype + self.cached_scheduler_step = max_compile( + self.scheduler_step, + input_types=[ + TensorType( + dtype, shape=["batch", "seq", "channels"], device=device + ), + TensorType( + dtype, + shape=["batch", "pred_seq", "channels"], + device=device, + ), + TensorType(DType.float32, shape=[1], device=device), + TensorType( + DType.int64, + shape=["batch", "seq", 3], + device=device, + ), + ], + ) + + packed_channels = self.transformer.config.in_channels + self.cached_postprocess_latents = max_compile( + self._postprocess_latents, + input_types=[ + TensorType( + dtype, + shape=["batch", "height", "width", packed_channels], + device=device, + ), + TensorType(dtype, shape=[16], device=device), + TensorType(dtype, shape=[16], device=device), + ], + ) + + self.cached_cfg_blend = max_compile( + self._cfg_blend, + input_types=[ + TensorType( + dtype, shape=["batch", "seq", "channels"], device=device + ), + TensorType( + dtype, shape=["batch", "seq", "channels"], device=device + ), + TensorType(DType.float32, shape=[1], device=device), + ], + ) + + self.cached_reshape_latents = max_compile( + self._reshape_latents, + input_types=[ + TensorType( + dtype, + shape=["batch", "seq", packed_channels], + device=device, + ), + TensorType(DType.float32, shape=["packed_h"], device=CPU()), + TensorType(DType.float32, shape=["packed_w"], device=CPU()), + ], + ) + + text_dtype = self.text_encoder.config.dtype + text_device = self.text_encoder.devices[0] + hidden_size = self.text_encoder.config.hidden_size + self.cached_trim_prompt_embeddings = max_compile( + self._trim_prompt_embeddings, + input_types=[ + TensorType( + text_dtype, + shape=["seq", hidden_size], + device=text_device, + ) + ], + ) + + self.cached_duplicate_prompt_embeddings = max_compile( + duplicate_batch, + input_types=[ + TensorType( + text_dtype, + shape=[1, "trimmed_seq_len", hidden_size], + device=text_device, + ) + ], + ) + + vae_dtype = self.vae.config.dtype + vae_device = self.vae.devices[0] + self.cached_reshape_vae_latents = max_compile( + self._reshape_vae_latents, + input_types=[ + TensorType( + vae_dtype, + shape=["batch", "channels", "height", "width"], + device=vae_device, + ) + ], + ) + + z_dim = self.vae.config.z_dim + self.cached_normalize_and_pack_image_latent = max_compile( + self._normalize_and_pack_image_latent, + input_types=[ + TensorType( + vae_dtype, + shape=["batch", z_dim, "height", 2, "width", 2], + device=vae_device, + ), + TensorType(vae_dtype, shape=[z_dim], device=vae_device), + TensorType(vae_dtype, shape=[z_dim], device=vae_device), + ], + ) + + self.cached_concat_image_latents = max_compile( + self.concat_image_latents, + input_types=[ + TensorType( + dtype, shape=["batch", "seq", "channels"], device=device + ), + TensorType( + dtype, shape=["batch", "img_seq", "channels"], device=device + ), + TensorType( + DType.int64, shape=["batch", "seq", 3], device=device + ), + TensorType( + DType.int64, shape=["batch", "img_seq", 3], device=device + ), + ], + ) + + self.cached_concat_image_sequences = max_compile( + concat_sequence_pair, + input_types=[ + TensorType( + dtype, shape=["batch", "seq", "channels"], device=device + ), + TensorType( + dtype, shape=["batch", "img_seq", "channels"], device=device + ), + ], + ) + + self.cached_concat_image_ids = max_compile( + concat_sequence_pair, + input_types=[ + TensorType( + DType.int64, shape=["batch", "seq", 3], device=device + ), + TensorType( + DType.int64, shape=["batch", "img_seq", 3], device=device + ), + ], + ) + + self.cached_duplicate_condition_latents = max_compile( + duplicate_batch, + input_types=[ + TensorType( + dtype, + shape=[1, "img_seq", packed_channels], + device=device, + ) + ], + ) + + self.cached_duplicate_condition_ids = max_compile( + duplicate_batch, + input_types=[ + TensorType(DType.int64, shape=[1, "img_seq", 3], device=device) + ], + ) + + self.cached_extract_noise_latents = max_compile( + self._extract_noise_latents, + input_types=[ + TensorType( + dtype, + shape=["batch", "seq", packed_channels], + device=device, + ), + TensorType(DType.int64, shape=[1], device=CPU()), + ], + ) + + def _init_prompt_encoder(self) -> None: + if self.prompt_encoder is not None: + return + + # NOTE: + # This is a local assembly step, not a normal ComponentModel load. + # + # The edit prompt encoder depends on the already-instantiated + # `self.text_encoder`, reuses the text-encoder weight set, and adds the + # Qwen2.5-VL vision path needed for multimodal prompt encoding. If we + # tried to model it as a regular pipeline component, the shared loader + # would need special-case dependency wiring for "component B depends on + # loaded component A", which is more confusing than keeping the assembly + # here in the edit pipeline. + from max.graph.weights import load_weights + + if self._prompt_encoder_config is None: + raise ValueError("prompt encoder config is not initialized") + if self._prompt_encoder_weight_paths is None: + raise ValueError("prompt encoder weight paths are not initialized") + + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained( + self.pipeline_config.model.model_path, + subfolder="tokenizer", + ) + + self.prompt_encoder = Qwen25VLMultimodalEncoderModel( + text_encoder=self.text_encoder, + config=self._prompt_encoder_config, + encoding=self.pipeline_config.model.quantization_encoding, + devices=self.devices, + weights=load_weights(self._prompt_encoder_weight_paths), + session=self.session, + tokenizer=tokenizer, + ) + + def _get_prompt_encoder(self) -> Qwen25VLMultimodalEncoderModel: + # NOTE: + # We only need the multimodal prompt path when edit images are present. + # Text-only prompt encoding stays on `self.text_encoder`, so we avoid + # paying the extra vision-side setup cost unless the request actually + # uses image conditioning. + if self.prompt_encoder is None: + self._init_prompt_encoder() + if self.prompt_encoder is None: + raise ValueError("failed to initialize prompt_encoder") + return self.prompt_encoder + + def _encode_prompt( + self, + *, + tokens: TokenBuffer, + prompt_images: list[npt.NDArray[np.uint8]], + num_images_per_prompt: int, + prompt_encoder: Qwen25VLMultimodalEncoderModel | None, + ) -> Buffer: + if prompt_images: + assert prompt_encoder is not None + return prompt_encoder.encode( + tokens=tokens, + images=prompt_images, + num_images_per_prompt=num_images_per_prompt, + ) + + return self.prepare_prompt_embeddings( + tokens=tokens, + num_images_per_prompt=num_images_per_prompt, + ) + + @staticmethod + def _resolve_condition_images( + model_inputs: QwenImageEditModelInputs, + ) -> tuple[list[npt.NDArray[np.uint8]], list[npt.NDArray[np.uint8]]]: + prompt_images = ( + model_inputs.prompt_images or model_inputs.input_images or [] + ) + vae_condition_images = ( + model_inputs.vae_condition_images or model_inputs.input_images or [] + ) + return prompt_images, vae_condition_images + + def _prepare_negative_prompt_embeddings( + self, + *, + model_inputs: QwenImageEditModelInputs, + prompt_images: list[npt.NDArray[np.uint8]], + prompt_encoder: Qwen25VLMultimodalEncoderModel | None, + ) -> Buffer | None: + if ( + model_inputs.true_cfg_scale <= 1.0 + or model_inputs.negative_tokens is None + ): + return None + + return self._encode_prompt( + tokens=model_inputs.negative_tokens, + prompt_images=prompt_images, + num_images_per_prompt=model_inputs.num_images_per_prompt, + prompt_encoder=prompt_encoder, + ) + + def _prepare_condition_latents( + self, + *, + vae_condition_images: list[npt.NDArray[np.uint8]], + batch_size: int, + device: Device, + ) -> tuple[Buffer | None, Buffer | None]: + if not vae_condition_images: + return None, None + + image_bufs = [ + self._numpy_image_to_buffer(image) for image in vae_condition_images + ] + return self.prepare_image_latents( + images=image_bufs, + batch_size=batch_size, + device=device, + ) + + def _prepare_text_ids_for_embeddings( + self, + *, + embeddings: Buffer, + batch_size: int, + device: Device, + max_vid_index: int, + ) -> Buffer: + seq_len = embeddings.shape[1] + cache_key = f"{batch_size}_{seq_len}_{max_vid_index}" + if cache_key not in self.cache.text_ids: + self.cache.text_ids[cache_key] = self._prepare_text_ids( + batch_size, seq_len, device, max_vid_index + ) + return self.cache.text_ids[cache_key] + + def prepare_inputs(self, context: PixelContext) -> QwenImageEditModelInputs: # type: ignore[override] + """Convert a PixelContext into QwenImageEditModelInputs.""" + return QwenImageEditModelInputs.from_context(context) + + def _patchify_and_pack(self, latents: TensorValue) -> TensorValue: + """(B,C,H//2,2,W//2,2) → (B, H//2*W//2, C*4)""" + latents = ops.cast(latents, self.transformer.config.dtype) + batch = latents.shape[0] + c = latents.shape[1] + h2 = latents.shape[2] + w2 = latents.shape[4] + latents = ops.permute(latents, (0, 1, 3, 5, 2, 4)) + latents = ops.reshape(latents, (batch, c * 4, h2, w2)) + latents = ops.reshape(latents, (batch, c * 4, h2 * w2)) + return ops.permute(latents, (0, 2, 1)) + + def _postprocess_latents( + self, + latents_bhwc: TensorValue, + latents_mean: TensorValue, + latents_std: TensorValue, + ) -> TensorValue: + """Unpatchify (B,H,W,C*4) → (B,z_dim,H*2,W*2) and denormalize.""" + batch = latents_bhwc.shape[0] + h = latents_bhwc.shape[1] + w = latents_bhwc.shape[2] + c = latents_bhwc.shape[3] + z_dim = c // 4 + latents = ops.permute(latents_bhwc, (0, 3, 1, 2)) + latents = ops.reshape(latents, (batch, z_dim, 2, 2, h, w)) + latents = ops.permute(latents, (0, 1, 4, 2, 5, 3)) + latents = ops.reshape(latents, (batch, z_dim, h * 2, w * 2)) + mean_r = ops.reshape(latents_mean, (1, z_dim, 1, 1)) + std_r = ops.reshape(latents_std, (1, z_dim, 1, 1)) + return latents * std_r + mean_r + + def _normalize_and_pack_image_latent( + self, + image_latents: TensorValue, + latents_mean: TensorValue, + latents_std: TensorValue, + ) -> TensorValue: + """Normalize VAE output, then patchify+pack to (B, seq, C*4).""" + batch = image_latents.shape[0] + c = image_latents.shape[1] + h2 = image_latents.shape[2] + w2 = image_latents.shape[4] + mean_r = ops.reshape(latents_mean, (1, c, 1, 1)) + std_r = ops.reshape(latents_std, (1, c, 1, 1)) + raw = ops.reshape(image_latents, (batch, c, h2 * 2, w2 * 2)) + raw = (raw - mean_r) / std_r + packed = ops.reshape(raw, (batch, c, h2, 2, w2, 2)) + packed = ops.permute(packed, (0, 1, 3, 5, 2, 4)) + packed = ops.reshape(packed, (batch, c * 4, h2, w2)) + packed = ops.reshape(packed, (batch, c * 4, h2 * w2)) + return ops.permute(packed, (0, 2, 1)) + + def _cfg_blend( + self, + cond_pred: TensorValue, + uncond_pred: TensorValue, + cfg_scale: TensorValue, + ) -> TensorValue: + scale = ops.cast(cfg_scale, cond_pred.dtype) + return uncond_pred + scale * (cond_pred - uncond_pred) + + def _reshape_latents( + self, + latents_bsc: TensorValue, + h_carrier: TensorValue, + w_carrier: TensorValue, + ) -> TensorValue: + batch = latents_bsc.shape[0] + h = h_carrier.shape[0] + w = w_carrier.shape[0] + channels = latents_bsc.shape[2] + latents_bsc = ops.rebind(latents_bsc, [batch, h * w, channels]) + return ops.reshape(latents_bsc, (batch, h, w, channels)) + + def _reshape_vae_latents(self, x: TensorValue) -> TensorValue: + x = ops.rebind( + x, + [ + x.shape[0], + x.shape[1], + (x.shape[2] // 2) * 2, + (x.shape[3] // 2) * 2, + ], + ) + return ops.reshape( + x, + (x.shape[0], x.shape[1], x.shape[2] // 2, 2, x.shape[3] // 2, 2), + ) + + def concat_image_latents( + self, + latents: TensorValue, + image_latents: TensorValue, + latent_image_ids: TensorValue, + image_latent_ids: TensorValue, + ) -> tuple[TensorValue, TensorValue]: + return ( + ops.concat([latents, image_latents], axis=1), + ops.concat([latent_image_ids, image_latent_ids], axis=1), + ) + + def _extract_noise_latents( + self, latents: TensorValue, num_noise_tokens: TensorValue + ) -> TensorValue: + return ops.slice_tensor( + latents, + [ + slice(None), + (slice(0, num_noise_tokens), "noise_tokens"), + slice(None), + ], + ) + + def scheduler_step( + self, + latents: TensorValue, + noise_pred: TensorValue, + dt: TensorValue, + img_ids: TensorValue, + ) -> TensorValue: + """Single Euler step that updates only the noise-token prefix.""" + lat_dtype = latents.dtype + updated_latents = ops.cast(latents, DType.float32) + noise_pred = ops.rebind( + noise_pred, + [latents.shape[0], latents.shape[1], latents.shape[2]], + ) + updated_latents = updated_latents + dt * noise_pred + updated_latents = ops.cast(updated_latents, lat_dtype) + + token_types = img_ids[:, :, 0] + is_condition_token = ops.not_equal( + token_types, + ops.constant(0, DType.int64, device=token_types.device), + ) + condition_token_mask = ops.broadcast_to( + ops.unsqueeze(is_condition_token, -1), + latents.shape, + ) + return ops.where(condition_token_mask, latents, updated_latents) + + def prepare_scheduler( + self, sigmas: TensorValue + ) -> tuple[TensorValue, TensorValue]: + """Precompute timesteps and dt values from sigmas.""" + sigmas_curr = ops.slice_tensor(sigmas, [slice(0, -1)]) + sigmas_next = ops.slice_tensor(sigmas, [slice(1, None)]) + return ( + ops.cast(sigmas_curr, self.transformer.config.dtype), + sigmas_next - sigmas_curr, + ) + + # ── prompt encoding ─────────────────────────────────────────────────── + + PROMPT_TEMPLATE_DROP_IDX = 34 + + def _trim_prompt_embeddings( + self, hidden_states: TensorValue + ) -> TensorValue: + trimmed = ops.slice_tensor( + hidden_states, + [slice(self.PROMPT_TEMPLATE_DROP_IDX, None), slice(None)], + ) + return ops.unsqueeze(trimmed, 0) + + def prepare_prompt_embeddings( + self, + tokens: TokenBuffer, + num_images_per_prompt: int = 1, + ) -> Buffer: + device = self.text_encoder.devices[0] + text_input_ids_np = np.asarray(tokens.array).flatten() + token_key = tuple(int(token) for token in text_input_ids_np.tolist()) + if token_key not in self.cache.prompt_tokens: + self.cache.prompt_tokens[token_key] = Buffer.from_dlpack( + np.ascontiguousarray(text_input_ids_np) + ).to(device) + token_buf = self.cache.prompt_tokens[token_key] + + hidden_states_all = self.text_encoder(token_buf) + hs_buf = hidden_states_all[-1] + + trimmed = self.cached_trim_prompt_embeddings(hs_buf) + if num_images_per_prompt == 1: + return trimmed + if num_images_per_prompt == 2: + return self.cached_duplicate_prompt_embeddings(trimmed) + + hs_cpu = hs_buf.to(CPU()) + if self.text_encoder.config.dtype == DType.bfloat16: + hs_u16 = np.from_dlpack( + hs_cpu.view(dtype=DType.uint16, shape=hs_cpu.shape) + ) + hs_np = (hs_u16.astype(np.uint32) << 16).view(np.float32) + else: + hs_np = np.from_dlpack(hs_cpu).astype(np.float32) + + hs_np = hs_np[self.PROMPT_TEMPLATE_DROP_IDX :] + hs_np = hs_np[np.newaxis, :, :] + + if num_images_per_prompt != 1: + hs_np = np.broadcast_to( + hs_np, + (num_images_per_prompt, hs_np.shape[1], hs_np.shape[2]), + ).copy() + + if self.text_encoder.config.dtype == DType.bfloat16: + result_u16 = float32_to_bfloat16_as_uint16( + np.ascontiguousarray(hs_np) + ) + buf = Buffer.from_numpy(result_u16).to(device) + return buf.view(dtype=DType.bfloat16, shape=hs_np.shape) + return Buffer.from_numpy(np.ascontiguousarray(hs_np)).to(device) + + # ── position ID helpers ─────────────────────────────────────────────── + + @staticmethod + def _prepare_text_ids( + batch_size: int, seq_len: int, device: Device, max_vid_index: int = 0 + ) -> Buffer: + """Create 3D text position IDs in (T, H, W) format.""" + tok_positions = np.arange(seq_len, dtype=np.int64) + max_vid_index + coords = np.stack( + [tok_positions, tok_positions, tok_positions], axis=-1 + ) + text_ids = np.broadcast_to( + coords[np.newaxis, :, :], + (batch_size, coords.shape[0], coords.shape[1]), + ).copy() + return Buffer.from_dlpack(text_ids).to(device) + + @staticmethod + def _prepare_image_ids( + batch_size: int, height: int, width: int, device: Device + ) -> Buffer: + """Create 3D image position IDs in (T, H, W) format.""" + t_coords = np.zeros((height, width), dtype=np.int64) + h_c = np.arange(height, dtype=np.int64) - (height - height // 2) + w_c = np.arange(width, dtype=np.int64) - (width - width // 2) + h_coords, w_coords = np.meshgrid(h_c, w_c, indexing="ij") + coords = np.stack([t_coords, h_coords, w_coords], axis=-1).reshape( + -1, 3 + ) + image_ids = np.broadcast_to( + coords[np.newaxis, :, :], + (batch_size, coords.shape[0], coords.shape[1]), + ).copy() + return Buffer.from_dlpack(image_ids).to(device) + + @staticmethod + def _prepare_condition_image_ids( + batch_size: int, + height: int, + width: int, + device: Device, + image_index: int = 0, + ) -> Buffer: + """Condition-image IDs with T=image_index+1 (noise tokens use T=0). + + For multi-image editing each condition image needs a distinct T + coordinate so the transformer can distinguish them via RoPE: + noise → T=0, first image → T=1, second image → T=2, etc. + """ + t_coords = np.full((height, width), image_index + 1, dtype=np.int64) + h_c = np.arange(height, dtype=np.int64) - (height - height // 2) + w_c = np.arange(width, dtype=np.int64) - (width - width // 2) + h_coords, w_coords = np.meshgrid(h_c, w_c, indexing="ij") + coords = np.stack([t_coords, h_coords, w_coords], axis=-1).reshape( + -1, 3 + ) + condition_image_ids = np.broadcast_to( + coords[np.newaxis, :, :], + (batch_size, coords.shape[0], coords.shape[1]), + ).copy() + return Buffer.from_dlpack(condition_image_ids).to(device) + + def _get_condition_image_ids( + self, height: int, width: int, device: Device, image_index: int = 0 + ) -> Buffer: + cache_key = (height, width, image_index) + if cache_key not in self.cache.condition_image_ids: + self.cache.condition_image_ids[cache_key] = ( + self._prepare_condition_image_ids( + 1, + height, + width, + device, + image_index=image_index, + ) + ) + return self.cache.condition_image_ids[cache_key] + + # ── latent preprocessing ────────────────────────────────────────────── + + def preprocess_latents( + self, + latents: npt.NDArray[np.float32], + latent_image_ids: npt.NDArray[np.float32], + ) -> tuple[Buffer, Buffer]: + latents_np = np.asarray(latents) + b, c, h, w = latents_np.shape + latents_6d = latents_np.reshape(b, c, h // 2, 2, w // 2, 2) + device = self.transformer.devices[0] + latents_packed = self.cached_patchify_and_pack( + Buffer.from_dlpack(np.ascontiguousarray(latents_6d)).to(device) + ) + ids_key = (b, h, w) + if ids_key not in self.cache.latent_image_ids: + self.cache.latent_image_ids[ids_key] = Buffer.from_dlpack( + np.asarray(latent_image_ids, dtype=np.int64) + ).to(device) + ids_buf = self.cache.latent_image_ids[ids_key] + return latents_packed, ids_buf + + # ── image conditioning ──────────────────────────────────────────────── + + def _numpy_image_to_buffer(self, image: npt.NDArray[np.uint8]) -> Buffer: + if image.ndim == 3 and image.shape[2] == 4: + image = image[:, :, :3] + img_array = (image.astype(np.float32) / 127.5) - 1.0 + img_array = np.ascontiguousarray( + np.expand_dims(np.transpose(img_array, (2, 0, 1)), 0) + ) + vae_dtype = self.vae.config.dtype + device = self.vae.devices[0] + if vae_dtype == DType.bfloat16: + u16 = float32_to_bfloat16_as_uint16(img_array) + buf = Buffer.from_numpy(u16).to(device) + return buf.view(dtype=DType.bfloat16, shape=img_array.shape) + if vae_dtype == DType.float16: + img_array = img_array.astype(np.float16) + return Buffer.from_dlpack(img_array).to(device) + + def _encode_single_image( + self, + image: Buffer, + device: Device, + image_index: int = 0, + ) -> tuple[Buffer, Buffer]: + latents_mean = self.vae.latents_mean_tensor + latents_std = self.vae.latents_std_tensor + if latents_mean is None or latents_std is None: + raise ValueError("VAE latents_mean/latents_std are required.") + + raw_latents = self.vae.encode(image.to(device)) + _, _, raw_h, raw_w = raw_latents.shape + + latents_6d = self.cached_reshape_vae_latents(raw_latents) + image_latents = self.cached_normalize_and_pack_image_latent( + latents_6d, latents_mean, latents_std + ) + image_ids = self._get_condition_image_ids( + raw_h // 2, + raw_w // 2, + device, + image_index=image_index, + ) + return image_latents, image_ids + + def prepare_image_latents( + self, images: list[Buffer], batch_size: int, device: Device + ) -> tuple[Buffer, Buffer]: + all_latents: list[Buffer] = [] + all_ids: list[Buffer] = [] + for idx, img in enumerate(images): + lat, ids = self._encode_single_image(img, device, image_index=idx) + all_latents.append(lat) + all_ids.append(ids) + + if len(all_latents) == 1: + image_latents, image_ids = all_latents[0], all_ids[0] + else: + image_latents = all_latents[0] + for image_latent in all_latents[1:]: + image_latents = self.cached_concat_image_sequences( + image_latents, + image_latent, + ) + image_ids = all_ids[0] + for ids in all_ids[1:]: + image_ids = self.cached_concat_image_ids(image_ids, ids) + + if batch_size > 1: + if batch_size == 2: + return ( + self.cached_duplicate_condition_latents(image_latents), + self.cached_duplicate_condition_ids(image_ids), + ) + lat_np = np.from_dlpack(image_latents.to(CPU())) + image_latents = Buffer.from_dlpack( + np.broadcast_to( + lat_np, + (batch_size, lat_np.shape[1], lat_np.shape[2]), + ).copy() + ).to(device) + ids_np = np.from_dlpack(image_ids.to(CPU())) + image_ids = Buffer.from_dlpack( + np.broadcast_to( + ids_np, + (batch_size, ids_np.shape[1], ids_np.shape[2]), + ).copy() + ).to(device) + + return image_latents, image_ids + + # ── decode ──────────────────────────────────────────────────────────── + + def _get_shape_carriers( + self, h_latent: int, w_latent: int + ) -> tuple[Buffer, Buffer]: + for n in (h_latent, w_latent): + if n not in self.cache.shape_carriers: + self.cache.shape_carriers[n] = Buffer.from_dlpack( + np.empty(n, dtype=np.float32) + ) + return ( + self.cache.shape_carriers[h_latent], + self.cache.shape_carriers[w_latent], + ) + + def decode_latents( + self, + latents: Buffer, + height: int, + width: int, + output_type: Literal["np", "latent"] = "np", + ) -> np.ndarray | Buffer: + """Decode packed latents into an image array.""" + if output_type == "latent": + return latents + + h_latent = height // (self.vae_scale_factor * 2) + w_latent = width // (self.vae_scale_factor * 2) + + latents_mean = self.vae.latents_mean_tensor + latents_std = self.vae.latents_std_tensor + if latents_mean is None or latents_std is None: + raise ValueError("VAE latents_mean/latents_std not loaded.") + + h_carrier, w_carrier = self._get_shape_carriers(h_latent, w_latent) + latents_bhwc = self.cached_reshape_latents( + latents, h_carrier, w_carrier + ) + latents_decoded = self.cached_postprocess_latents( + latents_bhwc, latents_mean, latents_std + ) + decoded = self.vae.decode(latents_decoded) + return self._image_to_flat_hwc(self._to_numpy(decoded)) + + def _to_numpy(self, image: Any) -> np.ndarray: + cpu_image = image.to(CPU()) + try: + return np.from_dlpack(cpu_image).astype(np.float32) + except (RuntimeError, TypeError): + from max.experimental.tensor import Tensor as _Tensor + + if isinstance(cpu_image, _Tensor): + return np.from_dlpack(cpu_image.cast(DType.float32)).astype( + np.float32 + ) + return np.from_dlpack( + _Tensor(storage=cpu_image).cast(DType.float32) + ).astype(np.float32) + + @staticmethod + def _image_to_flat_hwc(image: np.ndarray) -> np.ndarray: + img = np.asarray(image) + while img.ndim > 3: + img = img.squeeze(0) + if img.ndim == 3 and img.shape[0] == 3: + img = np.transpose(img, (1, 2, 0)) + return img.astype(np.float32, copy=False) + + # ── main execute ────────────────────────────────────────────────────── + + @traced + def execute( # type: ignore[override] + self, + model_inputs: QwenImageEditModelInputs, + output_type: Literal["np", "latent"] = "np", + ) -> QwenImageEditPipelineOutput: + """Run the QwenImageEdit denoising loop and decode outputs.""" + device = self.transformer.devices[0] + + # Phase 1: prompt, latent, and conditioning preparation. + prompt_images, vae_condition_images = self._resolve_condition_images( + model_inputs + ) + has_images = bool(prompt_images) + prompt_encoder = self._get_prompt_encoder() if has_images else None + + prompt_embeds = self._encode_prompt( + tokens=model_inputs.tokens, + prompt_images=prompt_images, + num_images_per_prompt=model_inputs.num_images_per_prompt, + prompt_encoder=prompt_encoder, + ) + batch_size = int(prompt_embeds.shape[0]) + + do_true_cfg = model_inputs.true_cfg_scale > 1.0 + negative_prompt_embeds = self._prepare_negative_prompt_embeddings( + model_inputs=model_inputs, + prompt_images=prompt_images, + prompt_encoder=prompt_encoder, + ) + + latents, latent_image_ids = self.preprocess_latents( + model_inputs.latents, model_inputs.latent_image_ids + ) + noise_token_count_value = int(latents.shape[1]) + if noise_token_count_value not in self.cache.noise_token_counts: + self.cache.noise_token_counts[noise_token_count_value] = ( + Buffer.from_numpy( + np.array([noise_token_count_value], dtype=np.int64) + ) + ) + noise_token_count = self.cache.noise_token_counts[ + noise_token_count_value + ] + + image_latents, image_latent_ids = self._prepare_condition_latents( + vae_condition_images=vae_condition_images, + batch_size=batch_size, + device=device, + ) + + h_latent = model_inputs.height // (self.vae_scale_factor * 2) + w_latent = model_inputs.width // (self.vae_scale_factor * 2) + max_vid_index = max(h_latent // 2, w_latent // 2) + + text_ids = self._prepare_text_ids_for_embeddings( + embeddings=prompt_embeds, + batch_size=batch_size, + device=device, + max_vid_index=max_vid_index, + ) + + # Phase 2: classifier-free guidance setup. + negative_text_ids: Buffer | None = None + if do_true_cfg and negative_prompt_embeds is not None: + negative_text_ids = self._prepare_text_ids_for_embeddings( + embeddings=negative_prompt_embeds, + batch_size=batch_size, + device=device, + max_vid_index=max_vid_index, + ) + + # Phase 3: scheduler and loop-invariant inputs. + num_inference_steps = model_inputs.num_inference_steps + sigmas_key = f"{num_inference_steps}_{latents.shape[1]}" + if sigmas_key not in self.cache.sigmas: + self.cache.sigmas[sigmas_key] = Buffer.from_dlpack( + model_inputs.sigmas + ).to(device) + with Tracer("prepare_scheduler"): + all_timesteps, all_dts = self.cached_prepare_scheduler( + self.cache.sigmas[sigmas_key] + ) + timesteps_seq = all_timesteps.driver_tensor + dts_seq = all_dts.driver_tensor + + cfg_scale_buf: Buffer | None = None + if do_true_cfg: + cfg_scale = float(model_inputs.true_cfg_scale) + if cfg_scale not in self.cache.cfg_scales: + self.cache.cfg_scales[cfg_scale] = Buffer.from_dlpack( + np.array([cfg_scale], dtype=np.float32) + ).to(device) + cfg_scale_buf = self.cache.cfg_scales[cfg_scale] + + latents_in = latents + ids_in = latent_image_ids + if image_latents is not None and image_latent_ids is not None: + latents_in, ids_in = self.cached_concat_image_latents( + latents, + image_latents, + latent_image_ids, + image_latent_ids, + ) + + # Phase 4: denoising loop. + with Tracer("denoising_loop"): + for i in range(num_inference_steps): + with Tracer(f"denoising_step_{i}"): + timestep = timesteps_seq[i : i + 1] + dt = dts_seq[i : i + 1] + + with Tracer("transformer_pos"): + noise_pred = self.transformer( + latents_in, + prompt_embeds, + timestep, + ids_in, + text_ids, + )[0] + + if ( + do_true_cfg + and negative_prompt_embeds is not None + and negative_text_ids is not None + and cfg_scale_buf is not None + ): + with Tracer("transformer_neg"): + noise_pred_uncond = self.transformer( + latents_in, + negative_prompt_embeds, + timestep, + ids_in, + negative_text_ids, + )[0] + with Tracer("cfg_blend"): + noise_pred = self.cached_cfg_blend( + noise_pred, noise_pred_uncond, cfg_scale_buf + ) + + with Tracer("scheduler_step"): + latents_in = self.cached_scheduler_step( + latents_in, + noise_pred, + dt, + ids_in, + ) + + latents = self.cached_extract_noise_latents( + latents_in, noise_token_count + ) + + # Phase 5: decode outputs. + image_list = [] + if batch_size == 1: + image_list.append( + self.decode_latents( + latents, + model_inputs.height, + model_inputs.width, + output_type, + ) + ) + else: + lat_np = self._to_numpy(latents) + for b in range(batch_size): + latents_b = Buffer.from_dlpack( + np.ascontiguousarray(lat_np[b : b + 1]) + ).to(device) + image_list.append( + self.decode_latents( + latents_b, + model_inputs.height, + model_inputs.width, + output_type, + ) + ) + + return QwenImageEditPipelineOutput(images=image_list) diff --git a/max/python/max/pipelines/core/context.py b/max/python/max/pipelines/core/context.py index 28cc3db4ce2..baad82c22d8 100644 --- a/max/python/max/pipelines/core/context.py +++ b/max/python/max/pipelines/core/context.py @@ -671,7 +671,7 @@ class PixelContext: num_inference_steps: Number of denoising steps. guidance_scale: Guidance scale for classifier-free guidance. num_images_per_prompt: Number of images/videos to generate per prompt. - input_image: Optional input image for image-to-image generation (PIL.Image.Image). + input_images: Optional list of input images for image-to-image generation. model_name: Name of the model being used. """ @@ -729,8 +729,19 @@ class PixelContext: num_images_per_prompt: int = field(default=1) input_image: npt.NDArray[np.uint8] | None = field(default=None) """Input image as numpy array (H, W, C) in uint8 format for image-to-image generation.""" - image: npt.NDArray[np.uint8] | None = field(default=None) - """Decoded output image (H, W, C) uint8 [0, 255]. Set after generation completes.""" + input_images: list[npt.NDArray[np.uint8]] | None = field(default=None) + """Input images as list of numpy arrays (H, W, C) in uint8 format for image-to-image generation.""" + prompt_images: list[npt.NDArray[np.uint8]] | None = field(default=None) + """Optional prompt-conditioning images prepared by the tokenizer.""" + vae_condition_images: list[npt.NDArray[np.uint8]] | None = field( + default=None + ) + """Optional VAE-conditioning images prepared by the tokenizer. + + Qwen image edit keeps prompt-conditioning images and VAE-conditioning + images separate because the multimodal prompt encoder and the VAE latent + conditioning path use different resize targets. + """ output_format: str = field(default="jpeg") """Image encoding format for the output (e.g., 'jpeg', 'png', 'webp').""" status: GenerationStatus = field(default=GenerationStatus.ACTIVE) @@ -751,24 +762,16 @@ def reset(self) -> None: """Resets the context's state.""" self.status = GenerationStatus.ACTIVE - def update(self, image: npt.NDArray[np.uint8]) -> None: - """Update the context with the decoded uint8 image output.""" - self.image = image + def update(self, latents: npt.NDArray[Any]) -> None: + """Update the context with newly generated latents/image data.""" + self.latents = latents def to_generation_output(self) -> GenerationOutput: """Convert this context to a GenerationOutput object.""" - if self.image is None: - raise ValueError( - "No decoded image available; generation may not have completed." - ) return GenerationOutput( request_id=self.request_id, final_status=self.status, - output=[ - OutputImageContent.from_numpy( - self.image, format=self.output_format - ) - ], + output=[OutputImageContent.from_numpy(self.latents, format="png")], ) diff --git a/max/python/max/pipelines/lib/registry.py b/max/python/max/pipelines/lib/registry.py index 0eb101e6a12..a4f2f5e8686 100644 --- a/max/python/max/pipelines/lib/registry.py +++ b/max/python/max/pipelines/lib/registry.py @@ -802,11 +802,26 @@ def retrieve_factory( "tokenizer_2" in diffusers_config["components"] ) + # Determine tokenizer max_length based on pipeline type + pipeline_class_name = ( + diffusers_config.get("_class_name", "") + if diffusers_config + else "" + ) + if pipeline_class_name in { + "QwenImagePipeline", + "QwenImageEditPipeline", + "QwenImageEditPlusPipeline", + }: + # QwenImage uses Qwen2 tokenizer with chat template (34 prefix tokens) + primary_max_length = 1024 + 34 + else: + primary_max_length = 77 # Standard for CLIP tokenizer_kwargs = { "model_path": pipeline_config.model.model_path, "pipeline_config": pipeline_config, "subfolder": "tokenizer", - "max_length": max_length, + "max_length": primary_max_length, "revision": pipeline_config.model.huggingface_model_revision, "trust_remote_code": pipeline_config.model.trust_remote_code, }