Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 86 additions & 27 deletions max/examples/diffusion/simple_offline_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@
"Flux2Pipeline_ModuleV3",
"Flux2KleinPipeline_ModuleV3",
}
QWEN_IMAGE_ARCH_NAMES = {
"QwenImagePipeline",
"QwenImageEditPipeline",
"QwenImageEditPlusPipeline",
}
QWEN_IMAGE_EDIT_ARCH_NAMES = {
"QwenImageEditPipeline",
"QwenImageEditPlusPipeline",
}
QWEN_DEFAULT_GUIDANCE_SCALE = 1.0
QWEN_DEFAULT_TRUE_CFG_SCALE = 4.0


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
Expand Down Expand Up @@ -158,8 +169,21 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser.add_argument(
"--guidance-scale",
type=float,
default=3.5,
help="Guidance scale for classifier-free guidance. Set to 1.0 to disable CFG.",
default=None,
help=(
"Guidance scale for classifier-free guidance. "
"If omitted, defaults to 1.0 for QwenImage family and 3.5 otherwise."
),
)
parser.add_argument(
"--true-cfg-scale",
type=float,
default=None,
help=(
"True classifier-free guidance scale. "
"If omitted, defaults to 4.0 for QwenImage family when negative prompt is provided, "
"and 1.0 otherwise."
),
)
parser.add_argument(
"--seed",
Expand Down Expand Up @@ -188,8 +212,9 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser.add_argument(
"--input-image",
type=str,
action="append",
default=None,
help="Input image for image-to-image generation.",
help="Input image for image-to-image generation. Can be specified multiple times.",
)
parser.add_argument(
"--profile-timings",
Expand Down Expand Up @@ -263,16 +288,19 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
assert args.num_inference_steps > 0, (
"num-inference-steps must be a positive integer."
)
assert args.guidance_scale > 0.0, "guidance-scale must be positive."
if args.residual_threshold is not None:
if args.guidance_scale is not None:
assert args.guidance_scale > 0.0, "guidance-scale must be positive."
if args.true_cfg_scale is not None:
assert args.true_cfg_scale > 0.0, "true-cfg-scale must be positive."
if hasattr(args, 'residual_threshold') and args.residual_threshold is not None:
assert args.residual_threshold >= 0.0, (
"residual-threshold must be non-negative."
)
if args.taylorseer_cache_interval is not None:
if hasattr(args, 'taylorseer_cache_interval') and args.taylorseer_cache_interval is not None:
assert args.taylorseer_cache_interval >= 1, (
"taylorseer-cache-interval must be >= 1."
)
if args.taylorseer_warmup_steps is not None:
if hasattr(args, 'taylorseer_warmup_steps') and args.taylorseer_warmup_steps is not None:
assert args.taylorseer_warmup_steps >= 1, (
"taylorseer-warmup-steps must be >= 1."
)
Expand Down Expand Up @@ -381,6 +409,8 @@ async def generate_image(args: argparse.Namespace) -> None:
)
if arch.name in _FLUX2_ARCH_NAMES:
max_length = 512
elif arch.name in QWEN_IMAGE_ARCH_NAMES:
max_length = 512
print(f"Using max length: {max_length} for tokenizer")

if (
Expand Down Expand Up @@ -433,27 +463,50 @@ async def generate_image(args: argparse.Namespace) -> None:
print(f"Generating image for prompt: '{args.prompt}'")

# Step 4: Create an OpenResponsesRequest
# Load input image if provided and convert to data URI
input_image_data_uri = load_image_as_data_uri(args.input_image)
# Load input images if provided and convert to data URIs
input_image_data_uris: list[str] = []
if args.input_image:
for img_path in args.input_image:
uri = load_image_as_data_uri(img_path)
if uri is not None:
input_image_data_uris.append(uri)

is_qwen_image_family = arch.name in QWEN_IMAGE_ARCH_NAMES
guidance_scale = args.guidance_scale
if guidance_scale is None:
guidance_scale = (
QWEN_DEFAULT_GUIDANCE_SCALE if is_qwen_image_family else 3.5
)

true_cfg_scale = args.true_cfg_scale
if true_cfg_scale is None:
if is_qwen_image_family and args.negative_prompt is not None:
true_cfg_scale = QWEN_DEFAULT_TRUE_CFG_SCALE
else:
true_cfg_scale = 1.0

# Create request with structured message if image is provided
if input_image_data_uri:
# Create request with structured message if images are provided
if input_image_data_uris:
# Image-to-image: Use structured message with InputImageContent + InputTextContent
image_content_items: list[InputImageContent | InputTextContent] = [
InputImageContent(
type="input_image",
image_url=uri,
)
for uri in input_image_data_uris
]
image_content_items.append(
InputTextContent(
type="input_text",
text=args.prompt,
)
)
body = OpenResponsesRequestBody(
model=args.model,
input=[
UserMessage(
role="user",
content=[
InputImageContent(
type="input_image",
image_url=input_image_data_uri,
),
InputTextContent(
type="input_text",
text=args.prompt,
),
],
content=image_content_items,
)
],
seed=args.seed,
Expand All @@ -463,7 +516,8 @@ async def generate_image(args: argparse.Namespace) -> None:
height=args.height,
width=args.width,
steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
guidance_scale=guidance_scale,
true_cfg_scale=true_cfg_scale,
)
),
)
Expand All @@ -479,15 +533,17 @@ async def generate_image(args: argparse.Namespace) -> None:
height=args.height,
width=args.width,
steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
guidance_scale=guidance_scale,
true_cfg_scale=true_cfg_scale,
)
),
)

request = OpenResponsesRequest(request_id=RequestID(), body=body)

print(
f"Parameters: steps={args.num_inference_steps}, guidance={args.guidance_scale}"
"Parameters: "
f"steps={args.num_inference_steps}, guidance={guidance_scale}, true_cfg={true_cfg_scale}"
)

# Step 5: Create a PixelContext object from the request
Expand Down Expand Up @@ -545,16 +601,19 @@ async def generate_image(args: argparse.Namespace) -> None:
height=args.height,
width=args.width,
steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
guidance_scale=guidance_scale,
true_cfg_scale=true_cfg_scale,
)
),
)
request_warmup = OpenResponsesRequest(
request_id=RequestID(), body=body_warmup
)
input_image = Image.open(args.input_image) if args.input_image else None
warmup_image = (
Image.open(args.input_image[0]) if args.input_image else None
)
context_warmup = await tokenizer.new_context(
request_warmup, input_image=input_image
request_warmup, input_image=warmup_image
)
inputs_warmup = PixelGenerationInputs[PixelContext](
batch={context_warmup.request_id: context_warmup}
Expand Down
5 changes: 5 additions & 0 deletions max/python/max/pipelines/architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def register_all_models() -> None:
from .qwen3_embedding import qwen3_embedding_arch
from .qwen3_embedding_modulev3 import qwen3_embedding_modulev3_arch
from .qwen3vl_moe import qwen3vl_arch, qwen3vl_moe_arch
from .qwen_image import qwen_image_arch
from .qwen_image_edit import qwen_image_edit_arch, qwen_image_edit_plus_arch
from .unified_eagle_llama3 import unified_eagle_llama3_arch
from .unified_mtp_deepseekV3 import unified_mtp_deepseekV3_arch

Expand Down Expand Up @@ -131,6 +133,9 @@ def register_all_models() -> None:
qwen3_embedding_modulev3_arch,
qwen3vl_arch,
qwen3vl_moe_arch,
qwen_image_arch,
qwen_image_edit_arch,
qwen_image_edit_plus_arch,
unified_eagle_llama3_arch,
unified_mtp_deepseekV3_arch,
]
Expand Down
17 changes: 17 additions & 0 deletions max/python/max/pipelines/architectures/qwen_image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from .arch import qwen_image_arch
from .model import QwenImageTransformerModel

__all__ = ["QwenImageTransformerModel", "qwen_image_arch"]
61 changes: 61 additions & 0 deletions max/python/max/pipelines/architectures/qwen_image/arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from __future__ import annotations

from dataclasses import dataclass

from max.graph.weights import WeightsFormat
from max.interfaces import PipelineTask
from max.pipelines.core import PixelContext
from max.pipelines.lib import (
PixelGenerationTokenizer,
SupportedArchitecture,
)
from max.pipelines.lib.config import PipelineConfig
from max.pipelines.lib.interfaces import ArchConfig
from typing_extensions import Self

from .pipeline_qwen_image import QwenImagePipeline


@dataclass(kw_only=True)
class QwenImageArchConfig(ArchConfig):
"""Pipeline-level config for QwenImage (implements ArchConfig; no KV cache)."""

pipeline_config: PipelineConfig

def get_max_seq_len(self) -> int:
return 0 # Not used for pixel generation.

@classmethod
def initialize(cls, pipeline_config: PipelineConfig) -> Self:
if len(pipeline_config.model.device_specs) != 1:
raise ValueError("QwenImage is only supported on a single device")
return cls(pipeline_config=pipeline_config)


qwen_image_arch = SupportedArchitecture(
name="QwenImagePipeline",
task=PipelineTask.PIXEL_GENERATION,
default_encoding="bfloat16",
supported_encodings={"bfloat16": []},
example_repo_ids=[
"Qwen/Qwen-Image-2512",
],
pipeline_model=QwenImagePipeline, # type: ignore[arg-type]
context_type=PixelContext,
default_weights_format=WeightsFormat.safetensors,
tokenizer=PixelGenerationTokenizer,
config=QwenImageArchConfig,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from .embeddings import QwenImagePosEmbed, QwenImageTimestepProjEmbeddings
from .qwen_image_attention import (
QwenImageFeedForward,
QwenImageTransformerBlock,
)

__all__ = [
"QwenImageFeedForward",
"QwenImagePosEmbed",
"QwenImageTimestepProjEmbeddings",
"QwenImageTransformerBlock",
]
Loading
Loading