From 3bb103f9d499aca28f6574b39d483ded6e0d1613 Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Fri, 16 Jan 2026 09:29:55 +0000 Subject: [PATCH 01/18] feat: add entrypoints for diffusion pipeline --- max/examples/diffusion/offline_generation.py | 61 ++++++++++ max/python/max/entrypoints/BUILD.bazel | 36 ++++++ max/python/max/entrypoints/cli/generate.py | 39 ++++++ max/python/max/entrypoints/diffusion.py | 57 +++++++++ max/python/max/entrypoints/pipelines.py | 113 ++++++++++++++++++ .../max/entrypoints/pipelines_diffusion.py | 29 +++++ 6 files changed, 335 insertions(+) create mode 100644 max/examples/diffusion/offline_generation.py create mode 100644 max/python/max/entrypoints/diffusion.py create mode 100644 max/python/max/entrypoints/pipelines_diffusion.py diff --git a/max/examples/diffusion/offline_generation.py b/max/examples/diffusion/offline_generation.py new file mode 100644 index 00000000000..56bfce230c0 --- /dev/null +++ b/max/examples/diffusion/offline_generation.py @@ -0,0 +1,61 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import argparse +import os +from pathlib import Path + +from max.entrypoints.diffusion import DiffusionPipeline +from max.pipelines import PipelineConfig + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-path", type=str, default="black-forest-labs/FLUX.1-dev" + ) + parser.add_argument("--use-torch-randn", action="store_true") + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + model_path = args.model_path + if args.use_torch_randn: + # NOTE: Use torch randn for latent initialization. + # Currently, It's not possible to set seed for Max random generation, + # so, use torch randn to test different seeds. + os.environ["USE_TORCH_RANDN"] = "1" + os.environ["SEED"] = str(args.seed) + pipeline_config = PipelineConfig(model_path=model_path) + pipe = DiffusionPipeline(pipeline_config) + + prompt = "A cat holding a sign that says hello world" + print(f"Prompt: {prompt}") + + result = pipe( + prompt=prompt, + height=1024, + width=1024, + num_inference_steps=50, + guidance_scale=3.5, + ) + + images = result.images + + output_path = Path("output.png") + output_path.parent.mkdir(parents=True, exist_ok=True) + images[0].save(output_path) + print(f"Image saved to: {output_path}") + + +if __name__ == "__main__": + main() diff --git a/max/python/max/entrypoints/BUILD.bazel b/max/python/max/entrypoints/BUILD.bazel index 6bef0b817c7..caf523b6598 100644 --- a/max/python/max/entrypoints/BUILD.bazel +++ b/max/python/max/entrypoints/BUILD.bazel @@ -93,6 +93,42 @@ modular_py_binary( ], ) +modular_py_binary( + name = "pipelines_diffusion", + srcs = [ + "pipelines_diffusion.py", + ], + data = [ + "@nvshmem_prebuilt//:host", + ], + env = { + "OTEL_EXPORTER_OTLP_METRICS_DEFAULT_HISTOGRAM_AGGREGATION": "base2_exponential_bucket_histogram", + "MODULAR_SHMEM_LIB_DIR": "../+http_archive+nvshmem_prebuilt", + }, + mojo_deps = select({ + "//:emit_mojo_enabled": PROD_MOJOPKGS, + "//conditions:default": [], + }), + deps = [ + # Provides the `max.entrypoints.pipelines` module for the wrapper to import. + ":_pipelines", + ":entrypoints", + "//max/python/max:_core", + "//max/python/max/benchmark:benchmark_serving_lib", + "//max/python/max/interfaces", + "//max/python/max/pipelines", + "//max/python/max/serve:config", + "//max/python/max/serve/telemetry", + requirement("typing-extensions"), + requirement("click"), + ] + select({ + "//:nvidia_gpu": [ + requirement("torch"), + ], + "//conditions:default": [], + }), +) + modular_py_binary( name = "replay_recording", srcs = ["replay_recording.py"], diff --git a/max/python/max/entrypoints/cli/generate.py b/max/python/max/entrypoints/cli/generate.py index 9dd3c93d5ad..01f38fc0086 100644 --- a/max/python/max/entrypoints/cli/generate.py +++ b/max/python/max/entrypoints/cli/generate.py @@ -19,6 +19,7 @@ import dataclasses import logging from collections.abc import Iterable +from pathlib import Path from typing import Any import requests @@ -148,3 +149,41 @@ def generate_text_for_pipeline( print_tokens=True, ) ) + + +def generate_image( + pipeline_config: PipelineConfig, + prompt: str, + height: int, + width: int, + num_inference_steps: int, + guidance_scale: float, + num_images_per_prompt: int, + output: Path, +) -> None: + from ..diffusion import DiffusionPipeline + + pipeline = DiffusionPipeline(pipeline_config) + result = pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + ) + + images = result.images + assert images, "Expected at least one generated image." + + output.parent.mkdir(parents=True, exist_ok=True) + if num_images_per_prompt == 1: + images[0].save(output) + logger.info(f"Image saved to: {output}") + else: + stem = output.stem + suffix = output.suffix + for i, image in enumerate(images): + numbered_path = output.parent / f"{stem}_{i + 1}{suffix}" + image.save(numbered_path) + logger.info(f"{len(images)} images saved to: {output.parent}") diff --git a/max/python/max/entrypoints/diffusion.py b/max/python/max/entrypoints/diffusion.py new file mode 100644 index 00000000000..10e9c15719b --- /dev/null +++ b/max/python/max/entrypoints/diffusion.py @@ -0,0 +1,57 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.interfaces import ( + ImageGenerationInputs, + ImageGenerationOutput, + PipelineTask, +) +from max.pipelines.lib import PIPELINE_REGISTRY, PipelineConfig + + +class DiffusionPipeline: + """Entrypoint for image-generation diffusion pipelines.""" + + def __init__(self, pipeline_config: PipelineConfig) -> None: + # NOTE: Currently, this entrypoint is implemented minimally + # for offline image generation. + # It will be developed further to support serving as well. + self.pipeline_config = pipeline_config + _, model_factory = PIPELINE_REGISTRY.retrieve_factory( + pipeline_config, + task=PipelineTask.IMAGE_GENERATION, + ) + self.pipeline = model_factory() + + def __call__( + self, + prompt: str, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 50, + guidance_scale: float = 3.5, + num_images_per_prompt: int = 1, + ) -> ImageGenerationOutput: + """Generate images from a prompt with the configured pipeline.""" + # TODO: consider all possible diffusion tasks, + # e.g. T2I, I2I, T2V, I2V, V2V. + inputs = ImageGenerationInputs( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + ) + pipeline_output: ImageGenerationOutput = self.pipeline.execute(inputs) + return pipeline_output diff --git a/max/python/max/entrypoints/pipelines.py b/max/python/max/entrypoints/pipelines.py index c8dcada7415..7b315b44aea 100644 --- a/max/python/max/entrypoints/pipelines.py +++ b/max/python/max/entrypoints/pipelines.py @@ -18,6 +18,7 @@ import os import sys from collections.abc import Callable, Sequence +from pathlib import Path from typing import Any, TypeVar import click @@ -384,6 +385,118 @@ def cli_pipeline( ) +@main.group(name="diffusion", cls=ModelGroup) +def diffusion_group() -> None: + """Commands for diffusion-based image/video generation pipelines.""" + + +@diffusion_group.command(name="generate", cls=WithLazyPipelineOptions) +@click.option( + "--prompt", + type=str, + default="A cat holding a sign that says hello world", + help="The text prompt to use for image generation.", +) +@click.option( + "--height", + type=click.IntRange(min=64), + default=1024, + show_default=True, + help="Generated image height in pixels.", +) +@click.option( + "--width", + type=click.IntRange(min=64), + default=1024, + show_default=True, + help="Generated image width in pixels.", +) +@click.option( + "--num-inference-steps", + type=click.IntRange(min=1), + default=50, + show_default=True, + help="Number of denoising steps to run.", +) +@click.option( + "--guidance-scale", + type=float, + default=3.5, + show_default=True, + help="Classifier-free guidance scale.", +) +@click.option( + "--num-images-per-prompt", + type=click.IntRange(min=1), + default=1, + show_default=True, + help="Number of images to generate for a single prompt.", +) +@click.option( + "--output", + type=click.Path(dir_okay=False, path_type=Path), + default="output.png", + show_default=True, + help="Output image path (numbered if multiple images are generated).", +) +@click.option( + "--use-torch-randn/--no-use-torch-randn", + default=False, + show_default=True, + help=( + "Use torch-based random latents (set USE_TORCH_RANDN and SEED env vars)." + ), +) +@click.option( + "--seed", + type=int, + default=42, + show_default=True, + help="Random seed for torch-based latent initialization.", +) +def diffusion_generate( + prompt: str, + height: int, + width: int, + num_inference_steps: int, + guidance_scale: float, + num_images_per_prompt: int, + output: Path, + use_torch_randn: bool, + seed: int, + **config_kwargs: Any, +) -> None: + """Generate images using a diffusion pipeline.""" + from max.entrypoints.cli.generate import generate_image + from max.pipelines import PipelineConfig + + if use_torch_randn: + os.environ["USE_TORCH_RANDN"] = "1" + os.environ["SEED"] = str(seed) + + pipeline_config = PipelineConfig(**config_kwargs) + pipeline_config.log_basic_config() + + try: + generate_image( + pipeline_config=pipeline_config, + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + output=output, + ) + except Exception as exc: + logger.exception( + "Diffusion generation failed for model %s with prompt %r", + pipeline_config.model.model_path, + prompt, + ) + raise click.ClickException("Diffusion generation failed.") from exc + + @main.command(name="encode", cls=WithLazyPipelineOptions) @click.option( "--prompt", diff --git a/max/python/max/entrypoints/pipelines_diffusion.py b/max/python/max/entrypoints/pipelines_diffusion.py new file mode 100644 index 00000000000..2d94c5bb423 --- /dev/null +++ b/max/python/max/entrypoints/pipelines_diffusion.py @@ -0,0 +1,29 @@ +"""Diffusion-only CLI wrapper. + +This exists so Bazel can keep `//max/python/max/entrypoints:pipelines` lean, +while allowing `//max/python/max/entrypoints:pipelines_diffusion` to pull in +extra runtime deps. +""" + +from __future__ import annotations + +import sys + + +def main() -> None: + # Import the main pipelines CLI and dispatch into the `diffusion` group. + # + # NOTE: `max.entrypoints.pipelines.main` is a click command object. Calling it + # with `args=[...]` is equivalent to invoking the CLI with those argv tokens. + import max.entrypoints.pipelines as pipelines_cli + + pipelines_cli.main( + prog_name="pipelines", + args=["diffusion", *sys.argv[1:]], + ) + + +if __name__ == "__main__": + main() + + From 5c4e3ca52b8ac339a41a5fbe197dd57a2c0baeeb Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Fri, 16 Jan 2026 09:34:27 +0000 Subject: [PATCH 02/18] feat: add interfaces for image generation --- max/python/max/interfaces/__init__.py | 4 + .../interfaces/pipeline_variants/__init__.py | 6 + .../pipeline_variants/image_generation.py | 40 ++++ max/python/max/interfaces/task.py | 2 + .../lib/pipeline_variants/__init__.py | 1 + .../lib/pipeline_variants/image_generation.py | 217 ++++++++++++++++++ max/python/max/pipelines/lib/registry.py | 190 +++++++++------ 7 files changed, 384 insertions(+), 76 deletions(-) create mode 100644 max/python/max/interfaces/pipeline_variants/image_generation.py create mode 100644 max/python/max/pipelines/lib/pipeline_variants/image_generation.py diff --git a/max/python/max/interfaces/__init__.py b/max/python/max/interfaces/__init__.py index 3d3062549e5..30910f28851 100644 --- a/max/python/max/interfaces/__init__.py +++ b/max/python/max/interfaces/__init__.py @@ -49,6 +49,8 @@ EmbeddingsGenerationInputs, EmbeddingsGenerationOutput, ImageContentPart, + ImageGenerationInputs, + ImageGenerationOutput, ImageMetadata, TextContentPart, TextGenerationContext, @@ -109,6 +111,8 @@ def create_text_pipeline() -> Pipeline[TextGenerationInputs, TextGenerationOutpu "EmbeddingsGenerationOutput", "GenerationStatus", "ImageContentPart", + "ImageGenerationInputs", + "ImageGenerationOutput", "ImageMetadata", "LoRAOperation", "LoRARequest", diff --git a/max/python/max/interfaces/pipeline_variants/__init__.py b/max/python/max/interfaces/pipeline_variants/__init__.py index d5a7b10d1f3..073a7ad7a76 100644 --- a/max/python/max/interfaces/pipeline_variants/__init__.py +++ b/max/python/max/interfaces/pipeline_variants/__init__.py @@ -24,6 +24,10 @@ EmbeddingsGenerationInputs, EmbeddingsGenerationOutput, ) +from .image_generation import ( + ImageGenerationInputs, + ImageGenerationOutput, +) from .text_generation import ( BatchType, ImageContentPart, @@ -54,6 +58,8 @@ "EmbeddingsGenerationInputs", "EmbeddingsGenerationOutput", "ImageContentPart", + "ImageGenerationInputs", + "ImageGenerationOutput", "ImageMetadata", "TextContentPart", "TextGenerationContext", diff --git a/max/python/max/interfaces/pipeline_variants/image_generation.py b/max/python/max/interfaces/pipeline_variants/image_generation.py new file mode 100644 index 00000000000..dfa55f2a2b6 --- /dev/null +++ b/max/python/max/interfaces/pipeline_variants/image_generation.py @@ -0,0 +1,40 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from dataclasses import dataclass + +from max.interfaces.pipeline import PipelineInputs +from PIL.Image import Image + + +@dataclass(eq=True) +class ImageGenerationInputs(PipelineInputs): + """Inputs for image-generation pipelines.""" + + # NOTE: Current implementation only considers offline generation without + # request scheduling. `ImageGenerationContext` should be used once + # request scheduling is implemented. + prompt: str + height: int + width: int + num_inference_steps: int + guidance_scale: float + num_images_per_prompt: int + + +@dataclass(kw_only=True) +class ImageGenerationOutput: + """Output container for generated images.""" + + images: list[Image] + """List of generated images.""" diff --git a/max/python/max/interfaces/task.py b/max/python/max/interfaces/task.py index 477b77451e7..0422ebf657c 100644 --- a/max/python/max/interfaces/task.py +++ b/max/python/max/interfaces/task.py @@ -58,6 +58,8 @@ class PipelineTask(str, Enum): """Task for generating audio.""" SPEECH_TOKEN_GENERATION = "speech_token_generation" """Task for generating speech tokens.""" + IMAGE_GENERATION = "image_generation" + """Task for generating images.""" @property def output_type( diff --git a/max/python/max/pipelines/lib/pipeline_variants/__init__.py b/max/python/max/pipelines/lib/pipeline_variants/__init__.py index 991ed73eea4..2c04acdbcdd 100644 --- a/max/python/max/pipelines/lib/pipeline_variants/__init__.py +++ b/max/python/max/pipelines/lib/pipeline_variants/__init__.py @@ -11,4 +11,5 @@ # limitations under the License. # ===----------------------------------------------------------------------=== # +from .image_generation import ImageGenerationPipeline from .text_generation import TextGenerationPipeline diff --git a/max/python/max/pipelines/lib/pipeline_variants/image_generation.py b/max/python/max/pipelines/lib/pipeline_variants/image_generation.py new file mode 100644 index 00000000000..bbb74c0cae3 --- /dev/null +++ b/max/python/max/pipelines/lib/pipeline_variants/image_generation.py @@ -0,0 +1,217 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from __future__ import annotations + +import fnmatch +import logging +import os +import re +from pathlib import Path +from typing import TYPE_CHECKING + +import huggingface_hub +import requests +from huggingface_hub.utils import EntryNotFoundError, OfflineModeIsEnabled +from max.config import load_config +from max.interfaces import ( + ImageGenerationInputs, + ImageGenerationOutput, + Pipeline, + RequestID, +) +from requests.exceptions import HTTPError + +from ..config_enums import RepoType +from ..interfaces import DiffusionPipeline + +if TYPE_CHECKING: + from ..config import PipelineConfig + +logger = logging.getLogger(__name__) + + +class ImageGenerationPipeline( + Pipeline[ImageGenerationInputs, ImageGenerationOutput], +): + """Pipeline wrapper for diffusion image generation.""" + + def __init__( + self, + pipeline_config: PipelineConfig, + diffusion_pipeline: type[DiffusionPipeline], + ) -> None: + # Download checkpoints if required + # NOTE: Unlike TextGenerationPipeline where each file, + # such as configs and weights, are downloaded individually, + # DiffusionPipeline downloads the entire snapshot at once, + # since it normally contains multiple components. + pretrained_model_name_or_path = ( + pipeline_config.model.huggingface_model_repo.repo_id + ) + if ( + pipeline_config.model.huggingface_model_repo.repo_type + == RepoType.online + ): + cached_folder = self.download( + pretrained_model_name_or_path, + config_name=diffusion_pipeline.config_name, + force_download=pipeline_config.model.force_download, + revision=pipeline_config.model.huggingface_model_revision, + ) + else: + cached_folder = pretrained_model_name_or_path + + self._diffusion_pipeline = diffusion_pipeline( + pipeline_config, cached_folder + ) + + def download( + self, + pretrained_model_name: str | os.PathLike, + config_name: str | None, + force_download: bool = False, + revision: str | None = None, + ) -> str: + """Download the pipeline components from the Hugging Face Hub. + + Args: + pretrained_model_name: Model identifier. + config_name: Pipeline config filename in the repo. + force_download: Whether to force download. + revision: Model revision. + + Returns: + Path to the downloaded model folder. + """ + try: + info = huggingface_hub.model_info( + pretrained_model_name, revision=revision + ) + except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e: + logger.warning( + f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache." + ) + model_info_call_error = ( + e # save error to reraise it if model is not cached locally + ) + + if config_name is None: + raise ValueError( + f"config_name for {pretrained_model_name} pipeline is not set. " + "Please set proper config file name from huggingface hub." + ) + try: + config_file = huggingface_hub.hf_hub_download( + pretrained_model_name, + config_name, + revision=revision, + force_download=force_download, + ) + except EntryNotFoundError as e: + raise ValueError( + f"config file {config_name} not found for {pretrained_model_name} pipeline. " + "Please check if the config file name is correct." + ) from e + + config_dict = load_config(config_file) + ignore_filenames = config_dict.pop("_ignore_files", []) + + filenames = {sibling.rfilename for sibling in info.siblings} + filenames = set(filenames) - set(ignore_filenames) + + ignore_patterns = [ + "*.bin", + "*.msgpack", + "*.onnx", + "*.pb", + "*.bin.index.*json", + "*.msgpack.index.*json", + "*.onnx.index.*json", + "*.pb.index.*json", + ] + + allow_patterns = ["*/*"] + allow_patterns += [ + "scheduler_config.json", + "config.json", + config_name, + ] + re_ignore_pattern = [ + re.compile(fnmatch.translate(p)) for p in ignore_patterns + ] + re_allow_pattern = [ + re.compile(fnmatch.translate(p)) for p in allow_patterns + ] + + expected_files = [ + f + for f in filenames + if not any(p.match(f) for p in re_ignore_pattern) + ] + expected_files = [ + f + for f in expected_files + if any(p.match(f) for p in re_allow_pattern) + ] + + snapshot_folder = Path(config_file).parent + pipeline_is_cached = all( + (snapshot_folder / f).is_file() for f in expected_files + ) + + if pipeline_is_cached and not force_download: + # if the pipeline is cached, we can directly return it + # else call snapshot_download + return snapshot_folder + + # download all allow_patterns - ignore_patterns + try: + cached_folder = huggingface_hub.snapshot_download( + pretrained_model_name, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + + return cached_folder + + except FileNotFoundError: + # Means we tried to load pipeline with `local_files_only=True` but the files have not been found in local cache. + # This can happen in two cases: + # 1. If the user passed `local_files_only=True` => we raise the error directly + # 2. If we forced `local_files_only=True` when `model_info` failed => we raise the initial error + if model_info_call_error is None: + # 1. user passed `local_files_only=True` + raise + else: + # 2. we forced `local_files_only=True` when `model_info` failed + raise OSError( + f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occurred" + " while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace" + " above." + ) from model_info_call_error + + def execute(self, inputs: ImageGenerationInputs) -> ImageGenerationOutput: + outputs = self._diffusion_pipeline( + prompt=inputs.prompt, + height=inputs.height, + width=inputs.width, + num_inference_steps=inputs.num_inference_steps, + guidance_scale=inputs.guidance_scale, + num_images_per_prompt=inputs.num_images_per_prompt, + ) + return ImageGenerationOutput(images=outputs.images) + + def release(self, request_id: RequestID) -> None: + pass diff --git a/max/python/max/pipelines/lib/registry.py b/max/python/max/pipelines/lib/registry.py index 566bcb6d5a0..3445e372971 100644 --- a/max/python/max/pipelines/lib/registry.py +++ b/max/python/max/pipelines/lib/registry.py @@ -16,6 +16,7 @@ from __future__ import annotations import functools +import json import logging from collections.abc import Callable from dataclasses import dataclass, field @@ -38,6 +39,7 @@ from transformers import ( AutoConfig, AutoTokenizer, + PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, ) @@ -49,8 +51,9 @@ from .audio_generator_pipeline import AudioGeneratorPipeline from .config_enums import RopeType, SupportedEncoding from .embeddings_pipeline import EmbeddingsPipeline -from .hf_utils import HuggingFaceRepo +from .hf_utils import HuggingFaceRepo, get_model_index_path_for_diffusers from .interfaces import PipelineModel +from .pipeline_variants.image_generation import ImageGenerationPipeline from .pipeline_variants.text_generation import TextGenerationPipeline from .speculative_decoding import ( EAGLESpeculativeDecodingPipeline, @@ -74,6 +77,7 @@ def get_pipeline_for_task( | type[StandaloneSpeculativeDecodingPipeline] | type[SpeechTokenGenerationPipeline] | type[EAGLESpeculativeDecodingPipeline] + | type[ImageGenerationPipeline] ): if task == PipelineTask.TEXT_GENERATION: if pipeline_config._speculative is not None: @@ -100,6 +104,8 @@ def get_pipeline_for_task( return AudioGeneratorPipeline elif task == PipelineTask.SPEECH_TOKEN_GENERATION: return SpeechTokenGenerationPipeline + elif task == PipelineTask.IMAGE_GENERATION: + return ImageGenerationPipeline @dataclass(frozen=False) @@ -290,7 +296,7 @@ def retrieve_architecture( def get_active_huggingface_config( self, huggingface_repo: HuggingFaceRepo - ) -> AutoConfig: + ) -> AutoConfig | PretrainedConfig: """Retrieves or creates a cached HuggingFace AutoConfig for the given model configuration. @@ -311,7 +317,22 @@ def get_active_huggingface_config( Returns: AutoConfig: The HuggingFace configuration object for the model. """ - if huggingface_repo not in self._cached_huggingface_configs: + model_index_path = get_model_index_path_for_diffusers(huggingface_repo) + + if model_index_path is not None: + with open(model_index_path, encoding="utf-8") as f: + model_index = json.load(f) + + class_name = model_index.get("_class_name") + if not class_name or not isinstance(class_name, str): + raise ValueError( + f"Diffusers-style repository '{huggingface_repo.repo_id}' is missing a valid '_class_name' in model_index.json" + ) + + self._cached_huggingface_configs[huggingface_repo] = ( + PretrainedConfig(architectures=[class_name]) + ) + else: self._cached_huggingface_configs[huggingface_repo] = ( AutoConfig.from_pretrained( huggingface_repo.repo_id, @@ -444,87 +465,104 @@ def retrieve_factory( assert arch is not None devices = load_devices(pipeline_config.model.device_specs) - max_length = arch.pipeline_model.calculate_max_seq_len( - pipeline_config, huggingface_config=huggingface_config - ) - - # Old Mistral model like Mistral-7B-Instruct-v0.3 uses LlamaTokenizer - # and suffers from the whitespace decoding bug. So, we enable the fix - # for only MistralModel in order to avoid any issues with performance - # for rest of the models. This can be applied more generically once - # we have more time verifying this for all the models. - # More information: - # https://linear.app/modularml/issue/AIPIPE-197/add-support-for-mistral-7b-instruct-v03 - # TODO: remove this pipeline_model.__name__ check - if ( - arch.pipeline_model.__name__ in ("MistralModel", "Phi3Model") - and arch.tokenizer is TextTokenizer - ): - text_tokenizer = cast(type[TextTokenizer], arch.tokenizer) - tokenizer = text_tokenizer( - pipeline_config.model.model_path, - pipeline_config=pipeline_config, - revision=pipeline_config.model.huggingface_model_revision, - max_length=max_length, - trust_remote_code=pipeline_config.model.trust_remote_code, - enable_llama_whitespace_fix=True, - chat_template=pipeline_config.retrieve_chat_template(), - context_validators=arch.context_validators, + if task != PipelineTask.IMAGE_GENERATION: + max_length = arch.pipeline_model.calculate_max_seq_len( + pipeline_config, huggingface_config=huggingface_config ) - else: - tokenizer = arch.tokenizer( - model_path=pipeline_config.model.model_path, - pipeline_config=pipeline_config, - revision=pipeline_config.model.huggingface_model_revision, - max_length=max_length, - trust_remote_code=pipeline_config.model.trust_remote_code, - chat_template=pipeline_config.retrieve_chat_template(), - context_validators=arch.context_validators, + + # Old Mistral model like Mistral-7B-Instruct-v0.3 uses LlamaTokenizer + # and suffers from the whitespace decoding bug. So, we enable the fix + # for only MistralModel in order to avoid any issues with performance + # for rest of the models. This can be applied more generically once + # we have more time verifying this for all the models. + # More information: + # https://linear.app/modularml/issue/AIPIPE-197/add-support-for-mistral-7b-instruct-v03 + # TODO: remove this pipeline_model.__name__ check + if ( + arch.pipeline_model.__name__ in ("MistralModel", "Phi3Model") + and arch.tokenizer is TextTokenizer + ): + text_tokenizer = cast(type[TextTokenizer], arch.tokenizer) + tokenizer = text_tokenizer( + pipeline_config.model.model_path, + pipeline_config=pipeline_config, + revision=pipeline_config.model.huggingface_model_revision, + max_length=max_length, + trust_remote_code=pipeline_config.model.trust_remote_code, + enable_llama_whitespace_fix=True, + chat_template=pipeline_config.retrieve_chat_template(), + context_validators=arch.context_validators, + ) + else: + tokenizer = arch.tokenizer( + model_path=pipeline_config.model.model_path, + pipeline_config=pipeline_config, + revision=pipeline_config.model.huggingface_model_revision, + max_length=max_length, + trust_remote_code=pipeline_config.model.trust_remote_code, + chat_template=pipeline_config.retrieve_chat_template(), + context_validators=arch.context_validators, + ) + # Cast tokenizer to the proper type for text generation pipeline compatibility + typed_tokenizer = cast( + PipelineTokenizer[ + Any, npt.NDArray[np.integer[Any]], TextGenerationRequest + ], + tokenizer, ) - # Cast tokenizer to the proper type for text generation pipeline compatibility - typed_tokenizer = cast( - PipelineTokenizer[ - Any, npt.NDArray[np.integer[Any]], TextGenerationRequest - ], - tokenizer, - ) - # For speculative decoding, retrieve draft model's architecture - factory_kwargs: dict[str, Any] = { - "pipeline_config": pipeline_config, - "pipeline_model": arch.pipeline_model, - "eos_token_id": tokenizer.eos, - "weight_adapters": arch.weight_adapters, - "tokenizer": typed_tokenizer, - } - - # If using speculative decoding, add draft model-specific parameters - if pipeline_config.draft_model_config is not None: - draft_arch = self.retrieve_architecture( - huggingface_repo=pipeline_config.draft_model_config.huggingface_weight_repo, - use_module_v3=pipeline_config.use_module_v3, + # For speculative decoding, retrieve draft model's architecture + factory_kwargs: dict[str, Any] = { + "pipeline_config": pipeline_config, + "pipeline_model": arch.pipeline_model, + "eos_token_id": tokenizer.eos, + "weight_adapters": arch.weight_adapters, + "tokenizer": typed_tokenizer, + } + + # If using speculative decoding, add draft model-specific parameters + if pipeline_config.draft_model_config is not None: + draft_arch = self.retrieve_architecture( + huggingface_repo=pipeline_config.draft_model_config.huggingface_weight_repo, + use_module_v3=pipeline_config.use_module_v3, + ) + if draft_arch is None: + raise ValueError( + f"MAX-Optimized architecture not found for draft model " + f"'{pipeline_config.draft_model_config.model_path}'" + ) + factory_kwargs["draft_pipeline_model"] = ( + draft_arch.pipeline_model + ) + factory_kwargs["draft_weight_adapters"] = ( + draft_arch.weight_adapters + ) + + pipeline_factory = cast( + Callable[[], PipelineTypes], + functools.partial( # type: ignore + pipeline_class, **factory_kwargs + ), ) - if draft_arch is None: + + if tokenizer.eos is None: raise ValueError( - f"MAX-Optimized architecture not found for draft model " - f"'{pipeline_config.draft_model_config.model_path}'" + "tokenizer.eos value is None, tokenizer configuration is incomplete." ) - factory_kwargs["draft_pipeline_model"] = draft_arch.pipeline_model - factory_kwargs["draft_weight_adapters"] = draft_arch.weight_adapters - - pipeline_factory = cast( - Callable[[], PipelineTypes], - functools.partial( # type: ignore - pipeline_class, **factory_kwargs - ), - ) - if tokenizer.eos is None: - raise ValueError( - "tokenizer.eos value is None, tokenizer configuration is incomplete." + return tokenizer, pipeline_factory + else: + factory_kwargs = { + "pipeline_config": pipeline_config, + "diffusion_pipeline": arch.pipeline_model, + } + pipeline_factory = cast( + Callable[[], PipelineTypes], + functools.partial( # type: ignore + pipeline_class, **factory_kwargs + ), ) - - return tokenizer, pipeline_factory + return None, pipeline_factory def retrieve_context_type( self, pipeline_config: PipelineConfig From 3776b06f30a9ac789407c0c51101a0178bfa9ecf Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Sat, 17 Jan 2026 10:28:06 +0000 Subject: [PATCH 03/18] chore: allow negative prompt --- max/python/max/interfaces/pipeline_variants/image_generation.py | 2 ++ .../max/pipelines/lib/pipeline_variants/image_generation.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/max/python/max/interfaces/pipeline_variants/image_generation.py b/max/python/max/interfaces/pipeline_variants/image_generation.py index dfa55f2a2b6..d761298a7d6 100644 --- a/max/python/max/interfaces/pipeline_variants/image_generation.py +++ b/max/python/max/interfaces/pipeline_variants/image_generation.py @@ -25,6 +25,8 @@ class ImageGenerationInputs(PipelineInputs): # request scheduling. `ImageGenerationContext` should be used once # request scheduling is implemented. prompt: str + negative_prompt: str | None + true_cfg_scale: float height: int width: int num_inference_steps: int diff --git a/max/python/max/pipelines/lib/pipeline_variants/image_generation.py b/max/python/max/pipelines/lib/pipeline_variants/image_generation.py index bbb74c0cae3..5841263e36f 100644 --- a/max/python/max/pipelines/lib/pipeline_variants/image_generation.py +++ b/max/python/max/pipelines/lib/pipeline_variants/image_generation.py @@ -205,6 +205,8 @@ def download( def execute(self, inputs: ImageGenerationInputs) -> ImageGenerationOutput: outputs = self._diffusion_pipeline( prompt=inputs.prompt, + negative_prompt=inputs.negative_prompt, + true_cfg_scale=inputs.true_cfg_scale, height=inputs.height, width=inputs.width, num_inference_steps=inputs.num_inference_steps, From 2f99887560f82a75fabf74af7aba16276075f207 Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Fri, 16 Jan 2026 09:41:14 +0000 Subject: [PATCH 04/18] feat: add model definitions for components of Flux1 pipeline --- max/python/max/dtype/__init__.py | 1 + max/python/max/dtype/dtype_extension.py | 56 ++ max/python/max/nn/norm/group_norm.py | 5 +- max/python/max/nn/norm/layer_norm.py | 45 +- .../max/pipelines/architectures/__init__.py | 2 + .../architectures/autoencoder_kl/__init__.py | 14 + .../autoencoder_kl/autoencoder_kl.py | 750 ++++++++++++++++ .../autoencoder_kl/layers/__init__.py | 14 + .../autoencoder_kl/layers/upsampling.py | 168 ++++ .../architectures/autoencoder_kl/model.py | 74 ++ .../autoencoder_kl/model_config.py | 66 ++ .../pipelines/architectures/clip/__init__.py | 14 + .../max/pipelines/architectures/clip/clip.py | 517 +++++++++++ .../max/pipelines/architectures/clip/model.py | 70 ++ .../architectures/clip/model_config.py | 63 ++ .../pipelines/architectures/flux1/__init__.py | 14 + .../max/pipelines/architectures/flux1/arch.py | 38 + .../pipelines/architectures/flux1/flux1.py | 544 ++++++++++++ .../architectures/flux1/layers/__init__.py | 12 + .../architectures/flux1/layers/activations.py | 56 ++ .../architectures/flux1/layers/embeddings.py | 471 ++++++++++ .../flux1/layers/flux_attention.py | 474 ++++++++++ .../flux1/layers/normalizations.py | 254 ++++++ .../pipelines/architectures/flux1/model.py | 77 ++ .../architectures/flux1/model_config.py | 59 ++ .../architectures/flux1/weight_adapters.py | 30 + .../pipelines/architectures/t5/__init__.py | 14 + .../max/pipelines/architectures/t5/model.py | 64 ++ .../architectures/t5/model_config.py | 69 ++ .../max/pipelines/architectures/t5/t5.py | 823 ++++++++++++++++++ .../max/pipelines/lib/interfaces/max_model.py | 45 + 31 files changed, 4888 insertions(+), 15 deletions(-) create mode 100644 max/python/max/dtype/dtype_extension.py create mode 100644 max/python/max/pipelines/architectures/autoencoder_kl/__init__.py create mode 100644 max/python/max/pipelines/architectures/autoencoder_kl/autoencoder_kl.py create mode 100644 max/python/max/pipelines/architectures/autoencoder_kl/layers/__init__.py create mode 100644 max/python/max/pipelines/architectures/autoencoder_kl/layers/upsampling.py create mode 100644 max/python/max/pipelines/architectures/autoencoder_kl/model.py create mode 100644 max/python/max/pipelines/architectures/autoencoder_kl/model_config.py create mode 100644 max/python/max/pipelines/architectures/clip/__init__.py create mode 100644 max/python/max/pipelines/architectures/clip/clip.py create mode 100644 max/python/max/pipelines/architectures/clip/model.py create mode 100644 max/python/max/pipelines/architectures/clip/model_config.py create mode 100644 max/python/max/pipelines/architectures/flux1/__init__.py create mode 100644 max/python/max/pipelines/architectures/flux1/arch.py create mode 100644 max/python/max/pipelines/architectures/flux1/flux1.py create mode 100644 max/python/max/pipelines/architectures/flux1/layers/__init__.py create mode 100644 max/python/max/pipelines/architectures/flux1/layers/activations.py create mode 100644 max/python/max/pipelines/architectures/flux1/layers/embeddings.py create mode 100644 max/python/max/pipelines/architectures/flux1/layers/flux_attention.py create mode 100644 max/python/max/pipelines/architectures/flux1/layers/normalizations.py create mode 100644 max/python/max/pipelines/architectures/flux1/model.py create mode 100644 max/python/max/pipelines/architectures/flux1/model_config.py create mode 100644 max/python/max/pipelines/architectures/flux1/weight_adapters.py create mode 100644 max/python/max/pipelines/architectures/t5/__init__.py create mode 100644 max/python/max/pipelines/architectures/t5/model.py create mode 100644 max/python/max/pipelines/architectures/t5/model_config.py create mode 100644 max/python/max/pipelines/architectures/t5/t5.py create mode 100644 max/python/max/pipelines/lib/interfaces/max_model.py diff --git a/max/python/max/dtype/__init__.py b/max/python/max/dtype/__init__.py index 864514236b8..ad702907d8c 100644 --- a/max/python/max/dtype/__init__.py +++ b/max/python/max/dtype/__init__.py @@ -11,4 +11,5 @@ # limitations under the License. # ===----------------------------------------------------------------------=== # +from . import dtype_extension from .dtype import DType diff --git a/max/python/max/dtype/dtype_extension.py b/max/python/max/dtype/dtype_extension.py new file mode 100644 index 00000000000..fba3de7e83f --- /dev/null +++ b/max/python/max/dtype/dtype_extension.py @@ -0,0 +1,56 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Extension for max.dtype to support additional attributes.""" + +from numpy import finfo as np_finfo + +from .dtype import DType + + +class finfo: + """A numerical properties of a floating point max.dtype.DType. + + This class mimics torch.finfo behavior without torch dependency, + including support for bfloat16. + + NOTE: Currently, it's applied through patching. + This extension is better to be implemented in dtype library itself. + """ + + def __init__(self, dtype: DType): + """Initialize finfo for a given max.dtype.DType. + + Args: + dtype: The data type to get limits for. + """ + if dtype == DType.bfloat16: + self.min = -3.38953e38 + self.max = 3.38953e38 + self.bits = 16 + self.eps = 0.0078125 + self.resolution = 0.01 + self.tiny = 1.17549e-38 + self.dtype = "bfloat16" + else: + np_finfo_obj = np_finfo(dtype.to_numpy()) + self.min = float(np_finfo_obj.min) + self.max = float(np_finfo_obj.max) + self.bits = np_finfo_obj.bits + self.eps = float(np_finfo_obj.eps) + self.resolution = float(np_finfo_obj.resolution) + self.tiny = float(np_finfo_obj.tiny) + self.dtype = str(np_finfo_obj.dtype) + + +DType.finfo = finfo diff --git a/max/python/max/nn/norm/group_norm.py b/max/python/max/nn/norm/group_norm.py index f1241d1a9aa..7ffc38e3ea6 100644 --- a/max/python/max/nn/norm/group_norm.py +++ b/max/python/max/nn/norm/group_norm.py @@ -45,6 +45,7 @@ def __init__( eps: float = 1e-5, affine: bool = True, device: DeviceRef = DeviceRef.GPU(), + dtype: DType = DType.float32, ) -> None: super().__init__() self.num_groups = num_groups @@ -65,13 +66,13 @@ def __init__( self.weight = Weight( name="weight", shape=(self.num_channels,), - dtype=DType.float32, + dtype=dtype, device=device, ) self.bias = Weight( name="bias", shape=(self.num_channels,), - dtype=DType.float32, + dtype=dtype, device=device, ) diff --git a/max/python/max/nn/norm/layer_norm.py b/max/python/max/nn/norm/layer_norm.py index 47a328c0faa..3e247bf7689 100644 --- a/max/python/max/nn/norm/layer_norm.py +++ b/max/python/max/nn/norm/layer_norm.py @@ -36,37 +36,56 @@ def __init__( dtype: DType, eps: float = 1e-5, use_bias: bool = True, + keep_dtype: bool = False, + elementwise_affine: bool = True, ) -> None: super().__init__() self.devices = devices - self.weight = Weight("weight", dtype, (dims,), device=self.devices[0]) - self.bias = ( - Weight("bias", dtype, (dims,), device=self.devices[0]) - if use_bias - else None - ) + if elementwise_affine: + self.weight = Weight( + "weight", dtype, (dims,), device=self.devices[0] + ) + self.bias = ( + Weight("bias", dtype, (dims,), device=self.devices[0]) + if use_bias + else None + ) + else: + self.weight = None + self.bias = None self.eps = eps self.dim = dims self.dtype = dtype + self.keep_dtype = keep_dtype self._sharding_strategy: ShardingStrategy | None = None def __call__(self, input: TensorValue): # TODO: AIPIPE-95 Replace with a broadcasting rmo.layer_norm bias = ( - ops.cast(self.bias, DType.float32) + self.bias if self.bias # If bias wasn't passed then use bias-less layer norm (beta = 0). else ops.broadcast_to( - ops.constant(0.0, DType.float32, self.weight.device), + ops.constant(0.0, self.dtype, input.device), + shape=(input.shape[-1],), + ) + ) + gamma = ( + self.weight + if self.weight + else ops.broadcast_to( + ops.constant(1.0, self.dtype, input.device), shape=(input.shape[-1],), ) ) - return ops.layer_norm( - input.cast(DType.float32), - gamma=ops.cast(self.weight, DType.float32), - beta=bias, + + output = ops.layer_norm( + input=input if self.keep_dtype else input.cast(DType.float32), + gamma=gamma if self.keep_dtype else ops.cast(gamma, DType.float32), + beta=bias if self.keep_dtype else ops.cast(bias, DType.float32), epsilon=self.eps, - ).cast(input.dtype) + ) + return output if self.keep_dtype else output.cast(input.dtype) @property def sharding_strategy(self) -> ShardingStrategy | None: diff --git a/max/python/max/pipelines/architectures/__init__.py b/max/python/max/pipelines/architectures/__init__.py index 3745c11317b..cdc222baa83 100644 --- a/max/python/max/pipelines/architectures/__init__.py +++ b/max/python/max/pipelines/architectures/__init__.py @@ -28,6 +28,7 @@ def register_all_models() -> None: from .deepseekV3 import deepseekV3_arch from .eagle_llama3 import eagle_llama_arch from .exaone import exaone_arch + from .flux1 import flux1_arch from .gemma3 import gemma3_arch from .gemma3multimodal import gemma3_multimodal_arch from .gpt_oss import gpt_oss_arch @@ -54,6 +55,7 @@ def register_all_models() -> None: deepseekV2_arch, deepseekV3_arch, eagle_llama_arch, + flux1_arch, gemma3_arch, gemma3_multimodal_arch, granite_arch, diff --git a/max/python/max/pipelines/architectures/autoencoder_kl/__init__.py b/max/python/max/pipelines/architectures/autoencoder_kl/__init__.py new file mode 100644 index 00000000000..e18af050cf2 --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoder_kl/__init__.py @@ -0,0 +1,14 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .model import AutoencoderKLModel diff --git a/max/python/max/pipelines/architectures/autoencoder_kl/autoencoder_kl.py b/max/python/max/pipelines/architectures/autoencoder_kl/autoencoder_kl.py new file mode 100644 index 00000000000..ef80846d949 --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoder_kl/autoencoder_kl.py @@ -0,0 +1,750 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from dataclasses import dataclass + +import max.nn as nn +from max.dtype import DType +from max.graph import DeviceRef, TensorType, TensorValue, ops +from max.nn import GroupNorm +from max.nn.layer.layer_list import LayerList + +from .layers import Upsample2D +from .model_config import AutoencoderKLConfig + + +class ResnetBlock2D(nn.Module): + """Residual block for 2D VAE decoder. + + This module implements a residual block with two convolutional layers, + group normalization, and optional shortcut connection. It supports + time embedding conditioning and configurable activation functions. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int | None, + groups: int, + groups_out: int, + eps: float = 1e-6, + non_linearity: str = "silu", + use_conv_shortcut: bool = False, + conv_shortcut_bias: bool = True, + device: DeviceRef | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize ResnetBlock2D module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + temb_channels: Number of time embedding channels (None if not used). + groups: Number of groups for first GroupNorm. + groups_out: Number of groups for second GroupNorm. + eps: Epsilon value for GroupNorm layers. + non_linearity: Activation function name (e.g., "silu"). + use_conv_shortcut: Whether to use convolutional shortcut. + conv_shortcut_bias: Whether to use bias in shortcut convolution. + device: Device reference for module placement. + dtype: Data type for module parameters. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.use_conv_shortcut = use_conv_shortcut + + self.norm1 = GroupNorm( + num_groups=groups, + num_channels=in_channels, + eps=eps, + affine=True, + device=device, + dtype=dtype, + ) + + self.conv1 = nn.Conv2d( + kernel_size=3, + in_channels=in_channels, + out_channels=out_channels, + dtype=dtype, + stride=1, + padding=1, + dilation=1, + num_groups=1, + has_bias=True, + device=device, + permute=True, + ) + + self.norm2 = GroupNorm( + num_groups=groups_out, + num_channels=out_channels, + eps=eps, + affine=True, + device=device, + dtype=dtype, + ) + + self.conv2 = nn.Conv2d( + kernel_size=3, + in_channels=out_channels, + out_channels=out_channels, + dtype=dtype, + stride=1, + padding=1, + dilation=1, + num_groups=1, + has_bias=True, + device=device, + permute=True, + ) + + self.conv_shortcut = None + if self.use_conv_shortcut: + self.conv_shortcut = nn.Conv2d( + kernel_size=1, + in_channels=in_channels, + out_channels=out_channels, + dtype=dtype, + stride=1, + padding=0, + dilation=1, + num_groups=1, + has_bias=conv_shortcut_bias, + device=device, + permute=True, + ) + elif in_channels != out_channels: + self.conv_shortcut = nn.Conv2d( + kernel_size=1, + in_channels=in_channels, + out_channels=out_channels, + dtype=dtype, + stride=1, + padding=0, + dilation=1, + num_groups=1, + has_bias=conv_shortcut_bias, + device=device, + permute=True, + ) + + def __call__( + self, x: TensorValue, temb: TensorValue | None = None + ) -> TensorValue: + """Apply ResnetBlock2D forward pass. + + Args: + x: Input tensor of shape [N, C, H, W]. + temb: Optional time embedding tensor (currently unused). + + Returns: + Output tensor of shape [N, C_out, H, W] with residual connection. + """ + shortcut = ( + self.conv_shortcut(x) if self.conv_shortcut is not None else x + ) + + h = ops.silu(self.norm1(x)) + h = self.conv1(h) + + h = ops.silu(self.norm2(h)) + h = self.conv2(h) + + return h + shortcut + + +class UpDecoderBlock2D(nn.Module): + """Upsampling decoder block for 2D VAE. + + This module consists of multiple ResNet blocks followed by an optional + upsampling layer. It progressively increases spatial resolution while + processing features through residual connections. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: int | None = None, + device: DeviceRef | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize UpDecoderBlock2D module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + resolution_idx: Optional resolution index for tracking. + dropout: Dropout rate (currently unused). + num_layers: Number of ResNet blocks in this decoder block. + resnet_eps: Epsilon value for ResNet GroupNorm layers. + resnet_time_scale_shift: Time embedding scale/shift mode. + resnet_act_fn: Activation function for ResNet blocks. + resnet_groups: Number of groups for ResNet GroupNorm. + resnet_pre_norm: Whether to apply normalization before ResNet. + output_scale_factor: Scaling factor for output (currently unused). + add_upsample: Whether to add upsampling layer after ResNet blocks. + temb_channels: Number of time embedding channels (None if not used). + device: Device reference for module placement. + dtype: Data type for module parameters. + """ + super().__init__() + resnets_list = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnet = ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + groups=resnet_groups, + groups_out=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + use_conv_shortcut=False, + conv_shortcut_bias=True, + device=device, + dtype=dtype, + ) + resnets_list.append(resnet) + self.resnets = LayerList(resnets_list) + + if add_upsample: + upsampler = Upsample2D( + channels=out_channels, + use_conv=True, + out_channels=out_channels, + name="conv", + kernel_size=3, + padding=1, + bias=True, + interpolate=True, + device=device, + dtype=dtype, + ) + self.upsamplers = LayerList([upsampler]) + else: + self.upsamplers = None + + def __call__( + self, hidden_states: TensorValue, temb: TensorValue | None = None + ) -> TensorValue: + """Apply UpDecoderBlock2D forward pass. + + Args: + hidden_states: Input tensor of shape [N, C_in, H, W]. + temb: Optional time embedding tensor. + + Returns: + Output tensor of shape [N, C_out, H*2, W*2] (if upsampling) or + [N, C_out, H, W] (if no upsampling). + """ + # Process through all resnet blocks + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + + # Apply upsampling if configured (compile-time decision) + if self.upsamplers is not None: + hidden_states = self.upsamplers[0](hidden_states) + + return hidden_states + + +class VAEAttention(nn.Module): + """Spatial attention module for VAE models. + + This module performs self-attention on 2D spatial features by: + 1. Converting [N, C, H, W] to [N, H*W, C] sequence format + 2. Applying scaled dot-product attention (optimized for small sequences) + 3. Converting back to [N, C, H, W] format + + Note: Manual attention is used instead of flash_attention_gpu because + VAE attention typically has small sequence lengths (H*W) where flash + attention overhead outweighs benefits. + """ + + def __init__( + self, + query_dim: int, + heads: int, + dim_head: int, + num_groups: int = 32, + eps: float = 1e-6, + device: DeviceRef | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize VAE attention module. + + Args: + query_dim: Dimension of query (number of channels). + heads: Number of attention heads. + dim_head: Dimension of each attention head. + num_groups: Number of groups for GroupNorm. + eps: Epsilon value for GroupNorm. + device: Device reference. + dtype: Data type. + """ + super().__init__() + self.query_dim = query_dim + self.heads = heads + self.dim_head = dim_head + self.inner_dim = heads * dim_head + + self.group_norm = GroupNorm( + num_groups=num_groups, + num_channels=query_dim, + eps=eps, + affine=True, + device=device, + dtype=dtype, + ) + + self.to_q = nn.Linear( + query_dim, self.inner_dim, has_bias=True, device=device, dtype=dtype + ) + self.to_k = nn.Linear( + query_dim, self.inner_dim, has_bias=True, device=device, dtype=dtype + ) + self.to_v = nn.Linear( + query_dim, self.inner_dim, has_bias=True, device=device, dtype=dtype + ) + self.to_out = LayerList( + [ + nn.Linear( + self.inner_dim, + query_dim, + has_bias=True, + device=device, + dtype=dtype, + ) + ] + ) + + self.scale = 1.0 / math.sqrt(dim_head) + + def __call__(self, x: TensorValue) -> TensorValue: + """Apply spatial attention to 2D image tensor. + + Args: + x: Input tensor of shape [N, C, H, W]. + + Returns: + Output tensor of shape [N, C, H, W] with residual connection. + """ + residual = x + + x = self.group_norm(x) + + n, c, h, w = x.shape + seq_len = h * w + + x = ops.reshape(x, (n, c, seq_len)) + x = ops.permute(x, (0, 2, 1)) + + q = self.to_q(x) + k = self.to_k(x) + v = self.to_v(x) + + q = ops.reshape(q, (n, seq_len, self.heads, self.dim_head)) + q = ops.permute(q, (0, 2, 1, 3)) + k = ops.reshape(k, (n, seq_len, self.heads, self.dim_head)) + k = ops.permute(k, (0, 2, 1, 3)) + v = ops.reshape(v, (n, seq_len, self.heads, self.dim_head)) + v = ops.permute(v, (0, 2, 1, 3)) + + attn = q @ ops.permute(k, (0, 1, 3, 2)) * self.scale + attn = ops.softmax(attn, axis=-1) + out = attn @ v + + out = ops.permute(out, (0, 2, 1, 3)) + out = ops.reshape(out, (n, seq_len, self.inner_dim)) + + out = self.to_out[0](out) + + out = ops.permute(out, (0, 2, 1)) + out = ops.reshape(out, (n, c, h, w)) + + return residual + out + + +class MidBlock2D(nn.Module): + """Internal MAX module for MidBlock2D graph generation.""" + + def __init__( + self, + in_channels: int, + temb_channels: int | None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + device: DeviceRef | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize MidBlock2D module.""" + super().__init__() + resnets_list = [] + attentions_list = [] + + resnet = ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + groups=resnet_groups, + groups_out=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + use_conv_shortcut=False, + conv_shortcut_bias=True, + device=device, + dtype=dtype, + ) + resnets_list.append(resnet) + + for _i in range(num_layers): + if add_attention: + attn = VAEAttention( + query_dim=in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + num_groups=resnet_groups, + eps=resnet_eps, + device=device, + dtype=dtype, + ) + attentions_list.append(attn) + else: + attentions_list.append(None) + + resnet = ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + groups=resnet_groups, + groups_out=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + use_conv_shortcut=False, + conv_shortcut_bias=True, + device=device, + dtype=dtype, + ) + resnets_list.append(resnet) + + self.resnets = LayerList(resnets_list) + self.attentions = ( + LayerList(attentions_list) if attentions_list else None + ) + + def __call__( + self, hidden_states: TensorValue, temb: TensorValue | None = None + ) -> TensorValue: + """Apply MidBlock2D forward pass. + + Args: + hidden_states: Input tensor of shape [N, C, H, W]. + temb: Optional time embedding tensor. + + Returns: + Output tensor of shape [N, C, H, W] with same spatial dimensions. + """ + hidden_states = self.resnets[0](hidden_states, temb) + + for i in range(len(self.resnets) - 1): + if self.attentions is not None and self.attentions[i] is not None: + hidden_states = self.attentions[i](hidden_states) + hidden_states = self.resnets[i + 1](hidden_states, temb) + + return hidden_states + + +@dataclass +class DecoderOutput: + r"""Output of decoding method. + + Args: + sample (`TensorValue` of shape `(batch_size, num_channels, height, width)`): + The decoded output sample from the last layer of the model. + """ + + sample: TensorValue + commit_loss: TensorValue | None = None + + +class Decoder(nn.Module): + """VAE decoder for generating images from latent representations. + + This decoder progressively upsamples latent features through multiple + decoder blocks, applying ResNet layers and attention mechanisms to + reconstruct high-resolution images from compressed latent codes. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", + mid_block_add_attention: bool = True, + use_post_quant_conv: bool = True, + device: DeviceRef | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize Decoder module. + + Args: + in_channels: Number of input channels (latent channels). + out_channels: Number of output channels (image channels). + up_block_types: Tuple of upsampling block types. + block_out_channels: Tuple of channel counts for each decoder block. + layers_per_block: Number of ResNet layers per decoder block. + norm_num_groups: Number of groups for GroupNorm layers. + act_fn: Activation function name (e.g., "silu"). + norm_type: Normalization type ("group" or "spatial"). + mid_block_add_attention: Whether to add attention in middle block. + use_post_quant_conv: Whether to use post-quantization convolution. + device: Device reference for module placement. + dtype: Data type for module parameters. + """ + super().__init__() + self.layers_per_block = layers_per_block + self.session = None + self.in_channels = in_channels + self.device = device + self.dtype = dtype + + self.post_quant_conv = None + if use_post_quant_conv: + self.post_quant_conv = nn.Conv2d( + kernel_size=1, + in_channels=in_channels, + out_channels=in_channels, + dtype=dtype, + stride=1, + padding=0, + dilation=1, + num_groups=1, + has_bias=True, + device=device, + permute=True, + ) + + self.conv_in = nn.Conv2d( + kernel_size=3, + in_channels=in_channels, + out_channels=block_out_channels[-1], + dtype=dtype, + stride=1, + padding=1, + dilation=1, + num_groups=1, + has_bias=True, + device=device, + permute=True, + ) + + temb_channels = in_channels if norm_type == "spatial" else None + self.mid_block = MidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=temb_channels, + dropout=0.0, + num_layers=1, + resnet_eps=1e-6, + resnet_time_scale_shift=( + "default" if norm_type == "group" else norm_type + ), + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_pre_norm=True, + add_attention=mid_block_add_attention, + attention_head_dim=block_out_channels[-1], + output_scale_factor=1.0, + device=device, + dtype=dtype, + ) + + up_blocks_list = [] + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if up_block_type == "UpDecoderBlock2D": + up_block = UpDecoderBlock2D( + in_channels=prev_output_channel, + out_channels=output_channel, + resolution_idx=i, + dropout=0.0, + num_layers=self.layers_per_block + 1, + resnet_eps=1e-6, + resnet_time_scale_shift=norm_type, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_pre_norm=True, + output_scale_factor=1.0, + add_upsample=not is_final_block, + temb_channels=temb_channels, + device=device, + dtype=dtype, + ) + up_blocks_list.append(up_block) + else: + raise ValueError(f"Unsupported up_block_type: {up_block_type}") + + prev_output_channel = output_channel + + self.up_blocks = LayerList(up_blocks_list) + + if norm_type == "spatial": + raise NotImplementedError("SpatialNorm not implemented in MAX VAE") + else: + self.conv_norm_out = GroupNorm( + num_groups=norm_num_groups, + num_channels=block_out_channels[0], + eps=1e-6, + affine=True, + device=device, + dtype=dtype, + ) + + self.conv_out = nn.Conv2d( + kernel_size=3, + in_channels=block_out_channels[0], + out_channels=out_channels, + dtype=dtype, + stride=1, + padding=1, + dilation=1, + num_groups=1, + has_bias=True, + device=device, + permute=True, + ) + + def __call__( + self, z: TensorValue, temb: TensorValue | None = None + ) -> TensorValue: + """Apply Decoder forward pass. + + Args: + z: Input latent tensor of shape [N, C_latent, H_latent, W_latent]. + temb: Optional time embedding tensor. + + Returns: + Decoded image tensor of shape [N, C_out, H, W] where H and W are + upsampled from H_latent and W_latent. + """ + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + sample = self.conv_in(z) + sample = self.mid_block(sample, temb) + + for up_block in self.up_blocks: + sample = up_block(sample, temb) + + sample = self.conv_norm_out(sample) + sample = ops.silu(sample) + sample = self.conv_out(sample) + + return sample + + def input_types(self) -> tuple[TensorType, ...]: + """Define input tensor types for the decoder model. + + Returns: + Tuple of TensorType specifications for decoder input. + """ + latent_type = TensorType( + self.dtype, + shape=[ + "batch_size", + self.in_channels, + "latent_height", + "latent_width", + ], + device=self.device, + ) + + return (latent_type,) + + +class AutoencoderKL(nn.Module): + r"""A VAE model with KL loss for encoding images into latents and decoding latent representations into images.""" + + def __init__( + self, + config: AutoencoderKLConfig, + ): + """Initialize VAE AutoencoderKL model. + + Args: + config: Autoencoder configuration containing channel sizes, block + structure, normalization settings, and device/dtype information. + """ + super().__init__() + self.decoder = Decoder( + in_channels=config.latent_channels, + out_channels=config.out_channels, + up_block_types=config.up_block_types, + block_out_channels=config.block_out_channels, + layers_per_block=config.layers_per_block, + norm_num_groups=config.norm_num_groups, + act_fn=config.act_fn, + norm_type="group", + mid_block_add_attention=config.mid_block_add_attention, + use_post_quant_conv=config.use_post_quant_conv, + device=config.device, + dtype=config.dtype, + ) + + def __call__(self, *args, **kwargs): + pass diff --git a/max/python/max/pipelines/architectures/autoencoder_kl/layers/__init__.py b/max/python/max/pipelines/architectures/autoencoder_kl/layers/__init__.py new file mode 100644 index 00000000000..ba1d3f82854 --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoder_kl/layers/__init__.py @@ -0,0 +1,14 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .upsampling import Upsample2D diff --git a/max/python/max/pipelines/architectures/autoencoder_kl/layers/upsampling.py b/max/python/max/pipelines/architectures/autoencoder_kl/layers/upsampling.py new file mode 100644 index 00000000000..35a526b32b0 --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoder_kl/layers/upsampling.py @@ -0,0 +1,168 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Upsampling utilities for MAX framework.""" + +import max.nn as nn +from max.dtype import DType +from max.experimental import tensor +from max.graph import DeviceRef, TensorValue, ops + + +class Interpolate2DNearest(nn.Module): + """2D nearest-neighbor upsampling module. + + This is a workaround implementation because MAX framework does not have + a native `interpolate` operation. The workaround uses reshape and broadcast + operations to achieve nearest-neighbor upsampling by a factor of 2. + + Note: + This workaround can be removed once MAX framework adds native interpolate support. + """ + + def __init__( + self, + scale_factor: int = 2, + device: DeviceRef = None, + dtype: DType = None, + ): + """Initialize 2D nearest-neighbor interpolation module. + + Args: + scale_factor: Upsampling factor (currently only 2 is supported). + device: Device reference for creating intermediate tensors. + dtype: Data type for intermediate tensors. + """ + super().__init__() + if scale_factor != 2: + raise NotImplementedError( + f"Only scale_factor=2 is currently supported, got {scale_factor}" + ) + + self.scale_factor = scale_factor + self.device = device + self.dtype = dtype + + def __call__(self, x: TensorValue) -> TensorValue: + """Upsample a 2D tensor using nearest-neighbor interpolation. + + Args: + x: Input tensor of shape [N, C, H, W]. + + Returns: + Upsampled tensor of shape [N, C, H*scale_factor, W*scale_factor]. + """ + n, c, h, w = x.shape + target_shape = [n, c, h * self.scale_factor, w * self.scale_factor] + + x_reshaped = ops.reshape(x, [n, c, h, 1, w, 1]) + + ones = tensor.Tensor.ones( + shape=(1, 1, 1, self.scale_factor, 1, self.scale_factor), + dtype=self.dtype, + device=self.device, + ) + x_expanded = x_reshaped * ones + + x = ops.reshape(x_expanded, target_shape) + + return x + + +class Upsample2D(nn.Module): + """2D upsampling module with optional convolution. + + This module performs 2D upsampling using nearest-neighbor interpolation + (via Interpolate2DNearest workaround) followed by an optional convolution layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: int | None = None, + name: str = "conv", + kernel_size: int | None = None, + padding: int = 1, + bias: bool = True, + interpolate: bool = True, + device: DeviceRef | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize 2D upsampling module. + + Args: + channels: Number of input channels. + use_conv: Whether to apply a convolution after upsampling. + use_conv_transpose: Whether to use transposed convolution (not supported yet). + out_channels: Number of output channels. If None, uses channels. + name: Name for the convolution layer (unused, kept for compatibility). + kernel_size: Kernel size for the convolution. + padding: Padding for the convolution. + bias: Whether to use bias in the convolution. + interpolate: Whether to perform interpolation upsampling. + device: Device reference. + dtype: Data type. + """ + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.interpolate = interpolate + self.device = device + self.dtype = dtype + + self.interpolate_module = None + if interpolate: + self.interpolate_module = Interpolate2DNearest( + scale_factor=2, device=device, dtype=dtype + ) + + self.conv = None + if use_conv_transpose: + raise NotImplementedError( + "Upsample2D does not support use_conv_transpose=True yet." + ) + elif use_conv: + if kernel_size is None: + kernel_size = 3 + self.conv = nn.Conv2d( + kernel_size=kernel_size, + in_channels=self.channels, + out_channels=self.out_channels, + dtype=dtype, + stride=1, + padding=padding, + has_bias=bias, + device=device, + permute=True, + ) + + def __call__(self, x: TensorValue) -> TensorValue: + """Apply 2D upsampling with optional convolution. + + Args: + x: Input tensor of shape [N, C, H, W]. + + Returns: + Upsampled tensor, optionally convolved. + """ + if self.interpolate_module is not None: + x = self.interpolate_module(x) + + if self.use_conv: + x = self.conv(x) + + return x diff --git a/max/python/max/pipelines/architectures/autoencoder_kl/model.py b/max/python/max/pipelines/architectures/autoencoder_kl/model.py new file mode 100644 index 00000000000..04ba6006d3d --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoder_kl/model.py @@ -0,0 +1,74 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.driver import CPU, Accelerator, Device, Tensor +from max.engine import InferenceSession, Model +from max.graph import Graph +from max.graph.weights import Weights +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.interfaces.max_model import MaxModel + +from .autoencoder_kl import AutoencoderKL +from .model_config import AutoencoderKLConfig + + +class AutoencoderKLModel(MaxModel): + config_name = AutoencoderKLConfig.config_name + + def __init__( + self, + config: dict, + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + ) -> None: + super().__init__(config, encoding, devices, weights) + self.config = AutoencoderKLConfig.generate( + config, + encoding, + devices, + ) + self.load_model() + + def load_model(self) -> None: + autoencoder_kl = AutoencoderKL(self.config) + + if self.config.device.is_cpu(): + session = InferenceSession([CPU()]) + else: + session = InferenceSession([Accelerator()]) + + self.load_decoder(session, autoencoder_kl) + + def load_decoder( + self, session: InferenceSession, autoencoder_kl: AutoencoderKL + ) -> Model: + state_dict = { + key: value.data() + for key, value in self.weights.items() + if not key.startswith("encoder.") + } + autoencoder_kl.load_state_dict(state_dict) + with Graph( + "autoencoder_kl_decoder", + input_types=autoencoder_kl.decoder.input_types(), + ) as graph: + outputs = autoencoder_kl.decoder(*graph.inputs) + graph.output(outputs) + compiled_graph = graph + self.decode_session = session.load( + compiled_graph, weights_registry=autoencoder_kl.state_dict() + ) + + def decode(self, *args, **kwargs) -> list[Tensor]: + return self.decode_session.execute(*args, **kwargs) diff --git a/max/python/max/pipelines/architectures/autoencoder_kl/model_config.py b/max/python/max/pipelines/architectures/autoencoder_kl/model_config.py new file mode 100644 index 00000000000..f3b28565fc4 --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoder_kl/model_config.py @@ -0,0 +1,66 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from typing import ClassVar + +from max.driver import Device +from max.dtype import DType +from max.graph import DeviceRef +from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding +from pydantic import Field + + +class AutoencoderKLConfigBase(MAXModelConfigBase): + in_channels: int = 3 + out_channels: int = 3 + down_block_types: list[str] = Field(default_factory=list, max_length=4) + up_block_types: list[str] = Field(default_factory=list, max_length=4) + block_out_channels: list[int] = Field(default_factory=list, max_length=4) + layers_per_block: int = 1 + act_fn: str = "silu" + latent_channels: int = 4 + norm_num_groups: int = 32 + sample_size: int = 32 + scaling_factor: float = 0.18215 + shift_factor: float | None = None + latents_mean: tuple[float] | None = None + latents_std: tuple[float] | None = None + force_upcast: bool = True + use_quant_conv: bool = True + use_post_quant_conv: bool = True + mid_block_add_attention: bool = True + device: DeviceRef = Field(default_factory=DeviceRef.CPU) + dtype: DType = DType.bfloat16 + + +class AutoencoderKLConfig(AutoencoderKLConfigBase): + config_name: ClassVar[str] = "config.json" + + @staticmethod + def generate( + config_dict: dict, + encoding: SupportedEncoding, + devices: list[Device], + ) -> AutoencoderKLConfigBase: + init_dict = { + key: value + for key, value in config_dict.items() + if key in AutoencoderKLConfigBase.__annotations__ + } + init_dict.update( + { + "dtype": encoding.dtype, + "device": DeviceRef.from_device(devices[0]), + } + ) + return AutoencoderKLConfigBase(**init_dict) diff --git a/max/python/max/pipelines/architectures/clip/__init__.py b/max/python/max/pipelines/architectures/clip/__init__.py new file mode 100644 index 00000000000..32cb1def84e --- /dev/null +++ b/max/python/max/pipelines/architectures/clip/__init__.py @@ -0,0 +1,14 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .model import ClipModel diff --git a/max/python/max/pipelines/architectures/clip/clip.py b/max/python/max/pipelines/architectures/clip/clip.py new file mode 100644 index 00000000000..95642bfbd93 --- /dev/null +++ b/max/python/max/pipelines/architectures/clip/clip.py @@ -0,0 +1,517 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from functools import partial + +import max.nn as nn +from max.dtype import DType +from max.graph import TensorType, TensorValue, ops +from max.nn import LayerNorm, Module + +from .model_config import ClipConfig + + +class CLIPTextEmbeddings(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP text embeddings. + + Args: + config: CLIP configuration for embedding dimensions and device/dtype. + """ + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.position_embedding = nn.Embedding( + config.max_position_embeddings, + self.embed_dim, + device=config.device, + dtype=config.dtype, + ) + self.token_embedding = nn.Embedding( + config.vocab_size, + self.embed_dim, + device=config.device, + dtype=config.dtype, + ) + + def __call__( + self, + input_ids: TensorValue | None = None, + position_ids: TensorValue | None = None, + inputs_embeds: TensorValue | None = None, + ) -> TensorValue: + """Apply embeddings to input tokens. + + Args: + input_ids: Input token IDs. + position_ids: Position IDs. + inputs_embeds: Pre-computed input embeddings. + + Returns: + Combined embeddings. + """ + if input_ids is None and inputs_embeds is None: + raise ValueError( + "You have to specify either input_ids or inputs_embeds" + ) + + if input_ids is not None: + seq_length = input_ids.shape[-1] + else: + seq_length = inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = ops.range( + 0, + seq_length, + step=1, + dtype=DType.int32, + device=self.config.device, + ) + position_ids = ops.unsqueeze(position_ids, 0) + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class CLIPAttention(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP attention module. + + Args: + config: CLIP configuration for attention dimensions and device/dtype. + """ + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear( + self.embed_dim, + self.embed_dim, + has_bias=True, + device=config.device, + dtype=config.dtype, + ) + self.v_proj = nn.Linear( + self.embed_dim, + self.embed_dim, + has_bias=True, + device=config.device, + dtype=config.dtype, + ) + self.q_proj = nn.Linear( + self.embed_dim, + self.embed_dim, + has_bias=True, + device=config.device, + dtype=config.dtype, + ) + self.out_proj = nn.Linear( + self.embed_dim, + self.embed_dim, + has_bias=True, + device=config.device, + dtype=config.dtype, + ) + + def __call__( + self, + hidden_states: TensorValue, + attention_mask: TensorValue | None = None, + causal_attention_mask: TensorValue | None = None, + ) -> TensorValue: + """Apply multi-head attention. + + Args: + hidden_states: Input hidden states. + attention_mask: Attention mask. + causal_attention_mask: Causal attention mask. + + Returns: + Attention output. + """ + batch_size, seq_length, embed_dim = hidden_states.shape + + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = ops.reshape( + query, (batch_size, seq_length, self.num_heads, self.head_dim) + ) + query = ops.transpose(query, 1, 2) + + key = ops.reshape( + key, (batch_size, seq_length, self.num_heads, self.head_dim) + ) + key = ops.transpose(key, 1, 2) + + value = ops.reshape( + value, (batch_size, seq_length, self.num_heads, self.head_dim) + ) + value = ops.transpose(value, 1, 2) + + if attention_mask is not None and causal_attention_mask is not None: + attention_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attention_mask = causal_attention_mask + + attn_weights = ( + ops.matmul(query, ops.transpose(key, -1, -2)) * self.scale + ) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = ops.softmax( + ops.cast(attn_weights, DType.float32), axis=-1 + ) + attn_weights = ops.cast(attn_weights, hidden_states.dtype) + + attn_output = ops.matmul(attn_weights, value) + attn_output = ops.transpose(attn_output, 1, 2) + attn_output = ops.reshape( + attn_output, (batch_size, seq_length, embed_dim) + ) + + attn_output = self.out_proj(attn_output) + + return attn_output + + +class CLIPMLP(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP MLP. + + Args: + config: CLIP configuration for MLP dimensions and device/dtype. + """ + super().__init__() + self.config = config + self.fc1 = nn.Linear( + config.hidden_size, + config.intermediate_size, + has_bias=True, + device=config.device, + dtype=config.dtype, + ) + self.fc2 = nn.Linear( + config.intermediate_size, + config.hidden_size, + has_bias=True, + device=config.device, + dtype=config.dtype, + ) + self.act_fn = partial(ops.gelu, approximate="quick") + + def __call__(self, hidden_states: TensorValue) -> TensorValue: + """Apply MLP block. + + Args: + hidden_states: Input hidden states. + + Returns: + Output hidden states. + """ + hidden_states = self.fc1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class CLIPEncoderLayer(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP encoder layer. + + Args: + config: CLIP configuration for encoder layer structure. + """ + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = CLIPAttention(config) + self.layer_norm1 = LayerNorm( + self.embed_dim, + eps=config.layer_norm_eps, + devices=[config.device], + dtype=config.dtype, + keep_dtype=True, + ) + self.mlp = CLIPMLP(config) + self.layer_norm2 = LayerNorm( + self.embed_dim, + eps=config.layer_norm_eps, + devices=[config.device], + dtype=config.dtype, + keep_dtype=True, + ) + + def __call__( + self, + hidden_states: TensorValue, + attention_mask: TensorValue, + causal_attention_mask: TensorValue, + ) -> TensorValue: + """Apply encoder layer. + + Args: + hidden_states: Input hidden states. + attention_mask: Attention mask. + causal_attention_mask: Causal attention mask. + + Returns: + Output hidden states. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class CLIPEncoder(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP encoder. + + Args: + config: CLIP configuration for encoder depth and dimensions. + """ + super().__init__() + self.layers = nn.LayerList( + [CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + + def __call__( + self, + inputs_embeds: TensorValue, + attention_mask: TensorValue | None = None, + causal_attention_mask: TensorValue | None = None, + ) -> TensorValue: + """Apply encoder (stack of layers). + + Args: + inputs_embeds: Input embeddings. + attention_mask: Attention mask. + causal_attention_mask: Causal attention mask. + + Returns: + Encoded hidden states. + """ + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + ) + return hidden_states + + +class CLIPTextTransformer(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP text transformer. + + Args: + config: CLIP configuration for embeddings, encoder, and device/dtype. + """ + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.embeddings = CLIPTextEmbeddings(config) + self.encoder = CLIPEncoder(config) + self.final_layer_norm = LayerNorm( + self.embed_dim, + eps=config.layer_norm_eps, + devices=[config.device], + dtype=config.dtype, + keep_dtype=True, + ) + self.eos_token_id = config.eos_token_id + + def _create_causal_mask(self, input_shape: tuple[int, int]) -> TensorValue: + """Create causal mask for the transformer. + + Args: + input_shape: Shape of the input tensor. + + Returns: + Causal mask tensor. + """ + _, seq_length = input_shape + + rows = ops.range( + 0, seq_length, step=1, dtype=DType.int32, device=self.config.device + ) + rows = ops.unsqueeze(rows, 1) + cols = ops.range( + 0, seq_length, step=1, dtype=DType.int32, device=self.config.device + ) + cols = ops.unsqueeze(cols, 0) + mask = ops.greater(cols, rows) + mask_float = mask.cast(self.config.dtype) + + min_val = DType.finfo(self.config.dtype).min + + causal_mask = mask_float * min_val + causal_mask = ops.unsqueeze(causal_mask, 0) + causal_mask = ops.unsqueeze(causal_mask, 1) + return causal_mask + + def __call__( + self, + input_ids: TensorValue | None = None, + attention_mask: TensorValue | None = None, + position_ids: TensorValue | None = None, + ) -> TensorValue: + """Apply text transformer. + + Args: + input_ids: Input token IDs. + attention_mask: Attention mask. + position_ids: Position IDs. + + Returns: + Tuple of (last_hidden_state, pooled_output). + """ + if input_ids is None: + raise ValueError("You have to specify input_ids") + + hidden_states = self.embeddings( + input_ids=input_ids, position_ids=position_ids + ) + + input_shape = input_ids.shape + causal_attention_mask = self._create_causal_mask(input_shape) + + if attention_mask is not None: + inverted_mask = ( + 1.0 - attention_mask.cast(hidden_states.dtype) + ) * DType.finfo(hidden_states.dtype).min + attention_mask = ops.unsqueeze(inverted_mask, 1) + attention_mask = ops.unsqueeze(attention_mask, 1) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + ) + + last_hidden_state = self.final_layer_norm(encoder_outputs) + + if self.eos_token_id == 2: + eos_token_indices = ops.argmax(input_ids, axis=-1).cast(DType.int32) + else: + eos_token_indices = ops.argmax( + ops.equal(input_ids, self.eos_token_id).cast(DType.int32), + axis=-1, + ).cast(DType.int32) + + pooled_output = ops.gather_nd( + last_hidden_state, eos_token_indices, batch_dims=1 + ) + + return last_hidden_state, pooled_output + + +class CLIPTextModel(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP text model with MAX. + + Args: + config: CLIP configuration for vocabulary size, dimensions, and + device/dtype settings. + """ + super().__init__() + self.text_model = CLIPTextTransformer(config) + self.device = config.device + + def input_types(self) -> tuple[TensorType, ...]: + """Define input tensor types for the model. + + Returns: + Tuple of TensorType specifications for model inputs. + """ + return ( + TensorType( + DType.int64, + shape=["batch_size", "sequence_length"], + device=self.device, + ), + ) + + def __call__( + self, + input_ids: TensorValue | None = None, + attention_mask: TensorValue | None = None, + position_ids: TensorValue | None = None, + ) -> tuple[TensorValue, TensorValue]: + """Apply CLIP text model forward pass. + + Args: + input_ids: Input token IDs. + attention_mask: Attention mask. + position_ids: Position IDs. + + Returns: + Tuple of (last_hidden_state, pooled_output). + """ + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) diff --git a/max/python/max/pipelines/architectures/clip/model.py b/max/python/max/pipelines/architectures/clip/model.py new file mode 100644 index 00000000000..df9f467591b --- /dev/null +++ b/max/python/max/pipelines/architectures/clip/model.py @@ -0,0 +1,70 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.driver import CPU, Accelerator, Device +from max.engine import InferenceSession, Model +from max.graph import Graph +from max.graph.weights import Weights +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.interfaces.max_model import MaxModel + +from .clip import CLIPTextModel +from .model_config import ClipConfig + + +class ClipModel(MaxModel): + config_name = ClipConfig.config_name + + def __init__( + self, + config: dict, + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + ) -> None: + super().__init__( + config, + encoding, + devices, + weights, + ) + self.config = ClipConfig.generate( + config, + encoding, + devices, + ) + self.load_model() + + def load_model(self) -> Model: + clip = CLIPTextModel(self.config) + + if self.config.device.is_cpu(): + session = InferenceSession([CPU()]) + else: + session = InferenceSession([Accelerator()]) + state_dict = {key: value.data() for key, value in self.weights.items()} + clip.load_state_dict(state_dict) + with Graph("clip_text_model", input_types=clip.input_types()) as graph: + outputs = clip( + *graph.inputs, + attention_mask=None, + position_ids=None, + ) + graph.output(*outputs) + compiled_graph = graph + self.session = session.load( + compiled_graph, weights_registry=clip.state_dict() + ) + + def __call__(self, *args, **kwargs): + return self.session.execute(*args, **kwargs) diff --git a/max/python/max/pipelines/architectures/clip/model_config.py b/max/python/max/pipelines/architectures/clip/model_config.py new file mode 100644 index 00000000000..ccf9ba35083 --- /dev/null +++ b/max/python/max/pipelines/architectures/clip/model_config.py @@ -0,0 +1,63 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from typing import ClassVar + +from max.driver import Device +from max.dtype import DType +from max.graph import DeviceRef +from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding +from pydantic import Field + + +class ClipConfigBase(MAXModelConfigBase): + vocab_size: int = 49408 + hidden_size: int = 512 + intermediate_size: int = 2048 + projection_dim: int = 512 + num_hidden_layers: int = 12 + num_attention_heads: int = 8 + max_position_embeddings: int = 77 + hidden_act: str = "quick_gelu" + layer_norm_eps: float = 1e-5 + attention_dropout: float = 0.0 + initializer_range: float = 0.02 + initializer_factor: float = 1.0 + pad_token_id: int = 1 + bos_token_id: int = 49406 + eos_token_id: int = 49407 + dtype: DType = DType.bfloat16 + device: DeviceRef = Field(default_factory=DeviceRef.GPU) + + +class ClipConfig(ClipConfigBase): + config_name: ClassVar[str] = "config.json" + + @staticmethod + def generate( + config_dict: dict, + encoding: SupportedEncoding, + devices: list[Device], + ) -> ClipConfigBase: + init_dict = { + key: value + for key, value in config_dict.items() + if key in ClipConfigBase.__annotations__ + } + init_dict.update( + { + "dtype": encoding.dtype, + "device": DeviceRef.from_device(devices[0]), + } + ) + return ClipConfigBase(**init_dict) diff --git a/max/python/max/pipelines/architectures/flux1/__init__.py b/max/python/max/pipelines/architectures/flux1/__init__.py new file mode 100644 index 00000000000..2325700031e --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/__init__.py @@ -0,0 +1,14 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .arch import flux1_arch diff --git a/max/python/max/pipelines/architectures/flux1/arch.py b/max/python/max/pipelines/architectures/flux1/arch.py new file mode 100644 index 00000000000..aea17022398 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/arch.py @@ -0,0 +1,38 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.graph.weights import WeightsFormat +from max.interfaces import BaseContext, PipelineTask +from max.pipelines.lib import ( + SupportedArchitecture, + SupportedEncoding, + TextTokenizer, +) + +from .pipeline_flux import FluxPipeline + +# TODO(minkyu): revisit default_encoding, supported_encodings, tokenizer. +flux1_arch = SupportedArchitecture( + name="FluxPipeline", + task=PipelineTask.IMAGE_GENERATION, + default_encoding=SupportedEncoding.bfloat16, + supported_encodings={SupportedEncoding.bfloat16: []}, + example_repo_ids=[ + "black-forest-labs/FLUX.1-dev", + "black-forest-labs/FLUX.1-schnell", + ], + pipeline_model=FluxPipeline, + tokenizer=TextTokenizer, + context_type=BaseContext, + default_weights_format=WeightsFormat.safetensors, +) diff --git a/max/python/max/pipelines/architectures/flux1/flux1.py b/max/python/max/pipelines/architectures/flux1/flux1.py new file mode 100644 index 00000000000..ed553ef159a --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/flux1.py @@ -0,0 +1,544 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import logging +import os +from collections.abc import Generator +from contextlib import contextmanager +from os import PathLike +from typing import Any + +import max.nn as nn +from max.driver import DLPackArray +from max.dtype import DType +from max.graph import DeviceRef, TensorType, TensorValue, ops +from max.graph.weights import SafetensorWeights +from max.nn import LayerNorm, Module + +from .layers.embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, +) +from .layers.flux_attention import FeedForward, FluxAttention, FluxPosEmbed +from .layers.normalizations import ( + AdaLayerNormContinuous, + AdaLayerNormZero, + AdaLayerNormZeroSingle, +) +from .model_config import FluxConfig + +logger = logging.getLogger(__name__) + + +def get_weight_registry_from_diffusers( + safe_tensor_folder: PathLike, +) -> dict[str, DLPackArray]: + weight_files = [ + os.path.join(safe_tensor_folder, f) + for f in os.listdir(safe_tensor_folder) + if f.endswith(".safetensors") + ] + weights = SafetensorWeights(weight_files) + return {name: weight.data().data for name, weight in weights.items()} + + +class FluxSingleTransformerBlock(Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize Flux single transformer block. + + Args: + dim: Dimension of the input/output. + num_attention_heads: Number of attention heads. + attention_head_dim: Dimension of each attention head. + mlp_ratio: Ratio for MLP hidden dimension. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim, device=device, dtype=dtype) + self.proj_mlp = nn.Linear( + dim, self.mlp_hidden_dim, has_bias=True, device=device, dtype=dtype + ) + self.act_mlp = ops.gelu + self.proj_out = nn.Linear( + dim + self.mlp_hidden_dim, + dim, + has_bias=True, + device=device, + dtype=dtype, + ) + self.attn = FluxAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + eps=1e-6, + pre_only=True, + device=device, + dtype=dtype, + ) + + def __call__( + self, + hidden_states: TensorValue, + encoder_hidden_states: TensorValue, + temb: TensorValue, + image_rotary_emb: tuple[TensorValue, TensorValue] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[TensorValue, TensorValue]: + """Apply single transformer block with attention and MLP. + + Args: + hidden_states: Input hidden states. + encoder_hidden_states: Encoder hidden states for cross-attention. + temb: Time embedding. + image_rotary_emb: Optional rotary position embeddings. + joint_attention_kwargs: Optional attention kwargs. + + Returns: + Tuple of (encoder_hidden_states, hidden_states). + """ + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = ops.concat( + [encoder_hidden_states, hidden_states], axis=1 + ) + + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp( + self.proj_mlp(norm_hidden_states), approximate="tanh" + ) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = ops.concat([attn_output, mlp_hidden_states], axis=2) + gate = ops.unsqueeze(gate, 1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == DType.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + encoder_hidden_states, hidden_states = ( + hidden_states[:, :text_seq_len], + hidden_states[:, text_seq_len:], + ) + return encoder_hidden_states, hidden_states + + +class FluxTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize Flux transformer block. + + Args: + dim: Dimension of the input/output. + num_attention_heads: Number of attention heads. + attention_head_dim: Dimension of each attention head. + qk_norm: Type of normalization for query and key ("rms_norm"). + eps: Epsilon for normalization layers. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + + self.norm1 = AdaLayerNormZero(dim, device=device, dtype=dtype) + self.norm1_context = AdaLayerNormZero(dim, device=device, dtype=dtype) + + self.attn = FluxAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + eps=eps, + device=device, + dtype=dtype, + ) + + self.norm2 = LayerNorm( + dim, + eps=1e-6, + devices=[device], + dtype=dtype, + keep_dtype=True, + elementwise_affine=False, + ) + self.ff = FeedForward( + dim=dim, + dim_out=dim, + activation_fn="gelu-approximate", + device=device, + dtype=dtype, + ) + + self.norm2_context = LayerNorm( + dim, + eps=1e-6, + devices=[device], + dtype=dtype, + keep_dtype=True, + elementwise_affine=False, + ) + self.ff_context = FeedForward( + dim=dim, + dim_out=dim, + activation_fn="gelu-approximate", + device=device, + dtype=dtype, + ) + + def __call__( + self, + hidden_states: TensorValue, + encoder_hidden_states: TensorValue, + temb: TensorValue, + image_rotary_emb: tuple[TensorValue, TensorValue] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[TensorValue, TensorValue]: + """Apply transformer block with dual-stream attention and feedforward. + + Args: + hidden_states: Input hidden states. + encoder_hidden_states: Encoder hidden states for cross-attention. + temb: Time embedding. + image_rotary_emb: Optional rotary position embeddings. + joint_attention_kwargs: Optional attention kwargs. + + Returns: + Tuple of (encoder_hidden_states, hidden_states). + """ + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.norm1(hidden_states, emb=temb) + ) + + ( + norm_encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = self.norm1_context(encoder_hidden_states, emb=temb) + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + attn_output, context_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = ops.unsqueeze(gate_msa, 1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + + ff_output = self.ff(norm_hidden_states) + ff_output = ops.unsqueeze(gate_mlp, 1) * ff_output + + hidden_states = hidden_states + ff_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = ops.unsqueeze(c_gate_msa, 1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = ( + norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + + c_shift_mlp[:, None] + ) + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = ( + encoder_hidden_states + + ops.unsqueeze(c_gate_mlp, 1) * context_ff_output + ) + if encoder_hidden_states.dtype == DType.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class FluxTransformer2DModel(nn.Module): + def __init__( + self, + config: FluxConfig, + ): + """Initialize Flux Transformer 2D model. + + Args: + config: Flux configuration containing model dimensions, attention + settings, and device/dtype information. + """ + super().__init__() + patch_size = config.patch_size + in_channels = config.in_channels + out_channels = config.out_channels + num_layers = config.num_layers + num_single_layers = config.num_single_layers + attention_head_dim = config.attention_head_dim + num_attention_heads = config.num_attention_heads + joint_attention_dim = config.joint_attention_dim + pooled_projection_dim = config.pooled_projection_dim + guidance_embeds = config.guidance_embeds + axes_dims_rope = config.axes_dims_rope + device = config.device + dtype = config.dtype + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + self.guidance_embeds = guidance_embeds + + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings + if guidance_embeds + else CombinedTimestepTextProjEmbeddings + ) + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, + pooled_projection_dim=pooled_projection_dim, + device=device, + dtype=dtype, + ) + self.context_embedder = nn.Linear( + joint_attention_dim, + self.inner_dim, + has_bias=True, + device=device, + dtype=dtype, + ) + self.x_embedder = nn.Linear( + in_channels, + self.inner_dim, + has_bias=True, + device=device, + dtype=dtype, + ) + + self.transformer_blocks = nn.Sequential( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + device=device, + dtype=dtype, + ) + for _ in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.Sequential( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + device=device, + dtype=dtype, + ) + for _ in range(num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, eps=1e-6, device=device, dtype=dtype + ) + self.proj_out = nn.Linear( + self.inner_dim, + patch_size * patch_size * self.out_channels, + has_bias=True, + device=device, + dtype=dtype, + ) + + self.gradient_checkpointing = False + + self.max_device = device + self.max_dtype = dtype + self.in_channels = in_channels + self.joint_attention_dim = joint_attention_dim + self.pooled_projection_dim = pooled_projection_dim + + self._cache_context_warning_shown = False + + def input_types(self) -> tuple[TensorType, ...]: + """Define input tensor types for the model. + + Returns: + Tuple of TensorType specifications for all model inputs. + """ + hidden_states_type = TensorType( + self.max_dtype, + shape=["batch_size", "image_seq_len", self.in_channels], + device=self.max_device, + ) + encoder_hidden_states_type = TensorType( + self.max_dtype, + shape=["batch_size", "text_seq_len", self.joint_attention_dim], + device=self.max_device, + ) + pooled_projections_type = TensorType( + self.max_dtype, + shape=["batch_size", self.pooled_projection_dim], + device=self.max_device, + ) + timestep_type = TensorType( + DType.float32, shape=["batch_size"], device=self.max_device + ) + img_ids_type = TensorType( + self.max_dtype, shape=["image_seq_len", 3], device=self.max_device + ) + txt_ids_type = TensorType( + self.max_dtype, shape=["text_seq_len", 3], device=self.max_device + ) + guidance_type = TensorType( + self.max_dtype, shape=["batch_size"], device=self.max_device + ) + + return ( + hidden_states_type, + encoder_hidden_states_type, + pooled_projections_type, + timestep_type, + img_ids_type, + txt_ids_type, + guidance_type, + ) + + @contextmanager + def cache_context(self, name: str) -> Generator[None, None, None]: + """Context manager for cache control (not implemented in MAX). + + Args: + name: Name of the cache context. + + Yields: + None. + """ + if not self._cache_context_warning_shown: + logger.warning( + "cache_context is not implemented in MAX FluxTransformer2DModel. " + "Caching optimizations are disabled." + ) + self._cache_context_warning_shown = True + yield + + def __call__( + self, + hidden_states: TensorValue, + encoder_hidden_states: TensorValue = None, + pooled_projections: TensorValue = None, + timestep: TensorValue = None, + img_ids: TensorValue = None, + txt_ids: TensorValue = None, + guidance: TensorValue = None, + joint_attention_kwargs: dict[str, Any] | None = None, + controlnet_block_samples: Any | None = None, + controlnet_single_block_samples: Any | None = None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> tuple[TensorValue]: + """Apply Flux Transformer 2D model forward pass. + + Args: + hidden_states: Input latent hidden states. + encoder_hidden_states: Text encoder hidden states. + pooled_projections: Pooled text embeddings. + timestep: Diffusion timestep. + img_ids: Image position IDs. + txt_ids: Text position IDs. + guidance: Guidance scale values. + joint_attention_kwargs: Additional attention arguments. + controlnet_block_samples: Optional controlnet block samples. + controlnet_single_block_samples: Optional controlnet single block samples. + return_dict: Whether to return as dictionary. + controlnet_blocks_repeat: Whether to repeat controlnet blocks. + + Returns: + Tuple containing output tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + + hidden_states = self.x_embedder(hidden_states) + + timestep = ops.cast(timestep, hidden_states.dtype) + timestep = timestep * 1000.0 + if guidance is not None: + guidance = guidance.cast(hidden_states.dtype) * 1000.0 + + temb = ( + self.time_text_embed(timestep, pooled_projections) + if not self.guidance_embeds + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + ids = ops.concat((txt_ids, img_ids), axis=0) + image_rotary_emb = self.pos_embed(ids) + + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + for block in self.single_transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + return (output,) diff --git a/max/python/max/pipelines/architectures/flux1/layers/__init__.py b/max/python/max/pipelines/architectures/flux1/layers/__init__.py new file mode 100644 index 00000000000..75c4f824f20 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/layers/__init__.py @@ -0,0 +1,12 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # diff --git a/max/python/max/pipelines/architectures/flux1/layers/activations.py b/max/python/max/pipelines/architectures/flux1/layers/activations.py new file mode 100644 index 00000000000..c5fc5a80aca --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/layers/activations.py @@ -0,0 +1,56 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import max.nn as nn +from max.dtype import DType +from max.graph import DeviceRef, TensorValue, ops + + +class GELU(nn.Module): + def __init__( + self, + dim_in: int, + dim_out: int, + approximate: str = "none", + bias: bool = True, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize GELU activation layer with linear projection. + + Args: + dim_in: Input dimension. + dim_out: Output dimension. + approximate: Approximation type for GELU ("none" or "tanh"). + bias: Whether to include bias in the linear projection. + device: Device to place the layer on. + dtype: Data type for the layer. + """ + super().__init__() + self.proj = nn.Linear( + dim_in, dim_out, has_bias=bias, dtype=dtype, device=device + ) + self.approximate = approximate + + def __call__(self, hidden_states: TensorValue) -> TensorValue: + """Apply GELU activation to the input. + + Args: + hidden_states: Input tensor. + + Returns: + Output tensor after linear projection and GELU activation. + """ + hidden_states = self.proj(hidden_states) + hidden_states = ops.gelu(hidden_states, approximate=self.approximate) + return hidden_states diff --git a/max/python/max/pipelines/architectures/flux1/layers/embeddings.py b/max/python/max/pipelines/architectures/flux1/layers/embeddings.py new file mode 100644 index 00000000000..cb50aa3f83b --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/layers/embeddings.py @@ -0,0 +1,471 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import math + +from max import nn +from max.dtype import DType +from max.graph import DeviceRef, TensorValue, ops + + +def apply_rotary_emb( + x: TensorValue, + freqs_cis: tuple[TensorValue, TensorValue], + sequence_dim: int = 2, +) -> TensorValue: + """Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query or key 'x' tensors using the provided frequency + tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped + for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors. + + Args: + x: Query or key tensor to apply rotary embeddings. Shape depends on + caller; the last dimension is split into complex pairs. + freqs_cis: Precomputed cosine/sine frequency tensors for complex + exponentials. Shape ([S, D], [S, D]). + sequence_dim: Dimension representing the sequence (1 or 2). + + Returns: + Tensor: Tensor with rotary embeddings applied. + """ + cos, sin = freqs_cis # [S, D] + if sequence_dim == 2: + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + elif sequence_dim == 1: + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + else: + raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") + + cos, sin = cos.to(x.device), sin.to(x.device) + + # Used for flux, cogvideox, hunyuan-dit + half_last_dim = x.shape[-1] // 2 + chunks = ops.chunk( + x.reshape(list(x.shape[:-1]) + [half_last_dim, 2]), chunks=2, axis=-1 + ) + x_real = ops.squeeze(chunks[0], axis=-1) + x_imag = ops.squeeze(chunks[1], axis=-1) + # Stack and flatten: [B, S, H, D//2] -> [B, S, H, D//2, 2] -> [B, S, H, D] + x_rotated_stacked = ops.stack([-x_imag, x_real], axis=-1) + batch_sz = x_rotated_stacked.shape[0] + seq_len = x_rotated_stacked.shape[1] + heads = x_rotated_stacked.shape[2] + flattened_last_dim = x_rotated_stacked.shape[3] * x_rotated_stacked.shape[4] + x_rotated = ops.reshape( + x_rotated_stacked, (batch_sz, seq_len, heads, flattened_last_dim) + ) + + out = ( + x.cast(DType.float32) * cos + x_rotated.cast(DType.float32) * sin + ).cast(x.dtype) + + return out + + +def get_timestep_embedding( + timesteps: TensorValue, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> TensorValue: + """Create sinusoidal timestep embeddings. + + Matches the implementation in Diffusers/DDPM. + """ + half_dim = embedding_dim // 2 + + # Create exponent: -math.log(max_period) * arange(0, half_dim) + # ops.range creates a sequence tensor + exponent = ops.range( + 0, half_dim, step=1, dtype=DType.float32, device=timesteps.device + ) + exponent = exponent * (-math.log(max_period)) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = ops.exp(exponent) + + # emb = timesteps[:, None].float() * emb[None, :] + timesteps_f32 = timesteps.cast(DType.float32) + timesteps_dim = timesteps_f32.shape[0] + emb_dim = emb.shape[0] + emb = timesteps_f32.reshape((timesteps_dim, 1)) * emb.reshape((1, emb_dim)) + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = ops.concat([ops.sin(emb), ops.cos(emb)], axis=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = ops.concat([emb[:, half_dim:], emb[:, :half_dim]], axis=-1) + + # zero pad if embedding_dim is odd (rare case) + if embedding_dim % 2 == 1: + # Pad with one zero column at the end + zeros = ops.zeros((emb.shape[0], 1), dtype=emb.dtype, device=emb.device) + emb = ops.concat([emb, zeros], axis=-1) + + return emb + + +class Timesteps(nn.Module): + def __init__( + self, + num_channels: int, + flip_sin_to_cos: bool, + downscale_freq_shift: float, + scale: int = 1, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.float32, + ): + """Initialize Timesteps embedding module. + + Args: + num_channels: Number of channels in the embedding. + flip_sin_to_cos: Whether to flip sine and cosine embeddings. + downscale_freq_shift: Frequency downscaling shift parameter. + scale: Scaling factor for embeddings. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def __call__(self, timesteps: TensorValue) -> TensorValue: + """Generate timestep embeddings. + + Args: + timesteps: Input timestep values. + + Returns: + Timestep embeddings. + """ + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int | None = None, + post_act_fn: str | None = None, + cond_proj_dim: int | None = None, + sample_proj_bias: bool = True, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize TimestepEmbedding module. + + Args: + in_channels: Number of input channels. + time_embed_dim: Dimension of the time embedding. + act_fn: Activation function to use ("silu", "swish", or "gelu"). + out_dim: Optional output dimension. Defaults to time_embed_dim if None. + post_act_fn: Optional post-activation function. + cond_proj_dim: Optional conditional projection dimension. + sample_proj_bias: Whether to use bias in projection layers. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + + self.linear_1 = nn.Linear( + in_channels, + time_embed_dim, + has_bias=sample_proj_bias, + device=device, + dtype=dtype, + ) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear( + cond_proj_dim, + in_channels, + has_bias=False, + device=device, + dtype=dtype, + ) + else: + self.cond_proj = None + if act_fn == "silu" or act_fn == "swish": + self.act_fn = ops.silu + elif act_fn == "gelu": + self.act_fn = ops.gelu + else: + raise ValueError(f"Invalid activation function: {act_fn}") + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + + self.linear_2 = nn.Linear( + time_embed_dim, + time_embed_dim_out, + has_bias=sample_proj_bias, + device=device, + dtype=dtype, + ) + + if post_act_fn is None: + self.post_act_fn = None + elif post_act_fn == "silu" or post_act_fn == "swish": + self.post_act_fn = ops.silu + elif post_act_fn == "gelu": + self.post_act_fn = ops.gelu + else: + raise ValueError(f"Invalid post activation function: {post_act_fn}") + + def __call__( + self, sample: TensorValue, condition: TensorValue | None = None + ) -> TensorValue: + """Generate timestep embeddings with optional conditioning. + + Args: + sample: Input sample tensor. + condition: Optional conditioning tensor. + + Returns: + Timestep embeddings. + """ + if condition is not None and self.cond_proj is not None: + sample = sample + self.cond_proj(condition) + + sample = self.linear_1(sample) + + sample = self.act_fn(sample) + + sample = self.linear_2(sample) + + if self.post_act_fn is not None: + sample = self.post_act_fn(sample) + + return sample + + +class PixArtAlphaTextProjection(nn.Module): + """Projects caption embeddings. Also handles dropout for classifier-free guidance.""" + + def __init__( + self, + in_features: int, + hidden_size: int, + out_features: int | None = None, + act_fn: str = "gelu_tanh", + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize PixArtAlpha text projection module. + + Args: + in_features: Number of input features. + hidden_size: Size of the hidden layer. + out_features: Number of output features. Defaults to hidden_size if None. + act_fn: Activation function to use ("gelu_tanh" or "silu"). + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = nn.Linear( + in_features, hidden_size, has_bias=True, device=device, dtype=dtype + ) + self.linear_2 = nn.Linear( + hidden_size, out_features, has_bias=True, device=device, dtype=dtype + ) + if act_fn == "gelu_tanh": + self.act_fn = ops.gelu(approximate="tanh") + elif act_fn == "silu": + self.act_fn = ops.silu + else: + raise ValueError(f"Invalid activation function: {act_fn}") + + def __call__(self, caption: TensorValue) -> TensorValue: + """Project caption embeddings. + + Args: + caption: Input caption embeddings. + + Returns: + Projected caption embeddings. + """ + hidden_states = self.linear_1(caption) + + hidden_states = self.act_fn(hidden_states) + + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class CombinedTimestepTextProjEmbeddings(nn.Module): + def __init__( + self, + embedding_dim: int, + pooled_projection_dim: int, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize combined timestep and text projection embeddings module. + + Args: + embedding_dim: Dimension of the embedding. + pooled_projection_dim: Dimension of the pooled projection. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + + self.time_proj = Timesteps( + num_channels=256, + flip_sin_to_cos=True, + downscale_freq_shift=0, + device=device, + dtype=dtype, + ) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=embedding_dim, + device=device, + dtype=dtype, + ) + self.text_embedder = PixArtAlphaTextProjection( + pooled_projection_dim, + embedding_dim, + act_fn="silu", + device=device, + dtype=dtype, + ) + + def __call__( + self, timestep: TensorValue, pooled_projection: TensorValue + ) -> TensorValue: + """Combine timestep and text embeddings. + + Args: + timestep: Input timestep values. + pooled_projection: Pooled text projection. + + Returns: + Combined conditioning embeddings. + """ + # Timestep projection and embedding + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.cast(pooled_projection.dtype) + ) + + # Text projection + pooled_projections = self.text_embedder(pooled_projection) + + # Combine + conditioning = timesteps_emb + pooled_projections + + return conditioning + + +class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): + def __init__( + self, + embedding_dim: int, + pooled_projection_dim: int, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize combined timestep, guidance, and text projection embeddings module. + + Args: + embedding_dim: Dimension of the embedding. + pooled_projection_dim: Dimension of the pooled projection. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + + self.time_proj = Timesteps( + num_channels=256, + flip_sin_to_cos=True, + downscale_freq_shift=0, + device=device, + dtype=dtype, + ) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=embedding_dim, + device=device, + dtype=dtype, + ) + self.guidance_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=embedding_dim, + device=device, + dtype=dtype, + ) + self.text_embedder = PixArtAlphaTextProjection( + pooled_projection_dim, + embedding_dim, + act_fn="silu", + device=device, + dtype=dtype, + ) + + def __call__( + self, + timestep: TensorValue, + guidance: TensorValue, + pooled_projection: TensorValue, + ) -> TensorValue: + """Combine timestep, guidance, and text embeddings. + + Args: + timestep: Input timestep values. + guidance: Guidance values. + pooled_projection: Pooled text projection. + + Returns: + Combined conditioning embeddings. + """ + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.cast(pooled_projection.dtype) + ) + + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder( + guidance_proj.cast(pooled_projection.dtype) + ) + + time_guidance_emb = timesteps_emb + guidance_emb + + pooled_projections = self.text_embedder(pooled_projection) + conditioning = time_guidance_emb + pooled_projections + + return conditioning diff --git a/max/python/max/pipelines/architectures/flux1/layers/flux_attention.py b/max/python/max/pipelines/architectures/flux1/layers/flux_attention.py new file mode 100644 index 00000000000..188050390e4 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/layers/flux_attention.py @@ -0,0 +1,474 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import math + +import max.nn as nn +from max.dtype import DType +from max.graph import DeviceRef, TensorValue, ops +from max.nn import ( + Linear, + Module, + RMSNorm, +) +from max.nn.attention.mask_config import MHAMaskVariant +from max.nn.kernels import flash_attention_gpu + +from .activations import GELU +from .embeddings import apply_rotary_emb + + +class FluxAttention(Module): + """Flux attention mechanism with QK normalization and optional dual stream.""" + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: int | None = None, + added_proj_bias: bool | None = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int | None = None, + context_pre_only: bool | None = None, + pre_only: bool = False, + elementwise_affine: bool = True, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize Flux attention module. + + Args: + query_dim: Dimension of query vectors. + heads: Number of attention heads. + dim_head: Dimension of each attention head. + dropout: Dropout probability. + bias: Whether to use bias in projections. + added_kv_proj_dim: Optional dimension for additional key/value projections. + added_proj_bias: Whether to use bias in additional projections. + out_bias: Whether to use bias in output projection. + eps: Epsilon for normalization layers. + out_dim: Optional output dimension. + context_pre_only: Whether to use context pre-processing only. + pre_only: Whether to use pre-processing only. + elementwise_affine: Whether to use elementwise affine in normalization. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + self.dtype = dtype + self.device = device + + self.norm_q = RMSNorm( + dim_head, + dtype=self.dtype, + eps=eps, + multiply_before_cast=elementwise_affine, + ) + self.norm_k = RMSNorm( + dim_head, + dtype=self.dtype, + eps=eps, + multiply_before_cast=elementwise_affine, + ) + self.to_q = Linear( + query_dim, + self.inner_dim, + has_bias=bias, + dtype=self.dtype, + device=self.device, + ) + self.to_k = Linear( + query_dim, + self.inner_dim, + has_bias=bias, + dtype=self.dtype, + device=self.device, + ) + self.to_v = Linear( + query_dim, + self.inner_dim, + has_bias=bias, + dtype=self.dtype, + device=self.device, + ) + + if not self.pre_only: + layers = [] + layers.append( + Linear( + self.inner_dim, + self.out_dim, + has_bias=out_bias, + dtype=self.dtype, + device=self.device, + ) + ) + # layers.append(Dropout(dropout)) # There is no Dropout in MAX + self.to_out = nn.Sequential(layers) + + if added_kv_proj_dim is not None: + self.norm_added_q = RMSNorm(dim_head, dtype=self.dtype, eps=eps) + self.norm_added_k = RMSNorm(dim_head, dtype=self.dtype, eps=eps) + self.add_q_proj = Linear( + added_kv_proj_dim, + self.inner_dim, + has_bias=added_proj_bias, + dtype=self.dtype, + device=self.device, + ) + self.add_k_proj = Linear( + added_kv_proj_dim, + self.inner_dim, + has_bias=added_proj_bias, + dtype=self.dtype, + device=self.device, + ) + self.add_v_proj = Linear( + added_kv_proj_dim, + self.inner_dim, + has_bias=added_proj_bias, + dtype=self.dtype, + device=self.device, + ) + self.to_add_out = Linear( + self.inner_dim, + query_dim, + has_bias=out_bias, + dtype=self.dtype, + device=self.device, + ) + + def __call__( + self, + hidden_states: TensorValue, + encoder_hidden_states: TensorValue = None, + attention_mask: TensorValue | None = None, + image_rotary_emb: tuple[TensorValue, TensorValue] | None = None, + ) -> TensorValue: + """Apply Flux attention to hidden states. + + Args: + hidden_states: Input hidden states. + encoder_hidden_states: Optional encoder hidden states for cross-attention. + attention_mask: Optional attention mask. + image_rotary_emb: Optional rotary embeddings for position encoding. + + Returns: + Output hidden states after attention, or tuple of (hidden_states, encoder_hidden_states) if encoder states provided. + """ + batch_size = hidden_states.shape[0] + + # get qkv projections + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + seq_len = query.shape[1] + query = ops.reshape( + query, (batch_size, seq_len, self.heads, self.head_dim) + ) + key = ops.reshape(key, (batch_size, seq_len, self.heads, self.head_dim)) + value = ops.reshape( + value, (batch_size, seq_len, self.heads, self.head_dim) + ) + + query = self.norm_q(query) + key = self.norm_k(key) + + encoder_query = encoder_key = encoder_value = None + if ( + encoder_hidden_states is not None + and self.added_kv_proj_dim is not None + ): + encoder_query = self.add_q_proj(encoder_hidden_states) + encoder_key = self.add_k_proj(encoder_hidden_states) + encoder_value = self.add_v_proj(encoder_hidden_states) + + query = self.norm_q(query) + key = self.norm_k(key) + + if ( + encoder_hidden_states is not None + and self.added_kv_proj_dim is not None + ): + encoder_seq_len = encoder_query.shape[1] + encoder_query = ops.reshape( + encoder_query, + (batch_size, encoder_seq_len, self.heads, self.head_dim), + ) + encoder_key = ops.reshape( + encoder_key, + (batch_size, encoder_seq_len, self.heads, self.head_dim), + ) + encoder_value = ops.reshape( + encoder_value, + (batch_size, encoder_seq_len, self.heads, self.head_dim), + ) + + encoder_query = self.norm_added_q(encoder_query) + encoder_key = self.norm_added_k(encoder_key) + + query = ops.concat([encoder_query, query], axis=1) + key = ops.concat([encoder_key, key], axis=1) + value = ops.concat([encoder_value, value], axis=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = flash_attention_gpu( + query, + key, + value, + mask_variant=MHAMaskVariant.NULL_MASK, + scale=math.sqrt(1.0 / self.head_dim), + ) + + total_seq_len = hidden_states.shape[1] + hidden_states = ops.reshape( + hidden_states, + (batch_size, total_seq_len, self.heads * self.head_dim), + ) + + if encoder_hidden_states is not None: + encoder_seq_len = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :encoder_seq_len, :] + hidden_states = hidden_states[:, encoder_seq_len:, :] + + hidden_states = self.to_out(hidden_states) + encoder_hidden_states = self.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + return hidden_states + + +class FeedForward(Module): + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim: int | None = None, + bias: bool = True, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize FeedForward module. + + Args: + dim: Input dimension. + dim_out: Optional output dimension. Defaults to dim if None. + mult: Multiplier for hidden dimension. + dropout: Dropout probability. + activation_fn: Activation function to use ("gelu" or "gelu-approximate"). + final_dropout: Whether to apply dropout at the end. + inner_dim: Optional inner dimension. Computed as dim * mult if None. + bias: Whether to use bias in linear layers. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias, device=device, dtype=dtype) + if activation_fn == "gelu-approximate": + act_fn = GELU( + dim, + inner_dim, + approximate="tanh", + bias=bias, + device=device, + dtype=dtype, + ) + else: + raise NotImplementedError( + f"Activation function {activation_fn} is not implemented" + ) + + self.net = nn.Sequential( + [ + act_fn, + Linear( + inner_dim, + dim_out, + has_bias=bias, + dtype=dtype, + device=device, + ), + ] + ) + + def __call__( + self, hidden_states: TensorValue, *args, **kwargs + ) -> TensorValue: + """Apply feedforward network to hidden states. + + Args: + hidden_states: Input hidden states. + *args: Additional positional arguments (unused). + **kwargs: Additional keyword arguments (unused). + + Returns: + Output hidden states after feedforward network. + """ + return self.net(hidden_states) + + +class FluxPosEmbed(nn.Module): + """Flux Position Embedding module for 3D rotary position embeddings. + + This module computes separate rotary embeddings for each spatial dimension + (typically time, height, width) and concatenates them. + + Args: + theta: Base value for frequency computation (typically 10000) + axes_dim: List of dimensions for each axis (e.g., [16, 56, 56] for time, height, width) + """ + + def __init__( + self, theta: int = 10000, axes_dim: tuple[int, int, int] = (16, 56, 56) + ): + """Initialize Flux position embedding module. + + Args: + theta: Base value for frequency computation (typically 10000). + axes_dim: Dimensions for each axis (e.g., [16, 56, 56] for time, height, width). + """ + super().__init__() + self.theta = float(theta) + self.axes_dim = list(axes_dim) + + def _get_1d_rotary_pos_embed( + self, dim: int, pos: TensorValue, device: DeviceRef + ) -> tuple[TensorValue, TensorValue]: + """Compute 1D rotary position embeddings for a single axis. + + Args: + dim: Dimension of the embedding (should be even) + pos: Position indices, shape [batch_size] + device: Device to compute on + + Returns: + Tuple of (freqs_cos, freqs_sin), each with shape [batch_size, dim] + """ + # Ensure dim is even + assert dim % 2 == 0, f"dim must be even, got {dim}" + + # Cast position to float32 for computation + pos = ops.cast(pos, DType.float32) + + # Compute frequencies: 1.0 / (theta ** (arange(0, dim, 2) / dim)) + # Shape: [dim/2] + arange_vals = ops.range( + 0, dim, step=2, dtype=DType.float32, device=device + ) + exponents = arange_vals / float(dim) + + # theta ** exponents + theta_tensor = ops.constant(self.theta, DType.float32, device=device) + theta_powered = ops.pow(theta_tensor, exponents) + + # 1.0 / theta_powered + freqs = 1.0 / theta_powered # Shape: [dim/2] + + # Outer product: pos [batch_size] x freqs [dim/2] = [batch_size, dim/2] + freqs_outer = ops.outer(pos, freqs) + + # Compute cos and sin + freqs_cos_half = ops.cos(freqs_outer) # [batch_size, dim/2] + freqs_sin_half = ops.sin(freqs_outer) # [batch_size, dim/2] + + # Repeat interleave to get full dimension + # repeat_interleave(2, dim=1): [a, b, c] -> [a, a, b, b, c, c] + # Since repeat_interleave is not supported on GPU, we use reshape + tile + + # 1. Unsqueeze: [batch_size, dim/2] -> [batch_size, dim/2, 1] + freqs_cos_expanded = ops.unsqueeze(freqs_cos_half, axis=2) + freqs_sin_expanded = ops.unsqueeze(freqs_sin_half, axis=2) + + # 2. Concat to duplicate: [batch_size, dim/2, 1] -> [batch_size, dim/2, 2] + freqs_cos_tiled = ops.concat( + [freqs_cos_expanded, freqs_cos_expanded], axis=2 + ) + freqs_sin_tiled = ops.concat( + [freqs_sin_expanded, freqs_sin_expanded], axis=2 + ) + + # 3. Reshape to flatten: [batch_size, dim/2, 2] -> [batch_size, dim] + flattened_dim = freqs_cos_tiled.shape[1] * freqs_cos_tiled.shape[2] + freqs_cos = ops.reshape( + freqs_cos_tiled, (freqs_cos_tiled.shape[0], flattened_dim) + ) + freqs_sin = ops.reshape( + freqs_sin_tiled, (freqs_sin_tiled.shape[0], flattened_dim) + ) + + return freqs_cos, freqs_sin + + def __call__(self, ids: TensorValue) -> tuple[TensorValue, TensorValue]: + """Forward pass to compute rotary position embeddings. + + Args: + ids: Position indices tensor with shape [batch_size, n_axes] + where n_axes is the number of spatial dimensions (e.g., 3 for time/height/width) + + Returns: + Tuple of (freqs_cos, freqs_sin) with concatenated embeddings from all axes + """ + # Get number of axes from the last dimension + n_axes = ids.shape[-1] + device = ids.device + + cos_out = [] + sin_out = [] + + # Compute embeddings for each axis + for i in range(int(n_axes)): + # Extract position for this axis: ids[:, i] + pos = ids[:, i] + + # Compute 1D rotary embeddings for this axis + cos_embed, sin_embed = self._get_1d_rotary_pos_embed( + dim=self.axes_dim[i], pos=pos, device=device + ) + + cos_out.append(cos_embed) + sin_out.append(sin_embed) + + # Concatenate embeddings from all axes along the last dimension + freqs_cos = ops.concat(cos_out, axis=-1) # [batch_size, sum(axes_dim)] + freqs_sin = ops.concat(sin_out, axis=-1) # [batch_size, sum(axes_dim)] + + return freqs_cos, freqs_sin diff --git a/max/python/max/pipelines/architectures/flux1/layers/normalizations.py b/max/python/max/pipelines/architectures/flux1/layers/normalizations.py new file mode 100644 index 00000000000..6755d56125d --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/layers/normalizations.py @@ -0,0 +1,254 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + + +import max.nn as nn +from max.dtype import DType +from max.graph import DeviceRef, TensorValue, ops +from max.nn import LayerNorm, RMSNorm + + +class AdaLayerNormZeroSingle(nn.Module): + def __init__( + self, + embedding_dim: int, + norm_type: str = "layer_norm", + bias: bool = True, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize adaptive layer normalization zero single module. + + Args: + embedding_dim: Size of each embedding vector. + norm_type: Type of normalization to use ("layer_norm"). + bias: Whether to use bias in linear projection. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + self.linear = nn.Linear( + embedding_dim, + 3 * embedding_dim, + has_bias=bias, + device=device, + dtype=dtype, + ) + if norm_type == "layer_norm": + self.norm = LayerNorm( + embedding_dim, + use_bias=False, + eps=1e-6, + devices=[device], + dtype=dtype, + keep_dtype=True, + elementwise_affine=False, + ) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def __call__( + self, x: TensorValue, emb: TensorValue | None = None + ) -> TensorValue: + """Apply adaptive layer normalization. + + Args: + x: Input tensor. + emb: Optional embedding tensor for conditioning. + + Returns: + Tuple of normalized tensor and gate values. + """ + emb = self.linear(ops.silu(emb)) + shift_msa, scale_msa, gate_msa = ops.chunk(emb, 3, axis=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa + + +class AdaLayerNormZero(nn.Module): + r"""Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__( + self, + embedding_dim: int, + num_embeddings: int | None = None, + norm_type: str = "layer_norm", + bias: bool = True, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize adaptive layer normalization zero module. + + Args: + embedding_dim: Size of each embedding vector. + num_embeddings: Optional size of the embeddings dictionary. + norm_type: Type of normalization to use ("layer_norm" or "fp32_layer_norm"). + bias: Whether to use bias in linear projection. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + if num_embeddings is not None: + # self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) + raise NotImplementedError( + "CombinedTimestepLabelEmbeddings is not implemented" + ) + else: + self.emb = None + + self.linear = nn.Linear( + embedding_dim, + 6 * embedding_dim, + has_bias=bias, + dtype=dtype, + device=device, + ) + if norm_type == "layer_norm": + self.norm = LayerNorm( + embedding_dim, + use_bias=False, + eps=1e-6, + devices=[device], + dtype=dtype, + keep_dtype=True, + elementwise_affine=False, + ) + elif norm_type == "fp32_layer_norm": + # self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False) + raise NotImplementedError("FP32LayerNorm is not implemented") + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def __call__( + self, + x: TensorValue, + timestep: TensorValue | None = None, + class_labels: TensorValue | None = None, + hidden_dtype: DType | None = None, + emb: TensorValue | None = None, + ) -> tuple[TensorValue, TensorValue, TensorValue, TensorValue, TensorValue]: + """Apply adaptive layer normalization with gate values for attention and MLP. + + Args: + x: Input tensor. + timestep: Optional timestep tensor. + class_labels: Optional class label tensor. + hidden_dtype: Optional hidden data type. + emb: Optional embedding tensor for conditioning. + + Returns: + Tuple of (normalized tensor, gate_msa, shift_mlp, scale_mlp, gate_mlp). + """ + if self.emb is not None: + emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) + emb = self.linear(ops.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + ops.chunk(emb, 6, axis=1) + ) + x = self.norm(x) + x = x * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AdaLayerNormContinuous(nn.Module): + r"""Adaptive normalization layer with a norm layer (layer_norm or rms_norm). + + Args: + embedding_dim (`int`): Embedding dimension to use during projection. + conditioning_embedding_dim (`int`): Dimension of the input condition. + elementwise_affine (`bool`, defaults to `True`): + Boolean flag to denote if affine transformation should be applied. + eps (`float`, defaults to 1e-5): Epsilon factor. + bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use. + norm_type (`str`, defaults to `"layer_norm"`): + Normalization layer to use. Values supported: "layer_norm", "rms_norm". + """ + + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + # elementwise_affine=True, + eps: float = 1e-5, + bias: bool = True, + norm_type: str = "layer_norm", + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize adaptive layer normalization continuous module. + + Args: + embedding_dim: Embedding dimension to use during projection. + conditioning_embedding_dim: Dimension of the input condition. + eps: Epsilon factor for normalization. + bias: Whether to use bias in linear projection. + norm_type: Type of normalization to use ("layer_norm" or "rms_norm"). + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + self.silu = ops.silu + self.linear = nn.Linear( + conditioning_embedding_dim, + embedding_dim * 2, + has_bias=bias, + device=device, + dtype=dtype, + ) + if norm_type == "layer_norm": + self.norm = LayerNorm( + embedding_dim, + eps=eps, + devices=[device], + dtype=dtype, + keep_dtype=True, + elementwise_affine=False, + ) + elif norm_type == "rms_norm": + self.norm = RMSNorm( + embedding_dim, eps=eps, device=device, dtype=dtype + ) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def __call__( + self, x: TensorValue, conditioning_embedding: TensorValue + ) -> TensorValue: + """Apply adaptive layer normalization with conditioning. + + Args: + x: Input tensor. + conditioning_embedding: Conditioning embedding tensor. + + Returns: + Normalized and conditioned tensor. + """ + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear(self.silu(conditioning_embedding).cast(x.dtype)) + scale, shift = ops.chunk(emb, 2, axis=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x diff --git a/max/python/max/pipelines/architectures/flux1/model.py b/max/python/max/pipelines/architectures/flux1/model.py new file mode 100644 index 00000000000..8b476f1ab72 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/model.py @@ -0,0 +1,77 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.driver import CPU, Accelerator, Device +from max.engine import InferenceSession, Model +from max.graph import Graph +from max.graph.weights import Weights +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.interfaces.max_model import MaxModel + +from .flux1 import FluxTransformer2DModel +from .model_config import FluxConfig +from .weight_adapters import convert_safetensor_state_dict + + +class Flux1Model(MaxModel): + config_name = FluxConfig.config_name + + def __init__( + self, + config: dict, + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + ) -> None: + super().__init__( + config, + encoding, + devices, + weights, + ) + self.config = FluxConfig.generate( + config, + encoding, + devices, + ) + self.load_model() + + def load_model(self) -> Model: + flux = FluxTransformer2DModel(self.config) + + if self.config.device.is_cpu(): + session = InferenceSession([CPU()]) + else: + session = InferenceSession([Accelerator()]) + state_dict = {key: value.data() for key, value in self.weights.items()} + state_dict = convert_safetensor_state_dict(state_dict) + flux.load_state_dict(state_dict) + with Graph( + "flux_transformer_2d_model", input_types=flux.input_types() + ) as graph: + outputs = flux( + *graph.inputs, + joint_attention_kwargs={}, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict=False, + controlnet_blocks_repeat=False, + ) + graph.output(*outputs) + compiled_graph = graph + self.session = session.load( + compiled_graph, weights_registry=flux.state_dict() + ) + + def __call__(self, *args, **kwargs): + return self.session.execute(*args, **kwargs) diff --git a/max/python/max/pipelines/architectures/flux1/model_config.py b/max/python/max/pipelines/architectures/flux1/model_config.py new file mode 100644 index 00000000000..c9292030000 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/model_config.py @@ -0,0 +1,59 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from typing import ClassVar + +from max.driver import Device +from max.dtype import DType +from max.graph import DeviceRef +from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding +from pydantic import Field + + +class FluxConfigBase(MAXModelConfigBase): + patch_size: int = 1 + in_channels: int = 64 + out_channels: int | None = None + num_layers: int = 19 + num_single_layers: int = 38 + attention_head_dim: int = 128 + num_attention_heads: int = 24 + joint_attention_dim: int = 4096 + pooled_projection_dim: int = 768 + guidance_embeds: bool = False + axes_dims_rope: tuple[int, int, int] = (16, 56, 56) + dtype: DType = DType.bfloat16 + device: DeviceRef = Field(default_factory=DeviceRef.GPU) + + +class FluxConfig(FluxConfigBase): + config_name: ClassVar[str] = "config.json" + + @staticmethod + def generate( + config_dict: dict, + encoding: SupportedEncoding, + devices: list[Device], + ) -> FluxConfigBase: + init_dict = { + key: value + for key, value in config_dict.items() + if key in FluxConfigBase.__annotations__ + } + init_dict.update( + { + "dtype": encoding.dtype, + "device": DeviceRef.from_device(devices[0]), + } + ) + return FluxConfigBase(**init_dict) diff --git a/max/python/max/pipelines/architectures/flux1/weight_adapters.py b/max/python/max/pipelines/architectures/flux1/weight_adapters.py new file mode 100644 index 00000000000..6ef149c0976 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/weight_adapters.py @@ -0,0 +1,30 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import re + +from max.graph.weights import WeightData + + +def convert_safetensor_state_dict( + state_dict: dict[str, WeightData], +) -> dict[str, WeightData]: + keys = list(state_dict.keys()) + for key in keys: + # Remap net.2 to net.1: Diffusers uses [GELU, Dropout, Linear], while MAX uses [GELU, Linear]. + if re.match( + r"transformer_blocks\.\d+\.(ff|ff_context)\.net\.2\.(weight|bias)", + key, + ): + state_dict[key.replace("net.2.", "net.1.")] = state_dict.pop(key) + return state_dict diff --git a/max/python/max/pipelines/architectures/t5/__init__.py b/max/python/max/pipelines/architectures/t5/__init__.py new file mode 100644 index 00000000000..ad108912d49 --- /dev/null +++ b/max/python/max/pipelines/architectures/t5/__init__.py @@ -0,0 +1,14 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .model import T5Model diff --git a/max/python/max/pipelines/architectures/t5/model.py b/max/python/max/pipelines/architectures/t5/model.py new file mode 100644 index 00000000000..f0c53c2fcdc --- /dev/null +++ b/max/python/max/pipelines/architectures/t5/model.py @@ -0,0 +1,64 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.driver import CPU, Accelerator, Device +from max.engine import InferenceSession, Model +from max.graph import Graph +from max.graph.weights import Weights +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.interfaces.max_model import MaxModel + +from .model_config import T5Config +from .t5 import T5EncoderModel + + +class T5Model(MaxModel): + config_name = T5Config.config_name + + def __init__( + self, + config: dict, + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + ) -> None: + super().__init__(config, encoding, devices, weights) + self.config = T5Config.generate( + config, + encoding, + devices, + ) + self.load_model() + + def load_model(self) -> Model: + t5 = T5EncoderModel(self.config) + + if self.config.device.is_cpu(): + session = InferenceSession([CPU()]) + else: + session = InferenceSession([Accelerator()]) + state_dict = {key: value.data() for key, value in self.weights.items()} + t5.load_state_dict(state_dict) + with Graph("t5_encoder_model", input_types=t5.input_types()) as graph: + outputs = t5( + input_ids=graph.inputs[0], + attention_mask=None, + ) + graph.output(outputs) + compiled_graph = graph + self.session = session.load( + compiled_graph, weights_registry=t5.state_dict() + ) + + def __call__(self, *args, **kwargs): + return self.session.execute(*args, **kwargs) diff --git a/max/python/max/pipelines/architectures/t5/model_config.py b/max/python/max/pipelines/architectures/t5/model_config.py new file mode 100644 index 00000000000..ab28ce4b2a8 --- /dev/null +++ b/max/python/max/pipelines/architectures/t5/model_config.py @@ -0,0 +1,69 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from typing import ClassVar + +from max.driver import Device +from max.dtype import DType +from max.graph import DeviceRef +from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding +from pydantic import Field + + +class T5ConfigBase(MAXModelConfigBase): + vocab_size: int = 32128 + d_model: int = 512 + d_kv: int = 64 + d_ff: int = 2048 + num_layers: int = 6 + num_decoder_layers: int | None = None + num_heads: int = 8 + relative_attention_num_buckets: int = 32 + relative_attention_max_distance: int = 128 + dropout_rate: float = 0.1 + layer_norm_epsilon: float = 1e-6 + initializer_factor: float = 1.0 + feed_forward_proj: str = "relu" + dense_act_fn: str | None = Field(default=None, exclude=True) + is_gated_act: bool = Field(default=False, exclude=True) + is_decoder: bool = Field(default=False, exclude=True) + is_encoder_decoder: bool = True + use_cache: bool = True + pad_token_id: int = 0 + eos_token_id: int = 1 + classifier_dropout: float = 0.0 + device: DeviceRef = Field(default_factory=DeviceRef.GPU) + dtype: DType = DType.bfloat16 + + +class T5Config(T5ConfigBase): + config_name: ClassVar[str] = "config.json" + + @staticmethod + def generate( + config_dict: dict, + encoding: SupportedEncoding, + devices: list[Device], + ) -> T5ConfigBase: + init_dict = { + key: value + for key, value in config_dict.items() + if key in T5ConfigBase.__annotations__ + } + init_dict.update( + { + "dtype": encoding.dtype, + "device": DeviceRef.from_device(devices[0]), + } + ) + return T5ConfigBase(**init_dict) diff --git a/max/python/max/pipelines/architectures/t5/t5.py b/max/python/max/pipelines/architectures/t5/t5.py new file mode 100644 index 00000000000..1cd8665a3bc --- /dev/null +++ b/max/python/max/pipelines/architectures/t5/t5.py @@ -0,0 +1,823 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import math + +import max.nn as nn +from max.dtype import DType +from max.graph import DeviceRef, TensorType, TensorValue, Weight, ops +from max.nn import Module + +from .model_config import T5Config + + +class T5LayerNorm(Module): + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.float32, + ): + """Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + + Args: + hidden_size: Hidden size. + eps: Epsilon. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + self.weight = Weight("weight", dtype, (hidden_size,), device=device) + self.variance_epsilon = eps + self.dtype = dtype + + def __call__(self, hidden_states: TensorValue) -> TensorValue: + """Process hidden states through the T5 layer norm. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + hidden_states_f32 = ops.cast(hidden_states, DType.float32) + variance = ops.mean(ops.pow(hidden_states_f32, 2), axis=-1) + hidden_states = hidden_states * ops.rsqrt( + variance + self.variance_epsilon + ) + + # convert into half-precision if necessary + if self.dtype in [DType.float16, DType.bfloat16]: + hidden_states = ops.cast(hidden_states, self.dtype) + + return self.weight * hidden_states + + +class T5DenseActDense(Module): + def __init__( + self, + config: T5Config, + ): + """Construct a dense-activation-dense module. + + Args: + config: T5 configuration for feed-forward dimensions and dtype. + """ + super().__init__() + self.wi = nn.Linear( + config.d_model, + config.d_ff, + has_bias=False, + device=config.device, + dtype=config.dtype, + ) + self.wo = nn.Linear( + config.d_ff, + config.d_model, + has_bias=False, + device=config.device, + dtype=config.dtype, + ) + self.act_fn = ( + lambda x: 0.5 + * x + * ( + 1.0 + + ops.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * ops.pow(x, 3.0)) + ) + ) + ) + + def __call__(self, hidden_states: TensorValue) -> TensorValue: + """Process hidden states through the dense-activation-dense block. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + hidden_states = self.wi(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedActDense(Module): + def __init__( + self, + config: T5Config, + ): + """Construct a dense-gated-activation-dense module. + + Args: + config: T5 configuration for feed-forward dimensions and dtype. + """ + super().__init__() + self.wi_0 = nn.Linear( + config.d_model, + config.d_ff, + has_bias=False, + device=config.device, + dtype=config.dtype, + ) + self.wi_1 = nn.Linear( + config.d_model, + config.d_ff, + has_bias=False, + device=config.device, + dtype=config.dtype, + ) + self.wo = nn.Linear( + config.d_ff, + config.d_model, + has_bias=False, + device=config.device, + dtype=config.dtype, + ) + self.act_fn = ( + lambda x: 0.5 + * x + * ( + 1.0 + + ops.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * ops.pow(x, 3.0)) + ) + ) + ) + + def __call__(self, hidden_states: TensorValue) -> TensorValue: + """Process hidden states through the dense-gated-activation-dense block. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + hidden_gelu = self.act_fn(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(Module): + def __init__( + self, + config: T5Config, + ): + """Construct a feed-forward layer. + + Args: + config: T5 configuration for gating, dimensions, and dtype. + """ + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense(config) + else: + self.DenseReluDense = T5DenseActDense(config) + + self.layer_norm = T5LayerNorm( + config.d_model, + eps=config.layer_norm_epsilon, + device=config.device, + dtype=config.dtype, + ) + + def __call__(self, hidden_states: TensorValue) -> TensorValue: + """Process hidden states through the feed-forward layer. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + forwarded_states + return hidden_states + + +class T5Attention(Module): + def __init__( + self, + config: T5Config, + has_relative_attention_bias: bool = False, + layer_idx: int | None = None, + ): + """Construct an attention layer. + + Args: + config: T5 configuration. + has_relative_attention_bias: Whether to use relative attention bias. + layer_idx: Index of the layer. + """ + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = ( + config.relative_attention_num_buckets + ) + self.relative_attention_max_distance = ( + config.relative_attention_max_distance + ) + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.device = config.device + self.dtype = config.dtype + + self.q = nn.Linear( + self.d_model, + self.inner_dim, + has_bias=False, + device=config.device, + dtype=config.dtype, + ) + self.k = nn.Linear( + self.d_model, + self.inner_dim, + has_bias=False, + device=config.device, + dtype=config.dtype, + ) + self.v = nn.Linear( + self.d_model, + self.inner_dim, + has_bias=False, + device=config.device, + dtype=config.dtype, + ) + self.o = nn.Linear( + self.inner_dim, + self.d_model, + has_bias=False, + device=config.device, + dtype=config.dtype, + ) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding( + self.relative_attention_num_buckets, + self.n_heads, + device=config.device, + dtype=config.dtype, + ) + + def _relative_position_bucket( + self, + relative_position: TensorValue, + bidirectional: bool = True, + num_buckets: int = 32, + max_distance: int = 128, + ) -> TensorValue: + """Compute relative position bucket. + + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Args: + relative_position: Tensor with relative positions. + bidirectional: Whether the attention is bidirectional. + num_buckets: Number of buckets. + max_distance: Maximum distance for relative positions. + + Returns: + TensorValue: Relative position buckets. + """ + relative_buckets = ops.constant(0, DType.int32, self.device) + + if bidirectional: + num_buckets = num_buckets // 2 + is_positive = ops.greater(relative_position, 0) + relative_buckets = relative_buckets + ( + is_positive.cast(DType.int32) * num_buckets + ) + relative_position = ops.abs(relative_position) + else: + relative_position = -ops.min(relative_position, 0) + + max_exact = num_buckets // 2 + is_small = ops.greater(max_exact, relative_position) + + scale = (num_buckets - max_exact) / math.log(max_distance / max_exact) + rel_pos_float = relative_position.cast(DType.float32) + val_log = ops.log(rel_pos_float / float(max_exact)) + relative_position_if_large = max_exact + (val_log * scale).cast( + DType.int32 + ) + relative_position_if_large = ops.min( + relative_position_if_large, num_buckets - 1 + ) + return relative_buckets + ops.where( + is_small, relative_position, relative_position_if_large + ) + + def compute_bias(self, query_length: int, key_length: int) -> TensorValue: + """Compute relative attention bias. + + Args: + query_length: Length of the query sequence. + key_length: Length of the key sequence. + + Returns: + TensorValue: Relative attention bias tensor. + """ + context_position = ops.range( + 0, query_length, step=1, dtype=DType.int32, device=self.device + ) + context_position = ops.unsqueeze(context_position, 1) + + memory_position = ops.range( + 0, key_length, step=1, dtype=DType.int32, device=self.device + ) + memory_position = ops.unsqueeze(memory_position, 0) + + relative_position = memory_position - context_position + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) + values = ops.permute(values, (2, 0, 1)) + values = ops.unsqueeze(values, 0) + return values + + def __call__( + self, + hidden_states: TensorValue, + mask: TensorValue | None = None, + key_value_states: TensorValue | None = None, + position_bias: TensorValue | None = None, + past_key_values: TensorValue | None = None, + layer_head_mask: TensorValue | None = None, + query_length: int | None = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: TensorValue | None = None, + ) -> tuple[TensorValue, TensorValue]: + """Process hidden states through the attention layer. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + mask: Attention mask. + key_value_states: Key-value states for cross-attention. + position_bias: Position bias tensor. + past_key_values: Past key values for caching (not implemented). + layer_head_mask: Mask for attention heads. + query_length: Length of the query sequence. + use_cache: Whether to use cache (not implemented). + output_attentions: Whether to return attention weights. + cache_position: Cache position. + + Returns: + Tuple[TensorValue, TensorValue]: Output tensor and position bias. + """ + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + if is_cross_attention: + raise NotImplementedError( + "T5 CrossAttention is not implemented yet." + ) + if past_key_values is not None: + raise NotImplementedError( + "T5 auto regressive model is not implemented yet." + ) + + query = self.q(hidden_states) + key = self.k(hidden_states) + value = self.v(hidden_states) + + # Reshape to (batch, seq, heads, head_dim) + query = ops.reshape( + query, + (batch_size, seq_length, self.n_heads, self.key_value_proj_dim), + ) + key = ops.reshape( + key, (batch_size, seq_length, self.n_heads, self.key_value_proj_dim) + ) + value = ops.reshape( + value, + (batch_size, seq_length, self.n_heads, self.key_value_proj_dim), + ) + + # Transpose to (batch, heads, seq, head_dim) + query = ops.permute(query, (0, 2, 1, 3)) + key = ops.permute(key, (0, 2, 1, 3)) + value = ops.permute(value, (0, 2, 1, 3)) + + scores = ops.matmul(query, ops.permute(key, (0, 1, 3, 2))) + + if position_bias is None and self.has_relative_attention_bias: + position_bias = self.compute_bias(seq_length, seq_length) + + if position_bias is not None: + scores = scores + position_bias + + if mask is not None: + scores = scores + mask + + attn_weights = ops.softmax(ops.cast(scores, DType.float32), axis=-1) + attn_weights = ops.cast(attn_weights, self.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = ops.matmul(attn_weights, value) + attn_output = ops.permute(attn_output, (0, 2, 1, 3)) + attn_output = ops.reshape( + attn_output, (batch_size, seq_length, self.inner_dim) + ) + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class T5LayerSelfAttention(Module): + def __init__( + self, + config: T5Config, + has_relative_attention_bias: bool = False, + layer_idx: int | None = None, + ): + """Construct a self-attention layer. + + Args: + config: T5 configuration. + has_relative_attention_bias: Whether to use relative attention bias. + layer_idx: Index of the layer. + """ + super().__init__() + self.SelfAttention = T5Attention( + config, + has_relative_attention_bias=has_relative_attention_bias, + layer_idx=layer_idx, + ) + self.layer_norm = T5LayerNorm( + config.d_model, + eps=config.layer_norm_epsilon, + device=config.device, + dtype=config.dtype, + ) + + def __call__( + self, + hidden_states: TensorValue, + attention_mask: TensorValue | None = None, + position_bias: TensorValue | None = None, + layer_head_mask: TensorValue | None = None, + past_key_values: TensorValue | None = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: TensorValue | None = None, + ) -> TensorValue: + """Process hidden states through the self-attention layer. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + attention_mask: Attention mask. + position_bias: Position bias tensor. + layer_head_mask: Mask for attention heads. + past_key_values: Past key values for caching (not implemented). + use_cache: Whether to use cache (not implemented). + output_attentions: Whether to return attention weights. + cache_position: Cache position. + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = hidden_states + attention_output[0] + outputs = (hidden_states,) + attention_output[1:] + return outputs + + +class T5Block(Module): + def __init__( + self, + config: T5Config, + has_relative_attention_bias: bool = False, + layer_idx: int | None = None, + ): + """Construct a T5 block. + + Args: + config: T5 configuration. + has_relative_attention_bias: Whether to use relative attention bias. + layer_idx: Index of the layer. + """ + super().__init__() + layers = list() + self.is_decoder = config.is_decoder + if self.is_decoder: + raise NotImplementedError( + "T5 LayerCrossAttention is not implemented yet." + ) + + layers.append( + T5LayerSelfAttention( + config, + has_relative_attention_bias=has_relative_attention_bias, + layer_idx=layer_idx, + ) + ) + layers.append(T5LayerFF(config)) + self.layer = nn.LayerList(layers) + + def __call__( + self, + hidden_states: TensorValue, + attention_mask: TensorValue | None = None, + position_bias: TensorValue | None = None, + encoder_hidden_states: TensorValue | None = None, + encoder_attention_mask: TensorValue | None = None, + encoder_decoder_position_bias: TensorValue | None = None, + cross_attn_layer_head_mask: TensorValue | None = None, + layer_head_mask: TensorValue | None = None, + past_key_values: TensorValue | None = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: TensorValue | None = None, + ) -> tuple[TensorValue, TensorValue]: + """Process hidden states through the T5 block. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + attention_mask: Attention mask. + position_bias: Position bias tensor. + encoder_hidden_states: Encoder hidden states (not implemented). + encoder_attention_mask: Encoder attention mask (not implemented). + encoder_decoder_position_bias: Encoder-decoder position bias (not implemented). + cross_attn_layer_head_mask: Cross attention layer head mask (not implemented). + layer_head_mask: Mask for attention heads. + past_key_values: Past key values for caching (not implemented). + use_cache: Whether to use cache (not implemented). + output_attentions: Whether to return attention weights. + cache_position: Cache position. + + Returns: + Tuple[TensorValue, TensorValue]: Output tensor and position bias. + """ + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] + + if hidden_states.dtype == DType.float16: + clamp_value = DType.finfo(hidden_states.dtype).max - 1000 + hidden_states = nn.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + do_cross_attention = ( + self.is_decoder and encoder_hidden_states is not None + ) + if do_cross_attention: + raise NotImplementedError( + "T5 CrossAttention is not implemented yet." + ) + + hidden_states = self.layer[-1](hidden_states) + if hidden_states.dtype == DType.float16: + clamp_value = DType.finfo(hidden_states.dtype).max - 1000 + hidden_states = nn.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + outputs = (hidden_states,) + return outputs + attention_outputs + + +class T5Stack(Module): + def __init__( + self, + config: T5Config, + embed_tokens: nn.Embedding | None = None, + ): + """Construct a T5 stack. + + Args: + config: T5 configuration. + embed_tokens: Embedding module. + """ + super().__init__() + self.config = config + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.LayerList( + [ + T5Block( + config, + has_relative_attention_bias=bool(i == 0), + layer_idx=i, + ) + for i in range(config.num_layers) + ] + ) + self.final_layer_norm = T5LayerNorm( + config.d_model, + eps=config.layer_norm_epsilon, + device=config.device, + dtype=config.dtype, + ) + self.dropout = config.dropout_rate + self.device = config.device + self.dtype = config.dtype + + def __call__( + self, + input_ids: TensorValue | None = None, + attention_mask: TensorValue | None = None, + inputs_embeds: TensorValue | None = None, + encoder_hidden_states: TensorValue | None = None, + encoder_attention_mask: TensorValue | None = None, + encoder_decoder_position_bias: TensorValue | None = None, + cross_attn_layer_head_mask: TensorValue | None = None, + layer_head_mask: TensorValue | None = None, + past_key_values: TensorValue | None = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: TensorValue | None = None, + ) -> TensorValue: + """Process input through the T5 stack. + + Args: + input_ids: Input IDs tensor of shape (batch_size, seq_length). + attention_mask: Attention mask tensor of shape (batch_size, seq_length). + inputs_embeds: Input embeddings tensor of shape (batch_size, seq_length, hidden_size). + encoder_hidden_states: Encoder hidden states (not implemented). + encoder_attention_mask: Encoder attention mask (not implemented). + encoder_decoder_position_bias: Encoder-decoder position bias (not implemented). + cross_attn_layer_head_mask: Cross attention layer head mask (not implemented). + layer_head_mask: Mask for attention heads. + past_key_values: Past key values for caching (not implemented). + use_cache: Whether to use cache (not implemented). + output_attentions: Whether to return attention weights. + cache_position: Cache position. + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + use_cache = ( + use_cache if use_cache is not None else self.config.use_cache + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + if input_ids is not None: + inputs_embeds = self.embed_tokens(input_ids) + elif inputs_embeds is None: + raise ValueError( + "You have to specify either input_ids or inputs_embeds" + ) + + if self.is_decoder or use_cache: + raise NotImplementedError("T5 decoder is not implemented yet.") + + hidden_states = inputs_embeds + + if attention_mask is not None: + causal_mask = ( + 1.0 - ops.cast(attention_mask, hidden_states.dtype) + ) * DType.finfo(hidden_states.dtype).min + causal_mask = ops.unsqueeze(causal_mask, 1) + causal_mask = ops.unsqueeze(causal_mask, 1) + else: + causal_mask = None + encoder_extended_attention_mask = None + + position_bias = None + for layer_module in self.block: + layer_outputs = layer_module( + hidden_states, + attention_mask=causal_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = layer_outputs[0] + position_bias = layer_outputs[1] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) + + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + +class T5EncoderModel(Module): + def __init__( + self, + config: T5Config, + ): + """Construct a T5 encoder model. + + Args: + config: T5 configuration for vocabulary size, layer counts, and + device/dtype settings. + """ + super().__init__() + act_info = config.feed_forward_proj.split("-") + config.dense_act_fn = act_info[-1] + config.is_gated_act = act_info[0] == "gated" + + self.shared = nn.Embedding( + config.vocab_size, + config.d_model, + device=config.device, + dtype=config.dtype, + ) + + encoder_config = config + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + + self.encoder = T5Stack(encoder_config, self.shared) + self.device = config.device + self.dtype = config.dtype + + def input_types(self) -> tuple[TensorType, ...]: + """Get input types for the model. + + Returns: + tuple[TensorType, ...]: Input types. + """ + return ( + TensorType( + DType.int64, + shape=["batch_size", "sequence_length"], + device=self.device, + ), + ) + + def __call__( + self, + input_ids: TensorValue | None = None, + attention_mask: TensorValue | None = None, + ) -> TensorValue: + """Process input through the T5 encoder model. + + Args: + input_ids: Input IDs tensor of shape (batch_size, seq_length). + attention_mask: Attention mask tensor of shape (batch_size, seq_length). + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + return self.encoder(input_ids=input_ids, attention_mask=attention_mask) diff --git a/max/python/max/pipelines/lib/interfaces/max_model.py b/max/python/max/pipelines/lib/interfaces/max_model.py new file mode 100644 index 00000000000..3f323d81ad8 --- /dev/null +++ b/max/python/max/pipelines/lib/interfaces/max_model.py @@ -0,0 +1,45 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from max.driver import Device +from max.engine import Model +from max.graph.weights import Weights + +if TYPE_CHECKING: + from max.pipelines.lib import SupportedEncoding + + +class MaxModel(ABC): + """Base interface for pipeline models with weight-backed execution.""" + + def __init__( + self, + config: dict, + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + ) -> None: + self.config = config + self.encoding = encoding + self.devices = devices + self.weights = weights + + @abstractmethod + def load_model(self) -> Model: + """Load and return a runtime model instance.""" + ... From 3a9e862df66919819c382db37ebdd30377cef4b7 Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Fri, 16 Jan 2026 09:44:23 +0000 Subject: [PATCH 05/18] feat: add pipeline definition for flux1 --- max/python/max/config/__init__.py | 22 +- max/python/max/experimental/BUILD.bazel | 1 + max/python/max/experimental/compile_utils.py | 97 ++ .../architectures/flux1/pipeline_flux.py | 788 ++++++++++++++++ max/python/max/pipelines/lib/config.py | 20 +- .../lib/diffusion_schedulers/__init__.py | 16 + .../scheduling_flow_match_euler_discrete.py | 852 ++++++++++++++++++ max/python/max/pipelines/lib/hf_utils.py | 49 +- .../max/pipelines/lib/image_processor.py | 226 +++++ .../max/pipelines/lib/interfaces/__init__.py | 2 + .../lib/interfaces/diffusion_pipeline.py | 177 ++++ max/python/max/pipelines/lib/model_config.py | 6 +- 12 files changed, 2234 insertions(+), 22 deletions(-) create mode 100644 max/python/max/experimental/compile_utils.py create mode 100644 max/python/max/pipelines/architectures/flux1/pipeline_flux.py create mode 100644 max/python/max/pipelines/lib/diffusion_schedulers/__init__.py create mode 100644 max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py create mode 100644 max/python/max/pipelines/lib/image_processor.py create mode 100644 max/python/max/pipelines/lib/interfaces/diffusion_pipeline.py diff --git a/max/python/max/config/__init__.py b/max/python/max/config/__init__.py index 4b10a5e89d1..440822a3478 100644 --- a/max/python/max/config/__init__.py +++ b/max/python/max/config/__init__.py @@ -16,7 +16,9 @@ import argparse import enum +import json import logging +import os import types from abc import abstractmethod from collections.abc import Mapping @@ -506,6 +508,7 @@ def _extract_max_config_data( config_dict: The loaded YAML configuration dictionary. config_class: The config class we're extracting data for. section_name: Optional specific section name to look for. + config_file_path: Path to the config file for resolving inheritance. Returns: Configuration data for the specific config class. @@ -854,9 +857,9 @@ def _add_field_as_argument( ): # For enums, use the string value as default but we'll need to convert back arg_kwargs = { - "default": field_value.value - if field_value - else field_obj.default + "default": ( + field_value.value if field_value else field_obj.default + ) } else: arg_kwargs = {"default": field_value} @@ -1071,6 +1074,19 @@ def parse_args( # type: ignore[override] # noqa: ANN202 return MAXConfigArgumentParser(parser, self) +def load_config(config_path: str | os.PathLike) -> dict: + if not os.path.exists(config_path): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + try: + with open(config_path, encoding="utf-8") as f: + config_dict = json.loads(f.read()) + except Exception as e: + raise ValueError( + f"Failed to load configuration from {config_path}: {e}" + ) from e + return config_dict + + all = [ "MAXBaseModel", "ConfigFileModel", diff --git a/max/python/max/experimental/BUILD.bazel b/max/python/max/experimental/BUILD.bazel index 9c95184007c..946a73ca810 100644 --- a/max/python/max/experimental/BUILD.bazel +++ b/max/python/max/experimental/BUILD.bazel @@ -9,6 +9,7 @@ modular_py_library( "_passes.py", "_tensor_repr.py", "functional.py", + "compile_utils.py", "random.py", "realization_context.py", "support.py", diff --git a/max/python/max/experimental/compile_utils.py b/max/python/max/experimental/compile_utils.py new file mode 100644 index 00000000000..13c10b93d83 --- /dev/null +++ b/max/python/max/experimental/compile_utils.py @@ -0,0 +1,97 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from collections.abc import Callable, Iterable +from typing import Any + +from max.driver import CPU, Accelerator +from max.engine import InferenceSession +from max.graph import Graph, TensorType +from max.nn.module_v3 import Module + + +class CompileWrapper: + def __init__( + self, + compile_target: Callable | Module, + input_types: Iterable[TensorType] | None = None, + ) -> None: + """Initialize the CompileWrapper. + + Args: + compile_target: The function or module to be compiled. + input_types: A list of input types (TensorTypes) required for compilation. + + Raises: + ValueError: If input_types is not provided. + """ + if input_types is None: + raise ValueError( + f"input_types must be provided for compilation of {compile_target.__name__}." + ) + + self.is_module = False + if isinstance(compile_target, Module): + self.is_module = True + self.session = compile_target.compile(input_types) + return + + with Graph(compile_target.__name__, input_types=input_types) as graph: + output = compile_target(*graph.inputs) + graph.output(output) + compiled_graph = graph + + if any(input_type.device.is_gpu() for input_type in input_types): + device = Accelerator() + else: + device = CPU() + session = InferenceSession([device]) + loaded_session = session.load(compiled_graph) + self.session = loaded_session + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Execute the compiled session with the given arguments. + + Args: + *args: Positional arguments to pass to the session. + **kwargs: Keyword arguments to pass to the session. + + Returns: + The result of the session execution. + """ + if self.is_module: + return self.session(*args, **kwargs) + return self.session.execute(*args, **kwargs) + + +def max_compile( + compile_target: Callable | Module | None = None, + input_types: Iterable[TensorType] | None = None, +) -> Callable[[Callable | Module], CompileWrapper] | CompileWrapper: + """Decorator or function to compile a target with specified input types. + + Args: + compile_target: The function or module to compile. If None, returns a decorator. + input_types: The input types for the compilation. + + Returns: + A CompileWrapper instance if compile_target is provided, otherwise a decorator. + """ + if compile_target is None: + + def decorator(f: Callable | Module) -> CompileWrapper: + return CompileWrapper(f, input_types) + + return decorator + + return CompileWrapper(compile_target, input_types) diff --git a/max/python/max/pipelines/architectures/flux1/pipeline_flux.py b/max/python/max/pipelines/architectures/flux1/pipeline_flux.py new file mode 100644 index 00000000000..deacff48a87 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/pipeline_flux.py @@ -0,0 +1,788 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import inspect +import os +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import numpy as np +import PIL.Image +from max.driver import Tensor +from max.dtype import DType +from max.experimental import Tensor as Tensor_v3 +from max.experimental import functional as F +from max.experimental import random +from max.graph import DeviceRef +from max.pipelines.lib.diffusion_schedulers import ( + FlowMatchEulerDiscreteScheduler, +) +from max.pipelines.lib.image_processor import ( + PipelineImageInput, + VaeImageProcessor, +) +from max.pipelines.lib.interfaces.diffusion_pipeline import ( + DiffusionPipeline, +) +from tqdm import tqdm +from transformers import ( + CLIPTokenizer, + T5TokenizerFast, +) + +from ..autoencoder_kl import AutoencoderKLModel +from ..clip import ClipModel +from ..t5 import T5Model +from .model import Flux1Model + + +def retrieve_timesteps( + scheduler: Any, + num_inference_steps: int | None = None, + device: str | DeviceRef | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs: Any, +) -> tuple[np.ndarray, int]: + r"""Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. + + Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `DeviceRef`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + **kwargs (`Any`, *optional*): + Additional arguments to pass to the scheduler's `set_timesteps` method. + + Returns: + `tuple[Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = int(timesteps.shape[0]) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = int(timesteps.shape[0]) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def calculate_shift( + image_seq_len: int, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +) -> float: + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +@dataclass +class FluxPipelineOutput: + """Output class for Flux image generation pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray` or `Tensor`) + List of denoised PIL images of length `batch_size` or numpy array or Max tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Max tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. + """ + + images: list[PIL.Image.Image] | np.ndarray | Tensor + + +class FluxPipeline(DiffusionPipeline): + config_name = "model_index.json" + + components = { + "scheduler": FlowMatchEulerDiscreteScheduler, + "vae": AutoencoderKLModel, + "text_encoder": ClipModel, + "tokenizer": CLIPTokenizer, + "text_encoder_2": T5Model, + "tokenizer_2": T5TokenizerFast, + "transformer": Flux1Model, + } + + def init_remaining_components(self) -> None: + image_processor_class = self.components.get( + "image_processor", VaeImageProcessor + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) + if getattr(self, "vae", None) + else 8 + ) + image_processor = image_processor_class( + vae_scale_factor=self.vae_scale_factor * 2 + ) + self.image_processor = image_processor + + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str], + device: DeviceRef | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: Tensor | None = None, + pooled_prompt_embeds: Tensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ) -> tuple[Tensor, Tensor, Tensor]: + r"""Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`DeviceRef`): + Max device + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + prompt_embeds (`Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + max_sequence_length (`int`, defaults to 512): Maximum sequence length to use with the `prompt`. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + if lora_scale is not None and isinstance(self, FluxPipeline): + self._lora_scale = lora_scale + + if self.text_encoder is not None and hasattr( + self.text_encoder, "set_lora_scale" + ): + self.text_encoder.set_lora_scale(lora_scale) + if self.text_encoder_2 is not None and hasattr( + self.text_encoder_2, "set_lora_scale" + ): + self.text_encoder_2.set_lora_scale(lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=min( + max_sequence_length, self.tokenizer.model_max_length + ), + truncation=True, + return_length=False, + return_overflowing_tokens=False, + ) + text_input_ids = Tensor_v3.constant( + text_inputs.input_ids, device=device, dtype=DType.int64 + ) + + text_encoder_outputs = self.text_encoder(text_input_ids) + prompt_embeds = text_encoder_outputs[0] + pooled_prompt_embeds = text_encoder_outputs[1] + + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + if self.text_encoder_2 is not None: + text_inputs_2 = self.tokenizer_2( + prompt_2, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + ) + text_input_ids_2 = Tensor_v3.constant( + text_inputs_2.input_ids, device=device, dtype=DType.int64 + ) + + prompt_embeds_2 = self.text_encoder_2(text_input_ids_2)[0] + else: + prompt_embeds_2 = None + + if prompt_embeds_2 is not None: + prompt_embeds = prompt_embeds_2 + + text_ids = Tensor_v3.zeros( + (prompt_embeds.shape[1], 3), + device=device, + dtype=prompt_embeds.dtype, + ) + + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = Tensor_v3.from_dlpack( + prompt_embeds + ) # V2 Tensor to V3 Tensor + pooled_prompt_embeds = Tensor_v3.from_dlpack( + pooled_prompt_embeds + ) # V2 Tensor to V3 Tensor + + prompt_embeds = F.tile(prompt_embeds, (1, num_images_per_prompt, 1)) + prompt_embeds = prompt_embeds.reshape( + (bs_embed * num_images_per_prompt, seq_len, -1) + ) + + pooled_prompt_embeds = F.tile( + pooled_prompt_embeds, (1, num_images_per_prompt) + ) + pooled_prompt_embeds = pooled_prompt_embeds.reshape( + (bs_embed * num_images_per_prompt, -1) + ) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + @staticmethod + def _prepare_latent_image_ids( + batch_size: int, + height: int, + width: int, + device: DeviceRef, + dtype: DType, + ) -> Tensor_v3: + latent_image_ids = np.stack( + [ + np.zeros((height, width)), + np.broadcast_to(np.arange(height)[:, None], (height, width)), + np.broadcast_to(np.arange(width)[None, :], (height, width)), + ], + axis=-1, + ) + + ( + latent_image_id_height, + latent_image_id_width, + latent_image_id_channels, + ) = latent_image_ids.shape + + latent_image_ids = np.reshape( + latent_image_ids, + ( + latent_image_id_height * latent_image_id_width, + latent_image_id_channels, + ), + ) + latent_image_ids = ( + Tensor_v3.from_dlpack(latent_image_ids).to(device).cast(dtype) + ) + + return latent_image_ids + + @staticmethod + def _pack_latents( + latents: Tensor_v3, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + ) -> Tensor_v3: + latents = F.reshape( + latents, + (batch_size, num_channels_latents, height // 2, 2, width // 2, 2), + ) + latents = F.permute(latents, (0, 2, 4, 1, 3, 5)) + latents = F.reshape( + latents, + ( + batch_size, + (height // 2) * (width // 2), + num_channels_latents * 4, + ), + ) + + return latents + + @staticmethod + def _unpack_latents( + latents: Tensor_v3, + height: int, + width: int, + vae_scale_factor: int, + ) -> Tensor_v3: + # TODO: should compile this function for speed up. + batch_size, _, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (height // (vae_scale_factor * 2)) + width = 2 * (width // (vae_scale_factor * 2)) + + latents = F.reshape( + latents, + (batch_size.dim, height // 2, width // 2, channels.dim // 4, 2, 2), + ) + latents = F.permute(latents, (0, 3, 1, 4, 2, 5)) + + latents = F.reshape( + latents, (batch_size.dim, channels.dim // (2 * 2), height, width) + ) + + return latents + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: DType, + device: DeviceRef, + latents: Tensor_v3 | None = None, + ) -> tuple[Tensor_v3, Tensor_v3]: + """Prepare latents for the Flux pipeline. + + Args: + batch_size: The number of images to generate. + num_channels_latents: The number of latent channels. + height: The height of the generated image. + width: The width of the generated image. + dtype: The data type for the latents. + device: The device to run on. + latents: Pre-generated latents. + + Returns: + Tuple of latents and latent image ids. + """ + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids( + batch_size, height // 2, width // 2, device, dtype + ) + return latents.to(device).cast(dtype), latent_image_ids + + # NOTE: Max random generation uses different seed with torch.randn. + # So, we currently leave torch randn as optional for + # functionality comparison with the original diffusers pipeline. + if os.environ.get("USE_TORCH_RANDN", "0") == "1": + import torch + + seed = int(os.environ.get("SEED", 42)) + generator = torch.Generator(device="cuda").manual_seed(seed) + latents = torch.randn( + shape, + generator=generator, + device="cuda", + dtype=dtype.to_torch(), + ) + latents = Tensor_v3.from_dlpack(latents) + else: + latents = random.normal(shape, device=device, dtype=dtype) + latents = self._pack_latents( + latents, batch_size, num_channels_latents, height, width + ) + + latent_image_ids = self._prepare_latent_image_ids( + batch_size, height // 2, width // 2, device, dtype + ) + + return latents, latent_image_ids + + def __call__( + self, + prompt: str | list[str] | None = None, + prompt_2: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + true_cfg_scale: float = 1.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 3.5, + num_images_per_prompt: int | None = 1, + latents: Tensor | None = None, + prompt_embeds: Tensor | None = None, + pooled_prompt_embeds: Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[Tensor] | None = None, + negative_ip_adapter_image: PipelineImageInput | None = None, + negative_ip_adapter_image_embeds: list[Tensor] | None = None, + negative_prompt_embeds: Tensor | None = None, + negative_pooled_prompt_embeds: Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] | None = None, + max_sequence_length: int = 512, + ): + r"""Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + `negative_prompt` is provided. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + latents (`Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`List[Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device() + + lora_scale = ( + self._joint_attention_kwargs.get("scale", None) + if self._joint_attention_kwargs is not None + else None + ) + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None + and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + latents, + ) + + # 5. Prepare timesteps + sigmas = ( + np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + if sigmas is None + else sigmas + ) + if ( + hasattr(self.scheduler, "use_flow_sigmas") + and self.scheduler.use_flow_sigmas + ): + sigmas = None + image_seq_len = latents.shape[1].dim + mu = calculate_shift( + image_seq_len, + self.scheduler.base_image_seq_len, + self.scheduler.max_image_seq_len, + self.scheduler.base_shift, + self.scheduler.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + + self._num_timesteps = timesteps.shape[0] + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = Tensor_v3.full( + [latents.shape[0].dim], + guidance_scale, + device=device, + dtype=prompt_embeds.dtype, + ) + else: + guidance = Tensor_v3.zeros( + [latents.shape[0].dim], + device=device, + dtype=prompt_embeds.dtype, + ) + + if ( + ip_adapter_image is not None + or ip_adapter_image_embeds is not None + or negative_ip_adapter_image is not None + or negative_ip_adapter_image_embeds is not None + ): + raise NotImplementedError( + "IP adapter is not supported for Max yet." + ) + + if self._joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + + # 6. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + batch_size = latents.shape[0].dim + for i in tqdm(range(self._num_timesteps), desc="Denoising"): + if self._interrupt: + continue + + t = timesteps[i] + self._current_timestep = t + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = ( + image_embeds + ) + + # NOTE: Convert timesteps to a Max Tensor before denoising loop, + # as in the original implementation, results in a significant slow down. + # As a workaround, we keep timesteps as a numpy array and convert it + # to a Max Tensor here. This might require a more efficient way to handle this. + # Converting to a Max module V3 Tensor also results in a significant slow down. + timestep = np.full((batch_size,), t) / 1000.0 + timestep = Tensor.from_dlpack(timestep).to(prompt_embeds.device) + + noise_pred = self.transformer( + latents, + prompt_embeds, + pooled_prompt_embeds, + timestep, + latent_image_ids, + text_ids, + guidance, + )[0] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = ( + negative_image_embeds + ) + + neg_noise_pred = self.transformer( + latents, + negative_prompt_embeds, + negative_pooled_prompt_embeds, + timestep, + latent_image_ids, + negative_text_ids, + guidance, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * ( + noise_pred - neg_noise_pred + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False + )[0] + + if latents.dtype != latents_dtype: + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs + ) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds + ) + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = Tensor_v3.from_dlpack(latents) # V2 Tensor to V3 Tensor + latents = self._unpack_latents( + latents, height, width, self.vae_scale_factor + ) + latents = ( + latents / self.vae.config.scaling_factor + ) + self.vae.config.shift_factor + image = self.vae.decode(latents)[0] + + image = Tensor_v3.from_dlpack(image) # V2 Tensor to V3 Tensor + image = self.image_processor.postprocess( + image, output_type=output_type + ) + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/max/python/max/pipelines/lib/config.py b/max/python/max/pipelines/lib/config.py index d3150c90728..e1d0a3e1410 100644 --- a/max/python/max/pipelines/lib/config.py +++ b/max/python/max/pipelines/lib/config.py @@ -28,6 +28,7 @@ from max.driver import DeviceSpec, load_devices from max.engine import InferenceSession from max.graph.quantization import QuantizationEncoding +from max.interfaces import PipelineTask from max.serve.queue.zmq_queue import generate_zmq_ipc_path from pydantic import ( Field, @@ -910,6 +911,11 @@ def _validate_and_resolve_remaining_pipeline_config( # memory estimations. arch.pipeline_model.finalize_pipeline_config(self) + if arch.task == PipelineTask.IMAGE_GENERATION: + # diffusion pipeline does not use KV cache, + # so we can skip profile run. + return + MemoryEstimator.estimate_memory_footprint( self, arch.pipeline_model, @@ -1140,12 +1146,14 @@ def log_basic_config(self) -> None: pipeline_class = get_pipeline_for_task(task, self) # Get reserved memory info from KVCache config - kv_config = self.model._kv_cache - if kv_config._available_cache_memory is None: - raise ValueError( - "KVCache config is not available after config resolution." - ) - memory_str = to_human_readable_bytes(kv_config._available_cache_memory) + memory_str = None + if task != PipelineTask.IMAGE_GENERATION: + kv_config = self.model._kv_cache + if kv_config._available_cache_memory is None: + raise ValueError( + "KVCache config is not available after config resolution." + ) + memory_str = to_human_readable_bytes(kv_config._available_cache_memory) devices_str = ", ".join( f"{d.device_type}[{d.id}]" for d in self.model.device_specs diff --git a/max/python/max/pipelines/lib/diffusion_schedulers/__init__.py b/max/python/max/pipelines/lib/diffusion_schedulers/__init__.py new file mode 100644 index 00000000000..df278d92447 --- /dev/null +++ b/max/python/max/pipelines/lib/diffusion_schedulers/__init__.py @@ -0,0 +1,16 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) diff --git a/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py b/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py new file mode 100644 index 00000000000..1eb39c90b33 --- /dev/null +++ b/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py @@ -0,0 +1,852 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import logging +import math +from dataclasses import dataclass + +import numpy as np +from max.driver import CPU, Accelerator, Device +from max.dtype import DType +from max.engine import InferenceSession +from max.experimental import Tensor, random +from max.graph import DeviceRef, Graph, TensorType + +try: + import scipy.stats + + is_scipy_available = True +except ImportError: + is_scipy_available = False + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class FlowMatchEulerDiscreteSchedulerOutput: + """Output class for the scheduler's `step` function output. + + Args: + prev_sample (`Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: Tensor + + +class FlowMatchEulerDiscreteScheduler: + """Euler scheduler. + + Native Modular implementation (ported from diffusers). + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + use_dynamic_shifting (`bool`, defaults to False): + Whether to apply timestep shifting on-the-fly based on the image resolution. + base_shift (`float`, defaults to 0.5): + Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent + with desired output. + max_shift (`float`, defaults to 1.15): + Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be + more exaggerated or stylized. + base_image_seq_len (`int`, defaults to 256): + The base image sequence length. + max_image_seq_len (`int`, defaults to 4096): + The maximum image sequence length. + invert_sigmas (`bool`, defaults to False): + Whether to invert the sigmas. + shift_terminal (`float`, defaults to None): + The end value of the shifted timestep schedule. + use_karras_sigmas (`bool`, defaults to False): + Whether to use Karras sigmas for step sizes in the noise schedule during sampling. + use_exponential_sigmas (`bool`, defaults to False): + Whether to use exponential sigmas for step sizes in the noise schedule during sampling. + use_beta_sigmas (`bool`, defaults to False): + Whether to use beta sigmas for step sizes in the noise schedule during sampling. + time_shift_type (`str`, defaults to "exponential"): + The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". + stochastic_sampling (`bool`, defaults to False): + Whether to use stochastic sampling. + """ + + config_name = "scheduler_config.json" + + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting: bool = False, + base_shift: float | None = 0.5, + max_shift: float | None = 1.15, + base_image_seq_len: int | None = 256, + max_image_seq_len: int | None = 4096, + invert_sigmas: bool = False, + shift_terminal: float | None = None, + use_karras_sigmas: bool | None = False, + use_exponential_sigmas: bool | None = False, + use_beta_sigmas: bool | None = False, + time_shift_type: str = "exponential", + stochastic_sampling: bool = False, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.float32, + **kwargs, + ): + """Initialize the scheduler. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + use_dynamic_shifting (`bool`, defaults to False): + Whether to apply timestep shifting on-the-fly based on the image resolution. + base_shift (`float`, defaults to 0.5): + Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent + with desired output. + max_shift (`float`, defaults to 1.15): + Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be + more exaggerated or stylized. + base_image_seq_len (`int`, defaults to 256): + The base image sequence length. + max_image_seq_len (`int`, defaults to 4096): + The maximum image sequence length. + invert_sigmas (`bool`, defaults to False): + Whether to invert the sigmas. + shift_terminal (`float`, defaults to None): + The end value of the shifted timestep schedule. + use_karras_sigmas (`bool`, defaults to False): + Whether to use Karras sigmas for step sizes in the noise schedule during sampling. + use_exponential_sigmas (`bool`, defaults to False): + Whether to use exponential sigmas for step sizes in the noise schedule during sampling. + use_beta_sigmas (`bool`, defaults to False): + Whether to use beta sigmas for step sizes in the noise schedule during sampling. + time_shift_type (`str`, defaults to "exponential"): + The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". + stochastic_sampling (`bool`, defaults to False): + Whether to use stochastic sampling. + device (`DeviceRef`, defaults to `DeviceRef.CPU()`): + The device to use. + dtype (`DType`, defaults to `DType.float32`): + The dtype to use. + """ + self.num_train_timesteps = num_train_timesteps + self._shift = shift + self.use_dynamic_shifting = use_dynamic_shifting + self.base_shift = base_shift + self.max_shift = max_shift + self.base_image_seq_len = base_image_seq_len + self.max_image_seq_len = max_image_seq_len + self.invert_sigmas = invert_sigmas + self.shift_terminal = shift_terminal + self.use_karras_sigmas = use_karras_sigmas + self.use_exponential_sigmas = use_exponential_sigmas + self.use_beta_sigmas = use_beta_sigmas + self.time_shift_type = time_shift_type + self.stochastic_sampling = stochastic_sampling + self.device = device + self.dtype = dtype + + if self.use_beta_sigmas and not is_scipy_available: + raise ImportError( + "Make sure to install scipy if you want to use beta sigmas." + ) + if ( + sum( + [ + self.use_beta_sigmas, + self.use_exponential_sigmas, + self.use_karras_sigmas, + ] + ) + > 1 + ): + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) + if time_shift_type not in {"exponential", "linear"}: + raise ValueError( + "`time_shift_type` must either be 'exponential' or 'linear'." + ) + + timesteps = np.linspace( + 1, num_train_timesteps, num_train_timesteps, dtype=np.float32 + )[::-1].copy() + + sigmas = timesteps / num_train_timesteps + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self._shift = shift + + self.sigmas = sigmas + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + self.load_model() + + @property + def shift(self) -> float: + """The value used for shifting.""" + return self._shift + + @property + def step_index(self) -> int: + """The index counter for current timestep. It will increase 1 after each scheduler step.""" + return self._step_index + + @property + def begin_index(self) -> int: + """The index for the first timestep. It should be set from pipeline with `set_begin_index` method.""" + return self._begin_index + + def set_begin_index(self, begin_index: int = 0) -> None: + """Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`, defaults to `0`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_shift(self, shift: float) -> None: + """Set the shift value.""" + self._shift = shift + + def scale_noise( + self, + sample: Tensor, + timestep: float | Tensor, + noise: Tensor | None = None, + ) -> Tensor: + """Forward process in flow-matching. + + Args: + sample (`Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + noise (`Tensor`, *optional*): + The noise tensor. + + Returns: + `Tensor`: + A scaled input sample. + """ + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device).cast(sample.dtype) + + if sample.device.type == "mps": + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device).cast( + DType.float32 + ) + timestep = timestep.to(sample.device).cast(DType.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) for t in timestep + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) + + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma: float) -> float: + """Converts sigma to timestep.""" + return sigma * self.num_train_timesteps + + def time_shift(self, mu: float, sigma: float, t: Tensor) -> Tensor: + """Apply time shifting to the timesteps. + + Args: + mu (`float`): + The mu parameter for time shifting. + sigma (`float`): + The sigma parameter for time shifting. + t (`Tensor`): + The timesteps to shift. + + Returns: + `Tensor`: + The shifted timesteps. + """ + if self.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) + + def stretch_shift_to_terminal(self, t: Tensor) -> Tensor: + r"""Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal`. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + def set_timesteps( + self, + num_inference_steps: int | None = None, + device: str | Device | None = None, + sigmas: list[float] | None = None, + mu: float | None = None, + timesteps: list[float] | None = None, + ) -> None: + """Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`, *optional*): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `Device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + sigmas (`List[float]`, *optional*): + Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed + automatically. + mu (`float`, *optional*): + Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep + shifting. + timesteps (`List[float]`, *optional*): + Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed + automatically. + """ + if self.use_dynamic_shifting and mu is None: + raise ValueError( + "`mu` must be passed when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is not None and timesteps is not None: + if len(sigmas) != len(timesteps): + raise ValueError( + "`sigmas` and `timesteps` should have the same length" + ) + + if num_inference_steps is not None: + if (sigmas is not None and len(sigmas) != num_inference_steps) or ( + timesteps is not None and len(timesteps) != num_inference_steps + ): + raise ValueError( + "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided" + ) + else: + num_inference_steps = ( + len(sigmas) if sigmas is not None else len(timesteps) + ) + + self.num_inference_steps = num_inference_steps + + # 1. Prepare default sigmas + is_timesteps_provided = timesteps is not None + + if is_timesteps_provided: + timesteps = np.array(timesteps).astype(np.float32) + + if sigmas is None: + if timesteps is None: + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), + self._sigma_to_t(self.sigma_min), + num_inference_steps, + ) + sigmas = timesteps / self.num_train_timesteps + else: + sigmas = np.array(sigmas).astype(np.float32) + num_inference_steps = len(sigmas) + + # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of + # "exponential" or "linear" type is applied + if self.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + + # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value + if self.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + + # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules + if self.use_karras_sigmas: + sigmas = self._convert_to_karras( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + elif self.use_exponential_sigmas: + sigmas = self._convert_to_exponential( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + elif self.use_beta_sigmas: + sigmas = self._convert_to_beta( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + + if not is_timesteps_provided: + timesteps = sigmas * self.num_train_timesteps + + # 5. Append the terminal sigma value. + # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the + # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi + if self.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.num_train_timesteps + sigmas = np.concatenate([sigmas, np.ones((1,), dtype=sigmas.dtype)]) + else: + sigmas = np.concatenate( + [ + sigmas, + np.zeros((1,), dtype=sigmas.dtype), + ] + ) + + # 6. Convert sigmas and timesteps to tensors and move to specified device + sigmas = ( + Tensor.from_dlpack(sigmas).to(device=device).cast(DType.float32) + ) + + self.timesteps = timesteps + self.sigmas = sigmas + self._step_index = None + self._begin_index = None + + def index_for_timestep( + self, timestep: Tensor, schedule_timesteps: Tensor | None = None + ) -> int: + """Returns the index for a given timestep. + + Args: + timestep (`Tensor`): + The timestep to find the index for. + schedule_timesteps (`Tensor`, *optional*): + The schedule timesteps to search in. If `None`, defaults to `self.timesteps`. + + Returns: + `int`: + The index of the timestep. + """ + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep: Tensor) -> None: + """Initialize the step index based on the given timestep. + + Args: + timestep (`Tensor`): + The current timestep. + """ + if self.begin_index is None: + if isinstance(timestep, Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def _step( + self, + model_output: Tensor, + timestep: float | Tensor, + sample: Tensor, + sigmas: Tensor | None = None, + step_index: Tensor | None = None, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + per_token_timesteps: Tensor | None = None, + return_dict: bool = True, + ) -> FlowMatchEulerDiscreteSchedulerOutput | tuple: + """Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`Tensor`): + The direct output from learned diffusion model. + timestep (`float` or `Tensor`): + The current discrete timestep in the diffusion chain. + sample (`Tensor`): + A current instance of a sample created by the diffusion process. + sigmas (`Tensor`, *optional*): + The sigmas tensor. + step_index (`Tensor`, *optional*): + The step index. + s_churn (`float`): + Churn parameter. + s_tmin (`float`): + Min churn timestep. + s_tmax (`float`): + Max churn timestep. + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + per_token_timesteps (`Tensor`, *optional*): + The timesteps for each token in the sample. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + """ + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.cast(DType.float32) + + if per_token_timesteps is not None: + per_token_sigmas = per_token_timesteps / self.num_train_timesteps + + sigmas = sigmas[:, None, None] + lower_mask = sigmas < per_token_sigmas[None] - 1e-6 + lower_sigmas = lower_mask * sigmas + lower_sigmas, _ = lower_sigmas.max(axis=0) + + current_sigma = per_token_sigmas[..., None] + next_sigma = lower_sigmas[..., None] + dt = current_sigma - next_sigma + else: + sigma = sigmas[step_index] + sigma_next = sigmas[step_index + 1] + + current_sigma = sigma + next_sigma = sigma_next + dt = sigma_next - sigma + + if self.stochastic_sampling: + x0 = sample - current_sigma * model_output + noise = random.normal(sample) + prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise + else: + prev_sample = sample + dt * model_output + + # upon completion increase step index by one + self._step_index += 1 + if per_token_timesteps is None: + # Cast sample back to model compatible dtype + prev_sample = prev_sample.cast(model_output.dtype) + + if not return_dict: + return prev_sample + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + def _convert_to_karras( + self, in_sigmas: Tensor, num_inference_steps: int + ) -> Tensor: + """Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364). + + Args: + in_sigmas (`Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `Tensor`: + The converted sigma values following the Karras noise schedule. + """ + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self, "sigma_min"): + sigma_min = self.sigma_min + else: + sigma_min = None + + if hasattr(self, "sigma_max"): + sigma_max = self.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def _convert_to_exponential( + self, in_sigmas: Tensor, num_inference_steps: int + ) -> Tensor: + """Construct an exponential noise schedule. + + Args: + in_sigmas (`Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `Tensor`: + The converted sigma values following an exponential schedule. + """ + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self, "sigma_min"): + sigma_min = self.sigma_min + else: + sigma_min = None + + if hasattr(self, "sigma_max"): + sigma_max = self.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp( + np.linspace( + math.log(sigma_max), math.log(sigma_min), num_inference_steps + ) + ) + return sigmas + + def _convert_to_beta( + self, + in_sigmas: Tensor, + num_inference_steps: int, + alpha: float = 0.6, + beta: float = 0.6, + ) -> Tensor: + """Construct a beta noise schedule as proposed in [Beta Sampling is All You Need](https://huggingface.co/papers/2407.12173). + + Args: + in_sigmas (`Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + alpha (`float`, *optional*, defaults to `0.6`): + The alpha parameter for the beta distribution. + beta (`float`, *optional*, defaults to `0.6`): + The beta parameter for the beta distribution. + + Returns: + `Tensor`: + The converted sigma values following a beta distribution schedule. + """ + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self, "sigma_min"): + sigma_min = self.sigma_min + else: + sigma_min = None + + if hasattr(self, "sigma_max"): + sigma_max = self.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + + def _time_shift_exponential( + self, mu: float, sigma: float, t: Tensor + ) -> Tensor: + """Apply exponential time shifting. + + Args: + mu (`float`): + The mu parameter. + sigma (`float`): + The sigma parameter. + t (`Tensor`): + The timesteps. + + Returns: + `Tensor`: + The shifted timesteps. + """ + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def _time_shift_linear(self, mu: float, sigma: float, t: Tensor) -> Tensor: + """Apply linear time shifting. + + Args: + mu (`float`): + The mu parameter. + sigma (`float`): + The sigma parameter. + t (`Tensor`): + The timesteps. + + Returns: + `Tensor`: + The shifted timesteps. + """ + return mu / (mu + (1 / t - 1) ** sigma) + + def __len__(self) -> int: + """Returns the number of train timesteps.""" + return self.num_train_timesteps + + def step_input_types(self) -> tuple[TensorType, ...]: + """Return the input types for the step function.""" + return ( + TensorType( + self.dtype, + shape=["batch_size", "image_seq_len", "channel"], + device=self.device, + ), + TensorType( + DType.float32, + shape=[], + device=self.device, + ), + TensorType( + self.dtype, + shape=["batch_size", "image_seq_len", "channel"], + device=self.device, + ), + TensorType( + DType.float32, + shape=["num_inference_steps"], + device=self.device, + ), + TensorType( + DType.int64, + shape=[], + device=DeviceRef.CPU(), + ), + ) + + def load_model(self) -> None: + """Load the model.""" + if self.device.is_cpu(): + session = InferenceSession([CPU()]) + else: + session = InferenceSession([Accelerator()]) + + self.set_begin_index(0) + with Graph( + "scheduler_step", input_types=self.step_input_types() + ) as graph: + outputs = self._step( + *graph.inputs, + return_dict=False, + ) + graph.output(outputs) + compiled_graph = graph + self.session = session.load(compiled_graph) + + def step( + self, + model_output: Tensor, + timestep: float | Tensor, + sample: Tensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + per_token_timesteps: Tensor | None = None, + return_dict: bool = True, + ) -> FlowMatchEulerDiscreteSchedulerOutput | tuple: + """Predict the sample from the previous timestep by reversing the SDE. + + Args: + model_output (`Tensor`): + The direct output from learned diffusion model. + timestep (`float` or `Tensor`): + The current discrete timestep in the diffusion chain. + sample (`Tensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + Churn parameter. + s_tmin (`float`): + Min churn timestep. + s_tmax (`float`): + Max churn timestep. + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + per_token_timesteps (`Tensor`, *optional*): + The timesteps for each token in the sample. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + """ + if self.step_index is None: + self._init_step_index(timestep) + schedule_output = self.session.execute( + model_output, + timestep, + sample, + self.sigmas, + self.step_index, + )[0] + self._step_index += 1 + + if not return_dict: + return (schedule_output,) + return FlowMatchEulerDiscreteSchedulerOutput( + prev_sample=schedule_output + ) diff --git a/max/python/max/pipelines/lib/hf_utils.py b/max/python/max/pipelines/lib/hf_utils.py index 8d8c6a7f80f..0a80d2ca461 100644 --- a/max/python/max/pipelines/lib/hf_utils.py +++ b/max/python/max/pipelines/lib/hf_utils.py @@ -304,7 +304,7 @@ def _repo_exists_with_retry(repo_id: str, revision: str) -> bool: ) time.sleep(delay_in_seconds) - assert False, ( # noqa: B011 + raise AssertionError( "This should never be reached due to the raise in the last attempt" ) @@ -372,20 +372,18 @@ def info(self) -> huggingface_hub.ModelInfo: @cached_property def weight_files(self) -> dict[WeightsFormat, list[str]]: - safetensor_search_pattern = "*.safetensors" - gguf_search_pattern = "*.gguf" - pytorch_search_pattern = "*.bin" + safetensor_search_pattern = "**/*.safetensors" + gguf_search_pattern = "**/*.gguf" weight_files = {} if self.repo_type == RepoType.local: safetensor_paths = glob.glob( - os.path.join(self.repo_id, safetensor_search_pattern) + os.path.join(self.repo_id, safetensor_search_pattern), + recursive=True, ) gguf_paths = glob.glob( - os.path.join(self.repo_id, gguf_search_pattern) - ) - pytorch_paths = glob.glob( - os.path.join(self.repo_id, pytorch_search_pattern) + os.path.join(self.repo_id, gguf_search_pattern), + recursive=True, ) elif self.repo_type == RepoType.online: fs = huggingface_hub.HfFileSystem() @@ -396,9 +394,6 @@ def weight_files(self) -> dict[WeightsFormat, list[str]]: gguf_paths = cast( list[str], fs.glob(f"{self.repo_id}/{gguf_search_pattern}") ) - pytorch_paths = cast( - list[str], fs.glob(f"{self.repo_id}/{pytorch_search_pattern}") - ) else: raise ValueError(f"Unsupported repo type: {self.repo_type}") @@ -626,3 +621,33 @@ def generate_local_model_path(repo_id: str, revision: str) -> str: if not path.is_dir(): raise FileNotFoundError(f"Model path does not exist: {path}") return str(path) + + +def get_model_index_path_for_diffusers( + huggingface_repo: HuggingFaceRepo, +) -> str | None: + model_index_path: str | None = None + + if huggingface_repo.repo_type == RepoType.local: + local_index = Path(huggingface_repo.repo_id) / "model_index.json" + if local_index.exists(): + model_index_path = str(local_index) + else: + raise ValueError( + f"Failed to find model_index.json in {huggingface_repo.repo_id}." + ) + else: + try: + if huggingface_hub.file_exists( + huggingface_repo.repo_id, + "model_index.json", + revision=huggingface_repo.revision, + ): + model_index_path = huggingface_hub.hf_hub_download( + huggingface_repo.repo_id, + "model_index.json", + revision=huggingface_repo.revision, + ) + except Exception: + model_index_path = None + return model_index_path diff --git a/max/python/max/pipelines/lib/image_processor.py b/max/python/max/pipelines/lib/image_processor.py new file mode 100644 index 00000000000..c6897c095bd --- /dev/null +++ b/max/python/max/pipelines/lib/image_processor.py @@ -0,0 +1,226 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import logging + +import numpy as np +import PIL.Image +from max.driver import Tensor as DTensor +from max.dtype import DType +from max.experimental import Tensor +from max.experimental import functional as F +from max.experimental.compile_utils import max_compile +from max.graph import DeviceRef, TensorType, TensorValue, ops +from PIL import Image + +logger = logging.getLogger(__name__) + + +PipelineImageInput = ( + PIL.Image.Image + | np.ndarray + | Tensor + | list[PIL.Image.Image] + | list[np.ndarray] + | list[Tensor] +) + + +class VaeImageProcessor: + config_name = "config.json" + + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 8, + vae_latent_channels: int = 4, + resample: str = "lanczos", + reducing_gap: int | None = None, + do_normalize: bool = True, + do_binarize: bool = False, + do_convert_rgb: bool = False, + do_convert_grayscale: bool = False, + device: DeviceRef = DeviceRef.GPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize the VaeImageProcessor. + + Args: + do_resize (bool, optional): Whether to resize images. Defaults to True. + vae_scale_factor (int, optional): The VAE scale factor. Defaults to 8. + vae_latent_channels (int, optional): The number of latent channels for the VAE. Defaults to 4. + resample (str, optional): The resampling mode for resizing. Defaults to "lanczos". + reducing_gap (int, optional): A reduction gap parameter for resampling. Defaults to None. + do_normalize (bool, optional): Whether to normalize images to [-1, 1]. Defaults to True. + do_binarize (bool, optional): Whether to binarize images. Defaults to False. + do_convert_rgb (bool, optional): Whether to convert images to RGB. Defaults to False. + do_convert_grayscale (bool, optional): Whether to convert images to grayscale. Defaults to False. + device (DeviceRef, optional): The device to use for the image processor. Defaults to DeviceRef.GPU(). + dtype (DType, optional): The data type to use for the image processor. Defaults to DType.bfloat16. + + Raises: + ValueError: If both do_convert_rgb and do_convert_grayscale are set to True. + """ + super().__init__() + if do_convert_rgb and do_convert_grayscale: + raise ValueError( + "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`," + " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.", + " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`", + ) + + self.do_normalize = do_normalize + self.device = device + self.dtype = dtype + self._denormalize_conditionally = max_compile( + self._denormalize_conditionally, + input_types=self._denormalize_conditionally_input_types(), + ) + + @staticmethod + def denormalize(images: np.ndarray | Tensor) -> np.ndarray | Tensor: + r"""Denormalize an image array to [0,1]. + + Args: + images (`np.ndarray` or `Tensor`): + The image array to denormalize. + + Returns: + `np.ndarray` or `Tensor`: + The denormalized image array. + """ + if isinstance(images, (Tensor, TensorValue)): + images = images * 0.5 + 0.5 + images = F.min( + images, + Tensor.constant(1.0, dtype=images.dtype, device=images.device), + ) + images = F.max( + images, + Tensor.constant(0.0, dtype=images.dtype, device=images.device), + ) + return images + return np.clip(images * 0.5 + 0.5, 0, 1) + + def _denormalize_conditionally( + self, + images: np.ndarray | Tensor, + ) -> np.ndarray: + r"""Denormalize a batch of images based on a condition list. + + Args: + images (`np.ndarray` or `Tensor`): + The input image tensor. + """ + images = self.denormalize(images) if self.do_normalize else images + images = ops.cast(images, DType.float32) + return images + + @staticmethod + def max_to_numpy(images: Tensor) -> np.ndarray: + r"""Convert a Max tensor to a NumPy image. + + Args: + images (`Tensor`): + The Max tensor to convert to NumPy format. + + Returns: + `np.ndarray`: + A NumPy array representation of the images. + """ + images = DTensor.to_numpy(images) + images = np.transpose(images, (0, 2, 3, 1)) + return images + + @staticmethod + def numpy_to_pil(images: np.ndarray) -> list[PIL.Image.Image]: + r"""Convert a numpy image or a batch of images to a PIL image. + + Args: + images (`np.ndarray`): + The image array to convert to PIL format. + + Returns: + `list[PIL.Image.Image]`: + A list of PIL images. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [ + Image.fromarray(image.squeeze(), mode="L") for image in images + ] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + def _denormalize_conditionally_input_types(self) -> list[TensorType]: + return [ + TensorType( + shape=("batch_size", "num_channels", "height", "width"), + device=self.device, + dtype=self.dtype, + ), + ] + + def postprocess( + self, + image: Tensor, + output_type: str = "pil", + do_denormalize: list[bool] | None = None, + ) -> PIL.Image.Image | np.ndarray | Tensor: + """Postprocess the image output from tensor to `output_type`. + + Args: + image (`Tensor`): + The image input, should be a Max tensor with shape `B x C x H x W`. + output_type (`str`, *optional*, defaults to `pil`): + The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. + do_denormalize (`list[bool]`, *optional*, defaults to `None`): + Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the + `VaeImageProcessor` config. + + Returns: + `PIL.Image.Image`, `np.ndarray` or `Tensor`: + The postprocessed image. + """ + if not isinstance(image, Tensor) and not isinstance(image, TensorValue): + raise ValueError( + f"Input for postprocessing is in incorrect format: {type(image)}. We only support Max tensor" + ) + if output_type not in ["latent", "max", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " + "`pil`, `np`, `max`, `latent`" + ) + logger.warning(deprecation_message) + output_type = "np" + + if output_type == "latent": + return image + + image = self._denormalize_conditionally(image) + + if output_type == "max": + return image[0] + + image = self.max_to_numpy(image[0]) + + if output_type == "np": + return image + + if output_type == "pil": + return self.numpy_to_pil(image) diff --git a/max/python/max/pipelines/lib/interfaces/__init__.py b/max/python/max/pipelines/lib/interfaces/__init__.py index db7ab9885c5..a6592f356d3 100644 --- a/max/python/max/pipelines/lib/interfaces/__init__.py +++ b/max/python/max/pipelines/lib/interfaces/__init__.py @@ -12,6 +12,7 @@ # ===----------------------------------------------------------------------=== # """Interfaces for MAX pipelines.""" +from .diffusion_pipeline import DiffusionPipeline from .generate import GenerateMixin from .kv_cache import KVCacheMixin, get_paged_manager from .pipeline_model import ( @@ -23,6 +24,7 @@ __all__ = [ "AlwaysSignalBuffersMixin", + "DiffusionPipeline", "GenerateMixin", "KVCacheMixin", "ModelInputs", diff --git a/max/python/max/pipelines/lib/interfaces/diffusion_pipeline.py b/max/python/max/pipelines/lib/interfaces/diffusion_pipeline.py new file mode 100644 index 00000000000..64835ee3c71 --- /dev/null +++ b/max/python/max/pipelines/lib/interfaces/diffusion_pipeline.py @@ -0,0 +1,177 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Pipeline utilities for MAX-optimized pipelines.""" + +from __future__ import annotations + +import os +from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from max.config import load_config +from max.driver import load_devices +from max.graph import DeviceRef +from max.graph.weights import load_weights +from max.pipelines.lib.interfaces.max_model import MaxModel +from tqdm import tqdm + +if TYPE_CHECKING: + from ..config import PipelineConfig + from ..diffusion_schedulers import FlowMatchEulerDiscreteScheduler + + +class DiffusionPipeline(ABC): + config_name: str | None = None + """ + The name of the config file of the pipeline. + + It can be found in the downloaded path or HuggingFace hub. + It's usually "model_index.json" or "config.json" for Diffusion models. + """ + + components: ( + dict[str, type[MaxModel] | type[FlowMatchEulerDiscreteScheduler]] | None + ) = None + """The components of the pipeline. + + It can be found in the downloaded path or HuggingFace hub. + It's usually contains text_encoder, tokenizer, transformer, vae, etc. + """ + + def __init__( + self, + pipeline_config: PipelineConfig, + cached_folder: str, + **kwargs: Any, + ) -> DiffusionPipeline: + """Load a pipeline from a pretrained model. + + Args: + pipeline_config: Pipeline configuration for model and runtime setup. + cached_folder: Local path to the downloaded model snapshot. + **kwargs: Additional pipeline-specific arguments. + """ + self.pipeline_config = pipeline_config + self.devices = load_devices(pipeline_config.model.device_specs) + + # Load sub models + loaded_sub_models = self.load_sub_models(cached_folder) + for name, model in loaded_sub_models.items(): + setattr(self, name, model) + + self.init_remaining_components() + + @abstractmethod + def init_remaining_components(self) -> None: + pass + + def load_sub_models( + self, + pretrained_model_name_or_path: str | os.PathLike, + ) -> dict: + """Load sub-models for the pipeline. + + Args: + pretrained_model_name_or_path: Path to pretrained model. + + Returns: + Dictionary containing the loaded sub-models. + """ + loaded_sub_models = {} + if self.components is None: + raise ValueError( + f"`components` for {self.__class__.__name__} pipeline is not set. " + "Please set proper components based on its sub-directories in the downloaded path." + ) + for name, component_class in tqdm( + self.components.items(), desc="Loading sub models" + ): + component_path = os.path.join(pretrained_model_name_or_path, name) + if "tokenizer" in name: + # NOTE: Currently, we are using tokenizers from transformers. + # TODO(minkyu): Check if we can use Tokenizer in Max, + # and remove this conditional path. + loaded_sub_models[name] = component_class.from_pretrained( + component_path + ) + continue + + if ( + not hasattr(component_class, "config_name") + or component_class.config_name is None + ): + raise ValueError( + f"`config_name` for {component_class.__name__} is not set. " + "Please set proper config file name in the downloaded path." + ) + config = load_config( + f"{component_path}/{component_class.config_name}" + ) + if issubclass(component_class, MaxModel): + weight_paths = [ + Path(pretrained_model_name_or_path) / weight_path + for weight_path in self.pipeline_config.model.weight_path + if weight_path.split("/")[0] == name + ] + loaded_sub_models[name] = component_class( + config=config, + encoding=self.pipeline_config.model.quantization_encoding, + devices=self.devices, + weights=load_weights(weight_paths), + ) + else: + loaded_sub_models[name] = component_class( + **config, + device=DeviceRef.from_device(self.devices[0]), + dtype=self.pipeline_config.model.quantization_encoding.dtype, + ) + + return loaded_sub_models + + def finalize_pipeline_config(self) -> None: + return + + def _execution_device(self) -> DeviceRef: + r"""Returns the device on which the pipeline's models will be executed. + + This property checks pipeline components to determine the execution device. + It supports MAX models (with DeviceRef device attribute). + Similar structure to diffusers' _execution_device but returns DeviceRef instead of DeviceRef. + + Returns: + DeviceRef: The execution device (GPU if available, otherwise CPU). + """ + # Check MAX models - prioritize GPU + # Similar to diffusers' _execution_device but for MAX models (not torch.nn.Module) + sub_models = {k: getattr(self, k) for k in self.components} + for name, model in sub_models.items(): + exclude_from_cpu_offload = getattr( + self, "_exclude_from_cpu_offload", set() + ) + if name in exclude_from_cpu_offload: + continue + + if hasattr(model, "device") and isinstance(model.device, DeviceRef): + return model.device + + if hasattr(self, "device"): + try: + device = self.device + if isinstance(device, DeviceRef): + return device + except Exception: + pass + + return DeviceRef.CPU() diff --git a/max/python/max/pipelines/lib/model_config.py b/max/python/max/pipelines/lib/model_config.py index 0b37899b012..71e2691f33b 100644 --- a/max/python/max/pipelines/lib/model_config.py +++ b/max/python/max/pipelines/lib/model_config.py @@ -509,7 +509,6 @@ def validate_and_resolve_quantization_encoding_weight_path( are consistent. Args: - weight_path: The path to the weight file. default_encoding: The default encoding to use if no encoding is provided. """ @@ -582,6 +581,7 @@ def _validate_and_resolve_dtype_casting( Note: We currently only support float32 to bfloat16 weight type casting. Args: + from_encoding: The source encoding to cast from. to_encoding: The desired encoding to cast to. Raises: @@ -806,6 +806,9 @@ def _resolve_weight_path( encoding=self._applied_dtype_cast_from ) + if not weight_files: + weight_files = self.huggingface_weight_repo.weight_files + if default_weight_files := weight_files.get( default_weights_format, [] ): @@ -945,6 +948,7 @@ def _local_weight_path(self, relative_path: Path) -> str | None: # NOTE(bduke): do this even for online repositories, because upstream # code originating from `huggingface_hub.hf_hub_download` returns # absolute paths for cached files. + relative_path = Path(relative_path) if relative_path.exists() and relative_path.is_file(): return str(relative_path.resolve()) From 02e9070006bb66d6113485fec49f8aefa60ea458 Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Sat, 17 Jan 2026 01:22:43 +0000 Subject: [PATCH 06/18] remove use_torch_randn path --- .../architectures/flux1/pipeline_flux.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/max/python/max/pipelines/architectures/flux1/pipeline_flux.py b/max/python/max/pipelines/architectures/flux1/pipeline_flux.py index deacff48a87..ef493ab5769 100644 --- a/max/python/max/pipelines/architectures/flux1/pipeline_flux.py +++ b/max/python/max/pipelines/architectures/flux1/pipeline_flux.py @@ -12,7 +12,6 @@ # ===----------------------------------------------------------------------=== # import inspect -import os from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -410,23 +409,7 @@ def prepare_latents( ) return latents.to(device).cast(dtype), latent_image_ids - # NOTE: Max random generation uses different seed with torch.randn. - # So, we currently leave torch randn as optional for - # functionality comparison with the original diffusers pipeline. - if os.environ.get("USE_TORCH_RANDN", "0") == "1": - import torch - - seed = int(os.environ.get("SEED", 42)) - generator = torch.Generator(device="cuda").manual_seed(seed) - latents = torch.randn( - shape, - generator=generator, - device="cuda", - dtype=dtype.to_torch(), - ) - latents = Tensor_v3.from_dlpack(latents) - else: - latents = random.normal(shape, device=device, dtype=dtype) + latents = random.normal(shape, device=device, dtype=dtype) latents = self._pack_latents( latents, batch_size, num_channels_latents, height, width ) From fb749bffad7138c796cca9843e1605b059cf740d Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Sat, 17 Jan 2026 10:31:44 +0000 Subject: [PATCH 07/18] fix: convert noise prediction tensors to V3 tensor for eager operations of cfg --- max/python/max/pipelines/architectures/flux1/pipeline_flux.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/max/python/max/pipelines/architectures/flux1/pipeline_flux.py b/max/python/max/pipelines/architectures/flux1/pipeline_flux.py index ef493ab5769..85298caf20e 100644 --- a/max/python/max/pipelines/architectures/flux1/pipeline_flux.py +++ b/max/python/max/pipelines/architectures/flux1/pipeline_flux.py @@ -720,6 +720,9 @@ def __call__( negative_text_ids, guidance, )[0] + # TODO: negative prompt path is very slow, need to optimize. + noise_pred = Tensor_v3.from_dlpack(noise_pred) + neg_noise_pred = Tensor_v3.from_dlpack(neg_noise_pred) noise_pred = neg_noise_pred + true_cfg_scale * ( noise_pred - neg_noise_pred ) From 013d105dd201414c97fa8e898eb4bd4fc67cb7ff Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Sat, 17 Jan 2026 01:17:26 +0000 Subject: [PATCH 08/18] fix: use set_seed api and remove torch flag --- max/examples/diffusion/offline_generation.py | 10 ++-------- max/python/max/entrypoints/pipelines.py | 6 ++---- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/max/examples/diffusion/offline_generation.py b/max/examples/diffusion/offline_generation.py index 56bfce230c0..12b5cebd735 100644 --- a/max/examples/diffusion/offline_generation.py +++ b/max/examples/diffusion/offline_generation.py @@ -12,10 +12,10 @@ # ===----------------------------------------------------------------------=== # import argparse -import os from pathlib import Path from max.entrypoints.diffusion import DiffusionPipeline +from max.experimental.realization_context import set_seed from max.pipelines import PipelineConfig @@ -24,17 +24,11 @@ def main() -> None: parser.add_argument( "--model-path", type=str, default="black-forest-labs/FLUX.1-dev" ) - parser.add_argument("--use-torch-randn", action="store_true") parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() model_path = args.model_path - if args.use_torch_randn: - # NOTE: Use torch randn for latent initialization. - # Currently, It's not possible to set seed for Max random generation, - # so, use torch randn to test different seeds. - os.environ["USE_TORCH_RANDN"] = "1" - os.environ["SEED"] = str(args.seed) + set_seed(args.seed) pipeline_config = PipelineConfig(model_path=model_path) pipe = DiffusionPipeline(pipeline_config) diff --git a/max/python/max/entrypoints/pipelines.py b/max/python/max/entrypoints/pipelines.py index 7b315b44aea..a99c8c6fb18 100644 --- a/max/python/max/entrypoints/pipelines.py +++ b/max/python/max/entrypoints/pipelines.py @@ -468,12 +468,10 @@ def diffusion_generate( ) -> None: """Generate images using a diffusion pipeline.""" from max.entrypoints.cli.generate import generate_image + from max.experimental.realization_context import set_seed from max.pipelines import PipelineConfig - if use_torch_randn: - os.environ["USE_TORCH_RANDN"] = "1" - os.environ["SEED"] = str(seed) - + set_seed(seed) pipeline_config = PipelineConfig(**config_kwargs) pipeline_config.log_basic_config() From e4fda6aed6b72e4e88f202b79ef53e3d251e5a27 Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Sat, 17 Jan 2026 10:26:15 +0000 Subject: [PATCH 09/18] chore: allow negative prompt --- max/python/max/entrypoints/diffusion.py | 4 ++++ max/python/max/entrypoints/pipelines_diffusion.py | 2 -- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/max/python/max/entrypoints/diffusion.py b/max/python/max/entrypoints/diffusion.py index 10e9c15719b..ec85979310c 100644 --- a/max/python/max/entrypoints/diffusion.py +++ b/max/python/max/entrypoints/diffusion.py @@ -36,6 +36,8 @@ def __init__(self, pipeline_config: PipelineConfig) -> None: def __call__( self, prompt: str, + negative_prompt: str | None = None, + true_cfg_scale: float = 1.0, height: int = 1024, width: int = 1024, num_inference_steps: int = 50, @@ -47,6 +49,8 @@ def __call__( # e.g. T2I, I2I, T2V, I2V, V2V. inputs = ImageGenerationInputs( prompt=prompt, + negative_prompt=negative_prompt, + true_cfg_scale=true_cfg_scale, height=height, width=width, num_inference_steps=num_inference_steps, diff --git a/max/python/max/entrypoints/pipelines_diffusion.py b/max/python/max/entrypoints/pipelines_diffusion.py index 2d94c5bb423..52863bf27f4 100644 --- a/max/python/max/entrypoints/pipelines_diffusion.py +++ b/max/python/max/entrypoints/pipelines_diffusion.py @@ -25,5 +25,3 @@ def main() -> None: if __name__ == "__main__": main() - - From dcec569cd964237e9e6a49a8a2dfa6879c18133b Mon Sep 17 00:00:00 2001 From: jingulee Date: Fri, 16 Jan 2026 08:17:46 +0000 Subject: [PATCH 10/18] add: images api --- max/examples/diffusion/offline_generation.py | 11 +- max/python/max/entrypoints/cli/__init__.py | 3 +- .../max/entrypoints/cli/serve/__init__.py | 3 +- .../cli/serve/serve_diffusion_api.py | 263 ++++++++++++ max/python/max/entrypoints/diffusion.py | 392 +++++++++++++++++- max/python/max/entrypoints/pipelines.py | 221 ++++++++-- max/python/max/interfaces/__init__.py | 10 + .../interfaces/pipeline_variants/__init__.py | 10 + .../pipeline_variants/image_generation.py | 276 +++++++++++- .../lib/interfaces/configuration_utils.py | 244 +++++++++++ 10 files changed, 1353 insertions(+), 80 deletions(-) create mode 100644 max/python/max/entrypoints/cli/serve/serve_diffusion_api.py create mode 100644 max/python/max/pipelines/lib/interfaces/configuration_utils.py diff --git a/max/examples/diffusion/offline_generation.py b/max/examples/diffusion/offline_generation.py index 12b5cebd735..cae6c25a368 100644 --- a/max/examples/diffusion/offline_generation.py +++ b/max/examples/diffusion/offline_generation.py @@ -14,7 +14,7 @@ import argparse from pathlib import Path -from max.entrypoints.diffusion import DiffusionPipeline +from max.entrypoints.diffusion import ImageGenerator from max.experimental.realization_context import set_seed from max.pipelines import PipelineConfig @@ -30,21 +30,20 @@ def main() -> None: model_path = args.model_path set_seed(args.seed) pipeline_config = PipelineConfig(model_path=model_path) - pipe = DiffusionPipeline(pipeline_config) + pipe = ImageGenerator(pipeline_config) prompt = "A cat holding a sign that says hello world" print(f"Prompt: {prompt}") - result = pipe( - prompt=prompt, + # Generate images using the new API + images = pipe.generate( + prompt, height=1024, width=1024, num_inference_steps=50, guidance_scale=3.5, ) - images = result.images - output_path = Path("output.png") output_path.parent.mkdir(parents=True, exist_ok=True) images[0].save(output_path) diff --git a/max/python/max/entrypoints/cli/__init__.py b/max/python/max/entrypoints/cli/__init__.py index 0aa58bb5569..33d030adf22 100644 --- a/max/python/max/entrypoints/cli/__init__.py +++ b/max/python/max/entrypoints/cli/__init__.py @@ -30,7 +30,7 @@ from .generate import generate_text_for_pipeline, stream_text_to_console from .list import list_pipelines_to_console, list_pipelines_to_json from .metrics import TextGenerationMetrics -from .serve import serve_api_server_and_model_worker +from .serve import serve_api_server_and_model_worker, serve_diffusion_api_server __all__ = [ "DevicesOptionType", @@ -48,6 +48,7 @@ "pipeline_encode", "sampling_params_options", "serve_api_server_and_model_worker", + "serve_diffusion_api_server", "stream_text_to_console", "validate_field_type", ] diff --git a/max/python/max/entrypoints/cli/serve/__init__.py b/max/python/max/entrypoints/cli/serve/__init__.py index 7b2da021d95..58ae44d4633 100644 --- a/max/python/max/entrypoints/cli/serve/__init__.py +++ b/max/python/max/entrypoints/cli/serve/__init__.py @@ -12,5 +12,6 @@ # ===----------------------------------------------------------------------=== # from .serve_api_and_model_worker import serve_api_server_and_model_worker +from .serve_diffusion_api import serve_diffusion_api_server -__all__ = ["serve_api_server_and_model_worker"] +__all__ = ["serve_api_server_and_model_worker", "serve_diffusion_api_server"] diff --git a/max/python/max/entrypoints/cli/serve/serve_diffusion_api.py b/max/python/max/entrypoints/cli/serve/serve_diffusion_api.py new file mode 100644 index 00000000000..34cfadaad54 --- /dev/null +++ b/max/python/max/entrypoints/cli/serve/serve_diffusion_api.py @@ -0,0 +1,263 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""API server for diffusion image generation. + +This module provides an OpenAI-compatible API server for image generation +using diffusion models (e.g., FLUX.1). + +The server implements the /v1/images/generations endpoint following the +OpenAI API specification. +""" + +from __future__ import annotations + +import logging +import os +import signal +import time +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from typing import Any + +import uvloop +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse +from max.entrypoints.diffusion import ImageGenerator +from max.interfaces import ImageGenerationRequest +from max.pipelines import PipelineConfig +from max.profiler import Tracer +from max.serve.api_server import validate_port_is_free +from max.serve.config import Settings +from uvicorn import Config, Server + +logger = logging.getLogger("max.serve.diffusion") + + +@asynccontextmanager +async def lifespan( + app: FastAPI, + settings: Settings, + pipeline_config: PipelineConfig, +) -> AsyncGenerator[None]: + """Manage the lifecycle of the diffusion server.""" + logger.info("Starting diffusion image generation server...") + + # Initialize the diffusion pipeline + try: + pipeline = ImageGenerator(pipeline_config) + app.state.pipeline = pipeline + app.state.pipeline_config = pipeline_config + app.state.settings = settings + except Exception: + logger.exception("Failed to initialize diffusion pipeline") + raise + + logger.info( + f"\n\n{'*' * 80}\n\n" + f"{'Image generation server ready on http://' + settings.host + ':' + str(settings.port) + ' (Press CTRL+C to quit)'.center(80)}\n\n" + f"{'*' * 80}\n" + ) + + yield + + logger.info("Shutting down diffusion server...") + + +def create_diffusion_app( + settings: Settings, + pipeline_config: PipelineConfig, +) -> FastAPI: + """Create the FastAPI application for diffusion serving.""" + + @asynccontextmanager + async def lifespan_wrap(app: FastAPI) -> AsyncGenerator[None, None]: + try: + async with lifespan(app, settings, pipeline_config): + yield + except Exception: + logger.exception("Server exception, shutting down...") + os.kill(os.getpid(), signal.SIGINT) + os.kill(os.getpid(), signal.SIGINT) + + app = FastAPI(title="MAX Serve - Image Generation", lifespan=lifespan_wrap) + + # Health check + @app.get("/health") + async def health() -> JSONResponse: + return JSONResponse({"status": "ok"}) + + @app.get("/v1/health") + async def v1_health() -> JSONResponse: + return JSONResponse({"status": "ok"}) + + # Version endpoint + @app.get("/version") + async def version() -> JSONResponse: + from importlib.metadata import PackageNotFoundError, version + + try: + package_version = version("max") + return JSONResponse({"version": package_version}) + except PackageNotFoundError: + return JSONResponse({"version": "unknown"}) + + # OpenAI-compatible image generation endpoint + @app.post("/v1/images/generations") + async def create_image(request: Request) -> JSONResponse: + """Generate images from a text prompt (OpenAI-compatible). + + Request body follows the OpenAI /v1/images/generations schema: + - prompt (required): A text description of the desired image(s) + - model: The model to use for image generation + - n: Number of images to generate (1-10) + - quality: Image quality ('standard', 'hd', 'high', 'medium', 'low') + - response_format: 'url' or 'b64_json' + - size: Image size (e.g., '1024x1024') + - style: Image style ('vivid', 'natural') + - user: End-user identifier + - background: Background transparency ('transparent', 'opaque', 'auto') + - output_format: Output format ('png', 'jpeg', 'webp') + - num_inference_steps: Number of denoising steps (extension) + - guidance_scale: Classifier-free guidance scale (extension) + - seed: Random seed for reproducibility (extension) + + Returns: + OpenAI-compatible ImagesResponse with created timestamp and + image data (b64_json or url). + """ + try: + # Parse request body + body = await request.json() + + # Validate required field + if "prompt" not in body: + raise ValueError("'prompt' is a required field") + + # Get pipeline from app state + pipeline: ImageGenerator = request.app.state.pipeline + + # Build internal request from OpenAI-compatible fields + internal_request = ImageGenerationRequest( + prompt=body["prompt"], + model=body.get("model"), + n=body.get("n", 1), + quality=body.get("quality", "standard"), + response_format=body.get("response_format", "b64_json"), + size=body.get("size", "1024x1024"), + style=body.get("style"), + user=body.get("user"), + background=body.get("background"), + moderation=body.get("moderation"), + output_compression=body.get("output_compression"), + output_format=body.get("output_format", "png"), + partial_images=body.get("partial_images"), + stream=body.get("stream"), + # Extension parameters for diffusion models + num_inference_steps=body.get("num_inference_steps", 50), + guidance_scale=body.get("guidance_scale", 3.5), + seed=body.get("seed"), + ) + + logger.debug( + "Processing image generation request: prompt=%r, size=%s, n=%d", + internal_request.prompt[:50] if len(internal_request.prompt) > 50 else internal_request.prompt, + internal_request.size, + internal_request.n or 1, + ) + + # Generate images + response = pipeline.generate(internal_request) + + # Return OpenAI-compatible response + return JSONResponse(response.to_dict()) + + except ValueError as e: + logger.warning("Invalid request: %s", str(e)) + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + logger.exception("Image generation failed") + raise HTTPException( + status_code=500, detail="Image generation failed" + ) from e + + # Model info endpoint + @app.get("/v1/models") + async def list_models(request: Request) -> JSONResponse: + """List available models.""" + pipeline: ImageGenerator = request.app.state.pipeline + return JSONResponse( + { + "object": "list", + "data": [ + { + "id": pipeline.model_name, + "object": "model", + "created": int(time.time()), + "owned_by": "modular", + } + ], + } + ) + + @app.get("/v1/models/{model_id}") + async def get_model(model_id: str, request: Request) -> JSONResponse: + """Get model information.""" + pipeline: ImageGenerator = request.app.state.pipeline + + # Check if the model_id matches + if model_id == pipeline.model_name or model_id in pipeline.model_name: + return JSONResponse( + { + "id": pipeline.model_name, + "object": "model", + "created": int(time.time()), + "owned_by": "modular", + } + ) + + raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found") + + return app + + +def serve_diffusion_api_server( + settings: Settings, + pipeline_config: PipelineConfig, +) -> None: + """Start the diffusion API server. + + Args: + settings: Server settings (port, host, etc.). + pipeline_config: Configuration for the diffusion pipeline. + """ + # Create the FastAPI app + app = create_diffusion_app(settings, pipeline_config) + + # Configure uvicorn + config = Config( + app=app, + log_config=None, + loop="uvloop", + host=settings.host, + port=settings.port, + timeout_graceful_shutdown=5, + ) + + # Validate port before loading models + validate_port_is_free(settings.port) + + server = Server(config) + + with Tracer("diffusion_api_server"): + uvloop.run(server.serve()) diff --git a/max/python/max/entrypoints/diffusion.py b/max/python/max/entrypoints/diffusion.py index ec85979310c..d166d2de96a 100644 --- a/max/python/max/entrypoints/diffusion.py +++ b/max/python/max/entrypoints/diffusion.py @@ -11,51 +11,397 @@ # limitations under the License. # ===----------------------------------------------------------------------=== # +"""High-level interface for image generation using diffusion models. + +This module provides both programmatic and OpenAI-compatible API access +to diffusion-based image generation pipelines. + +Example (Direct API): + ```python + from max.entrypoints.diffusion import ImageGenerator + from max.pipelines import PipelineConfig + + config = PipelineConfig(model="black-forest-labs/FLUX.1-schnell") + generator = ImageGenerator(config) + + images = generator.generate("A beautiful sunset over mountains") + images[0].save("output.png") + ``` + +Example (OpenAI-compatible API): + ```python + from max.entrypoints.diffusion import ImageGenerator + from max.interfaces import ImageGenerationRequest + + generator = ImageGenerator(config) + request = ImageGenerationRequest( + prompt="A beautiful sunset", + size="1024x1024", + n=1, + ) + response = generator.create(request) + # response.data[0].b64_json contains the base64-encoded image + ``` +""" + +from __future__ import annotations + +import queue +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from threading import Event, Thread +from typing import TYPE_CHECKING + +import tqdm +from PIL.Image import Image + from max.interfaces import ( ImageGenerationInputs, ImageGenerationOutput, + ImageGenerationRequest, + ImageGenerationResponse, PipelineTask, + RequestID, ) from max.pipelines.lib import PIPELINE_REGISTRY, PipelineConfig +if TYPE_CHECKING: + from max.pipelines.core import ImageGenerationPipeline + + +# ============================================================================ +# Internal Request/Response Types +# ============================================================================ + + +@dataclass +class _ImageRequest: + """Internal request object for the image generation queue.""" + + id: RequestID + prompts: Sequence[str] + height: int + width: int + num_inference_steps: int + guidance_scale: float + num_images_per_prompt: int + use_tqdm: bool + + +@dataclass +class _ImageResponse: + """Internal response object from the image generation queue.""" + + images: list[Image] + + +@dataclass +class _ThreadControl: + """Thread synchronization primitives.""" + + ready: Event = field(default_factory=Event) + cancel: Event = field(default_factory=Event) + -class DiffusionPipeline: - """Entrypoint for image-generation diffusion pipelines.""" +# ============================================================================ +# Main ImageGenerator Class +# ============================================================================ + + +class ImageGenerator: + """High-level interface for generating images using diffusion models. + + This class provides a thread-safe interface for image generation with + support for both direct API calls and OpenAI-compatible request/response. + + The generator runs a background worker thread that processes requests + from a queue, allowing for concurrent request handling. + + Attributes: + model_name: The name/path of the loaded model. + pipeline_config: The configuration for the pipeline. + """ + + # Thread control and communication + _thread_control: _ThreadControl + _worker_thread: Thread + _request_queue: queue.Queue[_ImageRequest] + _pending_requests: dict[RequestID, queue.Queue[_ImageResponse]] + + # Configuration + pipeline_config: PipelineConfig + model_name: str def __init__(self, pipeline_config: PipelineConfig) -> None: - # NOTE: Currently, this entrypoint is implemented minimally - # for offline image generation. - # It will be developed further to support serving as well. + """Initialize the image generator. + + Args: + pipeline_config: Configuration specifying the model and parameters. + """ self.pipeline_config = pipeline_config - _, model_factory = PIPELINE_REGISTRY.retrieve_factory( - pipeline_config, - task=PipelineTask.IMAGE_GENERATION, + self.model_name = pipeline_config.model_config.model_path + + # Initialize thread control and queues + self._thread_control = _ThreadControl() + self._request_queue = queue.Queue() + self._pending_requests = {} + + # Start background worker + self._worker_thread = Thread( + target=_run_worker, + args=( + self._thread_control, + self.pipeline_config, + self._request_queue, + self._pending_requests, + ), + daemon=True, ) - self.pipeline = model_factory() + self._worker_thread.start() + + # Wait for worker to be ready + self._thread_control.ready.wait() - def __call__( + def __del__(self) -> None: + """Clean up resources.""" + self._thread_control.cancel.set() + if self._worker_thread.is_alive(): + self._worker_thread.join(timeout=5.0) + + # ======================================================================== + # Public API: Direct Generation + # ======================================================================== + + def generate( self, - prompt: str, - negative_prompt: str | None = None, - true_cfg_scale: float = 1.0, + prompts: str | Sequence[str], + *, height: int = 1024, width: int = 1024, num_inference_steps: int = 50, guidance_scale: float = 3.5, num_images_per_prompt: int = 1, - ) -> ImageGenerationOutput: - """Generate images from a prompt with the configured pipeline.""" - # TODO: consider all possible diffusion tasks, - # e.g. T2I, I2I, T2V, I2V, V2V. - inputs = ImageGenerationInputs( - prompt=prompt, - negative_prompt=negative_prompt, - true_cfg_scale=true_cfg_scale, + use_tqdm: bool = True, + ) -> list[Image]: + """Generate images from text prompts. + + This method is thread-safe and can be called from multiple threads. + + Args: + prompts: Single prompt string or sequence of prompts. + height: Image height in pixels. + width: Image width in pixels. + num_inference_steps: Number of denoising steps. + guidance_scale: Classifier-free guidance scale. + num_images_per_prompt: Number of images per prompt. + use_tqdm: Show progress bar. + + Returns: + List of generated PIL Images. + + Example: + ```python + images = generator.generate( + "A cat sitting on a couch", + height=1024, + width=1024, + num_inference_steps=30, + ) + images[0].save("cat.png") + ``` + """ + # Normalize prompts to sequence + if isinstance(prompts, str): + prompts = [prompts] + + # Create internal request + request = _ImageRequest( + id=RequestID(), + prompts=prompts, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, + use_tqdm=use_tqdm, + ) + + # Submit request and wait for response + return self._submit_and_wait(request) + + # ======================================================================== + # Public API: OpenAI-Compatible + # ======================================================================== + + def create( + self, + request: ImageGenerationRequest, + ) -> ImageGenerationResponse: + """Generate images using OpenAI-compatible request format. + + Args: + request: OpenAI-compatible image generation request. + + Returns: + OpenAI-compatible response with base64-encoded images. + + Example: + ```python + request = ImageGenerationRequest( + prompt="A beautiful landscape", + size="1024x1024", + n=2, + response_format="b64_json", + ) + response = generator.create(request) + print(f"Generated {len(response.data)} images") + ``` + """ + # Parse dimensions from size string + width, height = request.get_dimensions() + + # Generate images + images = self.generate( + prompts=request.prompt, + height=height, + width=width, + num_inference_steps=request.num_inference_steps, + guidance_scale=request.guidance_scale, + num_images_per_prompt=request.n or 1, + use_tqdm=False, + ) + + # Convert to OpenAI response format + output = ImageGenerationOutput(images=images) + return ImageGenerationResponse.from_pipeline_output( + output=output, + response_format=request.response_format, + output_format=request.get_output_format(), + prompt=request.prompt, ) - pipeline_output: ImageGenerationOutput = self.pipeline.execute(inputs) - return pipeline_output + + # ======================================================================== + # Internal Methods + # ======================================================================== + + def _submit_and_wait(self, request: _ImageRequest) -> list[Image]: + """Submit a request to the queue and wait for response.""" + response_queue: queue.Queue[_ImageResponse] = queue.Queue() + self._pending_requests[request.id] = response_queue + + try: + self._request_queue.put_nowait(request) + response = response_queue.get() + return response.images + finally: + self._pending_requests.pop(request.id, None) + + # ======================================================================== + # Class Methods + # ======================================================================== + + @classmethod + def from_model(cls, model: str, **kwargs) -> ImageGenerator: + """Create an ImageGenerator from a model identifier. + + Args: + model: Model identifier (e.g., "black-forest-labs/FLUX.1-schnell"). + **kwargs: Additional PipelineConfig arguments. + + Returns: + Configured ImageGenerator instance. + + Example: + ```python + generator = ImageGenerator.from_model( + "black-forest-labs/FLUX.1-schnell" + ) + ``` + """ + config = PipelineConfig(model=model, **kwargs) + return cls(config) + + +# ============================================================================ +# Legacy Alias (for backward compatibility) +# ============================================================================ + +DiffusionPipeline = ImageGenerator + + +# ============================================================================ +# Background Worker +# ============================================================================ + + +def _run_worker( + thread_control: _ThreadControl, + pipeline_config: PipelineConfig, + request_queue: queue.Queue[_ImageRequest], + pending_requests: Mapping[RequestID, queue.Queue[_ImageResponse]], +) -> None: + """Background worker that processes image generation requests. + + This function runs in a separate thread and continuously processes + requests from the queue until cancellation is signaled. + """ + # Load the pipeline + _, model_factory = PIPELINE_REGISTRY.retrieve_factory( + pipeline_config, + task=PipelineTask.IMAGE_GENERATION, + ) + pipeline: ImageGenerationPipeline = model_factory() + + # Signal that we're ready + thread_control.ready.set() + + # Main processing loop + while not thread_control.cancel.is_set(): + try: + request = request_queue.get(timeout=0.3) + except queue.Empty: + continue + + # Process the request + images = _process_request(pipeline, request) + + # Send response + if response_queue := pending_requests.get(request.id): + response_queue.put(_ImageResponse(images=images)) + + +def _process_request( + pipeline: ImageGenerationPipeline, + request: _ImageRequest, +) -> list[Image]: + """Process a single image generation request. + + Args: + pipeline: The image generation pipeline. + request: The request to process. + + Returns: + List of generated images. + """ + all_images: list[Image] = [] + + # Create iterator with optional progress bar + prompt_iter = request.prompts + if request.use_tqdm: + prompt_iter = tqdm.tqdm(prompt_iter, desc="Generating images") + + # Generate images for each prompt + for prompt in prompt_iter: + inputs = ImageGenerationInputs( + prompt=prompt, + height=request.height, + width=request.width, + num_inference_steps=request.num_inference_steps, + guidance_scale=request.guidance_scale, + num_images_per_prompt=request.num_images_per_prompt, + ) + + output: ImageGenerationOutput = pipeline.execute(inputs) + all_images.extend(output.images) + + return all_images diff --git a/max/python/max/entrypoints/pipelines.py b/max/python/max/entrypoints/pipelines.py index a99c8c6fb18..62d66752816 100644 --- a/max/python/max/entrypoints/pipelines.py +++ b/max/python/max/entrypoints/pipelines.py @@ -385,52 +385,63 @@ def cli_pipeline( ) -@main.group(name="diffusion", cls=ModelGroup) -def diffusion_group() -> None: - """Commands for diffusion-based image/video generation pipelines.""" +# ============================================================================ +# Images Group (OpenAI-compatible /v1/images/* endpoints) +# ============================================================================ -@diffusion_group.command(name="generate", cls=WithLazyPipelineOptions) +@main.group(name="images", cls=ModelGroup) +def images_group() -> None: + """Commands for image generation (OpenAI-compatible /v1/images/* API).""" + + +@images_group.command(name="generate", cls=WithLazyPipelineOptions) @click.option( "--prompt", type=str, - default="A cat holding a sign that says hello world", - help="The text prompt to use for image generation.", + required=True, + help="A text description of the desired image(s).", ) @click.option( - "--height", - type=click.IntRange(min=64), - default=1024, + "--n", + type=click.IntRange(min=1, max=10), + default=1, show_default=True, - help="Generated image height in pixels.", + help="The number of images to generate (1-10).", ) @click.option( - "--width", - type=click.IntRange(min=64), - default=1024, + "--size", + type=str, + default="1024x1024", show_default=True, - help="Generated image width in pixels.", + help="The size of generated images (e.g., '1024x1024', '1792x1024', '1024x1792').", ) @click.option( - "--num-inference-steps", - type=click.IntRange(min=1), - default=50, + "--quality", + type=click.Choice(["auto", "standard", "hd", "high", "medium", "low"]), + default="auto", show_default=True, - help="Number of denoising steps to run.", + help="The quality of the image.", ) @click.option( - "--guidance-scale", - type=float, - default=3.5, + "--response-format", + type=click.Choice(["b64_json", "url"]), + default="b64_json", show_default=True, - help="Classifier-free guidance scale.", + help="The format in which generated images are returned.", ) @click.option( - "--num-images-per-prompt", - type=click.IntRange(min=1), - default=1, + "--output-format", + type=click.Choice(["png", "jpeg", "webp"]), + default="png", show_default=True, - help="Number of images to generate for a single prompt.", + help="The output image format.", +) +@click.option( + "--style", + type=click.Choice(["vivid", "natural"]), + default=None, + help="The style of generated images.", ) @click.option( "--output", @@ -450,49 +461,169 @@ def diffusion_group() -> None: @click.option( "--seed", type=int, - default=42, + default=None, + help="Random seed for reproducibility.", +) +@click.option( + "--num-inference-steps", + type=click.IntRange(min=1), + default=50, show_default=True, - help="Random seed for torch-based latent initialization.", + help="Number of denoising steps (diffusion model parameter).", ) -def diffusion_generate( +@click.option( + "--guidance-scale", + type=float, + default=3.5, + show_default=True, + help="Classifier-free guidance scale (diffusion model parameter).", +) +def images_generate( prompt: str, - height: int, - width: int, - num_inference_steps: int, - guidance_scale: float, - num_images_per_prompt: int, + n: int, + size: str, + quality: str, + response_format: str, + output_format: str, + style: str | None, output: Path, use_torch_randn: bool, - seed: int, + seed: int | None, + num_inference_steps: int, + guidance_scale: float, **config_kwargs: Any, ) -> None: - """Generate images using a diffusion pipeline.""" - from max.entrypoints.cli.generate import generate_image + """Generate images from a text prompt. + + This command follows the OpenAI /v1/images/generations API schema. + + Example: + max images generate --model black-forest-labs/FLUX.1-schnell \\ + --prompt "A beautiful sunset over mountains" \\ + --size 1024x1024 --n 1 --output sunset.png + """ + from max.entrypoints.diffusion import ImageGenerator from max.experimental.realization_context import set_seed + from max.interfaces import ImageGenerationRequest from max.pipelines import PipelineConfig + # Set random seed if provided set_seed(seed) + + # Create pipeline config and generator pipeline_config = PipelineConfig(**config_kwargs) pipeline_config.log_basic_config() try: - generate_image( - pipeline_config=pipeline_config, + generator = ImageGenerator(pipeline_config) + + # Create OpenAI-compatible request + request = ImageGenerationRequest( prompt=prompt, - height=height, - width=width, + n=n, + size=size, + quality=quality, + response_format=response_format, + output_format=output_format, + style=style, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + seed=seed, + ) + + logger.info(f"Generating {n} image(s) with prompt: {prompt[:50]}...") + + # Generate images directly (not using OpenAI response format for CLI) + images = generator.generate( + prompts=prompt, + height=request.get_dimensions()[1], + width=request.get_dimensions()[0], num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, - num_images_per_prompt=num_images_per_prompt, - output=output, + num_images_per_prompt=n, + use_tqdm=True, ) + + # Save images + output.parent.mkdir(parents=True, exist_ok=True) + if len(images) == 1: + images[0].save(output) + logger.info(f"Image saved to: {output}") + else: + for i, img in enumerate(images): + numbered_output = output.with_stem(f"{output.stem}_{i}") + img.save(numbered_output) + logger.info(f"Image saved to: {numbered_output}") + except Exception as exc: logger.exception( - "Diffusion generation failed for model %s with prompt %r", + "Image generation failed for model %s with prompt %r", pipeline_config.model.model_path, prompt, ) - raise click.ClickException("Diffusion generation failed.") from exc + raise click.ClickException("Image generation failed.") from exc + + +@images_group.command(name="serve", cls=WithLazyPipelineOptions) +@common_server_options +def images_serve( + profile_serve: bool, + sim_failure: int, + port: int, + headless: bool, + log_prefix: str | None, + **config_kwargs: Any, +) -> None: + """Start an OpenAI-compatible image generation server. + + This command launches a server with the following endpoints: + - POST /v1/images/generations (create image) + - GET /v1/models (list models) + - GET /health (health check) + + Example: + max images serve --model black-forest-labs/FLUX.1-schnell --port 8000 + + Then use curl to generate images: + curl http://localhost:8000/v1/images/generations \\ + -H "Content-Type: application/json" \\ + -d '{"prompt": "A cat", "size": "1024x1024"}' + """ + from max.entrypoints.cli import serve_diffusion_api_server + from max.pipelines import PipelineConfig + from max.serve.config import Settings + from max.serve.telemetry.common import configure_logging + + # Initialize Settings + setting_kwargs: dict[str, Any] = {} + if port is not None: + setting_kwargs["MAX_SERVE_PORT"] = port + if log_prefix is not None: + setting_kwargs["MAX_SERVE_LOG_PREFIX"] = log_prefix + if headless is not None: + setting_kwargs["MAX_SERVE_HEADLESS"] = headless + + settings = Settings(**setting_kwargs) + + # Initialize pipeline config + pipeline_config = PipelineConfig(**config_kwargs) + pipeline_config.log_basic_config() + + # Configure logging + configure_logging(settings) + + if headless: + logger.error("Headless mode is not supported for image serving yet") + raise click.ClickException("Headless mode not supported") + + serve_diffusion_api_server( + settings=settings, + pipeline_config=pipeline_config, + ) + + +# Legacy alias for backward compatibility +diffusion_group = images_group @main.command(name="encode", cls=WithLazyPipelineOptions) diff --git a/max/python/max/interfaces/__init__.py b/max/python/max/interfaces/__init__.py index 30910f28851..ad0e6844d78 100644 --- a/max/python/max/interfaces/__init__.py +++ b/max/python/max/interfaces/__init__.py @@ -49,9 +49,14 @@ EmbeddingsGenerationInputs, EmbeddingsGenerationOutput, ImageContentPart, + ImageData, ImageGenerationInputs, ImageGenerationOutput, + ImageGenerationRequest, + ImageGenerationResponse, + ImageGenerationUsage, ImageMetadata, + InputTokensDetails, TextContentPart, TextGenerationContext, TextGenerationContextType, @@ -111,9 +116,14 @@ def create_text_pipeline() -> Pipeline[TextGenerationInputs, TextGenerationOutpu "EmbeddingsGenerationOutput", "GenerationStatus", "ImageContentPart", + "ImageData", "ImageGenerationInputs", "ImageGenerationOutput", + "ImageGenerationRequest", + "ImageGenerationResponse", + "ImageGenerationUsage", "ImageMetadata", + "InputTokensDetails", "LoRAOperation", "LoRARequest", "LoRAResponse", diff --git a/max/python/max/interfaces/pipeline_variants/__init__.py b/max/python/max/interfaces/pipeline_variants/__init__.py index 073a7ad7a76..0a4aa1db496 100644 --- a/max/python/max/interfaces/pipeline_variants/__init__.py +++ b/max/python/max/interfaces/pipeline_variants/__init__.py @@ -25,8 +25,13 @@ EmbeddingsGenerationOutput, ) from .image_generation import ( + ImageData, ImageGenerationInputs, ImageGenerationOutput, + ImageGenerationRequest, + ImageGenerationResponse, + ImageGenerationUsage, + InputTokensDetails, ) from .text_generation import ( BatchType, @@ -58,9 +63,14 @@ "EmbeddingsGenerationInputs", "EmbeddingsGenerationOutput", "ImageContentPart", + "ImageData", "ImageGenerationInputs", "ImageGenerationOutput", + "ImageGenerationRequest", + "ImageGenerationResponse", + "ImageGenerationUsage", "ImageMetadata", + "InputTokensDetails", "TextContentPart", "TextGenerationContext", "TextGenerationContextType", diff --git a/max/python/max/interfaces/pipeline_variants/image_generation.py b/max/python/max/interfaces/pipeline_variants/image_generation.py index d761298a7d6..44ea84fe80d 100644 --- a/max/python/max/interfaces/pipeline_variants/image_generation.py +++ b/max/python/max/interfaces/pipeline_variants/image_generation.py @@ -11,19 +11,38 @@ # limitations under the License. # ===----------------------------------------------------------------------=== # -from dataclasses import dataclass +"""OpenAI-compatible image generation request/response models. + +This module provides dataclasses that map to the OpenAI /v1/images/generations +API schema for seamless integration with OpenAI-compatible clients. +""" + +from __future__ import annotations + +import base64 +import io +import time +from dataclasses import dataclass, field +from typing import Any, Literal from max.interfaces.pipeline import PipelineInputs from PIL.Image import Image +# Default image generation parameters +DEFAULT_SIZE = "1024x1024" +DEFAULT_NUM_IMAGES = 1 +DEFAULT_RESPONSE_FORMAT = "b64_json" +DEFAULT_QUALITY = "standard" +DEFAULT_OUTPUT_FORMAT = "png" +DEFAULT_NUM_INFERENCE_STEPS = 50 +DEFAULT_GUIDANCE_SCALE = 3.5 + + @dataclass(eq=True) class ImageGenerationInputs(PipelineInputs): """Inputs for image-generation pipelines.""" - # NOTE: Current implementation only considers offline generation without - # request scheduling. `ImageGenerationContext` should be used once - # request scheduling is implemented. prompt: str negative_prompt: str | None true_cfg_scale: float @@ -40,3 +59,252 @@ class ImageGenerationOutput: images: list[Image] """List of generated images.""" + + +@dataclass +class ImageGenerationRequest: + """OpenAI-compatible image generation request. + + This maps to the OpenAI /v1/images/generations API schema. + See: https://platform.openai.com/docs/api-reference/images/create + """ + + # Required field + prompt: str + """A text description of the desired image(s). Required.""" + + # OpenAI standard fields + model: str | None = None + """The model to use for image generation.""" + + n: int | None = DEFAULT_NUM_IMAGES + """The number of images to generate. Must be between 1 and 10.""" + + quality: str | None = DEFAULT_QUALITY + """The quality of the image (e.g., 'standard', 'hd', 'high', 'medium', 'low').""" + + response_format: Literal["url", "b64_json"] | None = DEFAULT_RESPONSE_FORMAT + """The format in which generated images are returned.""" + + size: str | None = DEFAULT_SIZE + """The size of the generated images (e.g., '1024x1024').""" + + style: str | None = None + """The style of the generated images (e.g., 'vivid', 'natural').""" + + user: str | None = None + """A unique identifier representing your end-user.""" + + # Extended fields for GPT image models + background: str | None = None + """Background transparency ('transparent', 'opaque', 'auto').""" + + moderation: str | None = None + """Content-moderation level ('low', 'auto').""" + + output_compression: int | None = None + """Compression level (0-100%).""" + + output_format: str | None = DEFAULT_OUTPUT_FORMAT + """Output format ('png', 'jpeg', 'webp').""" + + partial_images: int | None = None + """Number of partial images for streaming (0-3).""" + + stream: bool | None = None + """Generate in streaming mode.""" + + # Extended parameters for diffusion models (not in OpenAI spec) + num_inference_steps: int = DEFAULT_NUM_INFERENCE_STEPS + """Number of denoising steps. Extension for diffusion models.""" + + guidance_scale: float = DEFAULT_GUIDANCE_SCALE + """Classifier-free guidance scale. Extension for diffusion models.""" + + seed: int | None = None + """Random seed for reproducibility.""" + + def to_pipeline_inputs(self) -> ImageGenerationInputs: + """Convert OpenAI request to pipeline-native inputs.""" + width, height = self.get_dimensions() + return ImageGenerationInputs( + prompt=self.prompt, + height=height, + width=width, + num_inference_steps=self.num_inference_steps, + guidance_scale=self.guidance_scale, + num_images_per_prompt=self.n or 1, + ) + + def get_dimensions(self) -> tuple[int, int]: + """Parse size string and return (width, height) tuple.""" + size = self.size or DEFAULT_SIZE + + # Handle 'auto' as default + if size == "auto": + size = DEFAULT_SIZE + + # Parse WIDTHxHEIGHT format + try: + parts = size.lower().split("x") + if len(parts) == 2: + width, height = int(parts[0]), int(parts[1]) + if width >= 64 and height >= 64: + return width, height + except (ValueError, IndexError): + pass + + raise ValueError( + f"Invalid size '{size}'. Use format 'WIDTHxHEIGHT' (e.g., '1024x1024')." + ) + + def get_output_format(self) -> str: + """Get the output image format.""" + return self.output_format or DEFAULT_OUTPUT_FORMAT + + +@dataclass +class ImageData: + """Individual image data in the response.""" + + b64_json: str | None = None + """The base64-encoded image data.""" + + url: str | None = None + """The URL of the generated image (valid for 60 minutes).""" + + revised_prompt: str | None = None + """The prompt that was used to generate the image, if revised.""" + + +@dataclass +class InputTokensDetails: + """Details about input token usage.""" + + text_tokens: int = 0 + """Number of text tokens in the input.""" + + image_tokens: int = 0 + """Number of image tokens in the input.""" + + +@dataclass +class ImageGenerationUsage: + """Token usage statistics for image generation.""" + + total_tokens: int = 0 + """Total number of tokens used.""" + + input_tokens: int = 0 + """Number of input tokens.""" + + output_tokens: int = 0 + """Number of output tokens.""" + + input_tokens_details: InputTokensDetails | None = None + """Detailed breakdown of input tokens.""" + + +@dataclass +class ImageGenerationResponse: + """OpenAI-compatible image generation response. + + This maps to the OpenAI ImagesResponse schema. + See: https://platform.openai.com/docs/api-reference/images/object + """ + + created: int + """Unix timestamp when the response was created.""" + + data: list[ImageData] = field(default_factory=list) + """List of generated image data.""" + + usage: ImageGenerationUsage | None = None + """Token usage statistics.""" + + @classmethod + def from_pipeline_output( + cls, + output: ImageGenerationOutput, + response_format: Literal["url", "b64_json"] | None = "b64_json", + output_format: str = "png", + prompt: str | None = None, + ) -> ImageGenerationResponse: + """Convert pipeline output to OpenAI-compatible response. + + Args: + output: The raw pipeline output containing PIL images. + response_format: The desired response format ('url' or 'b64_json'). + output_format: The image format ('png', 'jpeg', 'webp'). + prompt: The original prompt, included as revised_prompt if provided. + + Returns: + An OpenAI-compatible ImageGenerationResponse. + """ + data: list[ImageData] = [] + fmt = response_format or "b64_json" + + # Map output_format to PIL format + pil_format_map = { + "png": "PNG", + "jpeg": "JPEG", + "webp": "WEBP", + } + pil_format = pil_format_map.get(output_format.lower(), "PNG") + + for image in output.images: + image_data = ImageData(revised_prompt=prompt) + + if fmt == "b64_json": + buffer = io.BytesIO() + image.save(buffer, format=pil_format) + buffer.seek(0) + image_data.b64_json = base64.b64encode(buffer.read()).decode( + "utf-8" + ) + elif fmt == "url": + raise ValueError( + "response_format='url' requires external storage. " + "Use 'b64_json' for local image generation." + ) + + data.append(image_data) + + return cls( + created=int(time.time()), + data=data, + usage=None, + ) + + def to_dict(self) -> dict[str, Any]: + """Convert response to dictionary for JSON serialization.""" + result: dict[str, Any] = { + "created": self.created, + "data": [ + { + k: v + for k, v in { + "b64_json": img.b64_json, + "url": img.url, + "revised_prompt": img.revised_prompt, + }.items() + if v is not None + } + for img in self.data + ], + } + + if self.usage is not None: + usage_dict: dict[str, Any] = { + "total_tokens": self.usage.total_tokens, + "input_tokens": self.usage.input_tokens, + "output_tokens": self.usage.output_tokens, + } + if self.usage.input_tokens_details is not None: + usage_dict["input_tokens_details"] = { + "text_tokens": self.usage.input_tokens_details.text_tokens, + "image_tokens": self.usage.input_tokens_details.image_tokens, + } + result["usage"] = usage_dict + + return result diff --git a/max/python/max/pipelines/lib/interfaces/configuration_utils.py b/max/python/max/pipelines/lib/interfaces/configuration_utils.py new file mode 100644 index 00000000000..5c0c5d0e439 --- /dev/null +++ b/max/python/max/pipelines/lib/interfaces/configuration_utils.py @@ -0,0 +1,244 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import functools +import inspect +import json +import os +from collections import OrderedDict +from collections.abc import Callable +from typing import Any, NoReturn + + +class ConfigDict(OrderedDict): + def __init__(self, *args, **kwargs): + """Initialize ConfigDict.""" + super().__init__(*args, **kwargs) + + for key, value in self.items(): + setattr(self, key, value) + + self.__frozen = True + + def __delitem__(self, *args, **kwargs): + raise Exception( + f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance." + ) + + def setdefault(self, *args, **kwargs) -> NoReturn: + """Set default value.""" + raise Exception( + f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance." + ) + + def pop(self, *args, **kwargs) -> NoReturn: + """Pop item.""" + raise Exception( + f"You cannot use ``pop`` on a {self.__class__.__name__} instance." + ) + + def update(self, *args, **kwargs) -> NoReturn: + """Update dictionary.""" + raise Exception( + f"You cannot use ``update`` on a {self.__class__.__name__} instance." + ) + + def __setattr__(self, name: str, value: Any): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception( + f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance." + ) + super().__setattr__(name, value) + + def __setitem__(self, name: str, value: Any): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception( + f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance." + ) + super().__setitem__(name, value) + + +class ConfigMixin: + config_name = None + + @classmethod + def load_config( + cls, + pretrained_model_name_or_path: str, + **kwargs, + ) -> dict: + """Load configuration from a pretrained model directory. + + Args: + pretrained_model_name_or_path: Path to pretrained model or model identifier. + **kwargs: Additional arguments. + + Returns: + Dictionary containing the configuration. + """ + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + subfolder = kwargs.pop("subfolder", None) + + if os.path.isfile(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if subfolder is not None and os.path.isfile( + os.path.join( + pretrained_model_name_or_path, subfolder, cls.config_name + ) + ): + config_file = os.path.join( + pretrained_model_name_or_path, subfolder, cls.config_name + ) + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, cls.config_name) + ): + # Load from a pretrained checkpoint + config_file = os.path.join( + pretrained_model_name_or_path, cls.config_name + ) + else: + raise OSError( + f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." + ) + else: + raise ValueError( + f"The provided pretrained_model_name_or_path '{pretrained_model_name_or_path}'" + " is neither a valid local path nor downloaded properly from Hugging Face Hub." + ) + + config_dict = cls._dict_from_json_file(config_file) + return config_dict + + @property + def config(self) -> ConfigDict: + """Returns the config of the class as a dictionary. + + Returns: + `Dict[str, Any]`: Config of the class. + """ + return self._internal_dict + + @classmethod + def _dict_from_json_file(cls, json_file: str | os.PathLike) -> dict: + with open(json_file, encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + @staticmethod + def _get_init_keys(input_class: Any) -> set: + if hasattr(input_class, "components"): + return set(input_class.components.keys()) + return set( + dict(inspect.signature(input_class.__init__).parameters).keys() + ) + + @classmethod + def extract_init_dict(cls, config_dict: dict) -> dict: + """Extract init dictionary from config dictionary. + + Args: + config_dict: Configuration dictionary. + + Returns: + Dictionary containing the init parameters. + """ + expected_keys = cls._get_init_keys(cls) + + init_dict = { + k: config_dict[k] for k in config_dict if k in expected_keys + } + return init_dict + + def register_to_config(self, **kwargs) -> None: + """Register arguments to the config. + + Args: + **kwargs: Arguments to register. + """ + if self.config_name is None: + raise NotImplementedError( + f"Make sure that {self.__class__} has defined a class name `config_name`" + ) + # Special case for `kwargs` used in deprecation warning added to schedulers + # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, + # or solve in a more general way. + kwargs.pop("kwargs", None) + + if not hasattr(self, "_internal_dict"): + internal_dict = kwargs + else: + internal_dict = {**self._internal_dict, **kwargs} + + self._internal_dict = ConfigDict(internal_dict) + + +def register_to_config(init: Callable) -> Callable: + """Register arguments to the config. + + Args: + init: Initialization function of a class. + + Returns: + Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are + automatically sent to `self.register_to_config`. To ignore a specific argument accepted by the init but that + shouldn't be registered in the config, use the `ignore_for_config` class variable. + + Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! + """ + + @functools.wraps(init) + def inner_init(self: Any, *args, **kwargs) -> None: + # Ignore private kwargs in the init. + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + config_init_kwargs = { + k: v for k, v in kwargs.items() if k.startswith("_") + } + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_to_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + ignore = getattr(self, "ignore_for_config", []) + # Get positional arguments aligned with kwargs + new_kwargs = {} + signature = inspect.signature(init) + parameters = { + name: p.default + for i, (name, p) in enumerate(signature.parameters.items()) + if i > 0 and name not in ignore + } + for arg, name in zip(args, parameters.keys(), strict=False): + new_kwargs[name] = arg + + # Then add all kwargs + new_kwargs.update( + { + k: init_kwargs.get(k, default) + for k, default in parameters.items() + if k not in ignore and k not in new_kwargs + } + ) + + # Take note of the parameters that were not present in the loaded config + if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: + new_kwargs["_use_default_values"] = list( + set(new_kwargs.keys()) - set(init_kwargs) + ) + + new_kwargs = {**config_init_kwargs, **new_kwargs} + self.register_to_config(**new_kwargs) + init(self, *args, **init_kwargs) + + return inner_init From 1e0a9af846e38e1e3fc026ce4436e3386e009d99 Mon Sep 17 00:00:00 2001 From: jingulee Date: Fri, 16 Jan 2026 08:18:44 +0000 Subject: [PATCH 11/18] update: flux.1 examples --- max/examples/diffusion/README.md | 219 +++++++++++++++++ max/examples/diffusion/client_example.py | 240 +++++++++++++++++++ max/examples/diffusion/openai_api_example.py | 127 ++++++++++ 3 files changed, 586 insertions(+) create mode 100644 max/examples/diffusion/README.md create mode 100644 max/examples/diffusion/client_example.py create mode 100644 max/examples/diffusion/openai_api_example.py diff --git a/max/examples/diffusion/README.md b/max/examples/diffusion/README.md new file mode 100644 index 00000000000..a3990f06223 --- /dev/null +++ b/max/examples/diffusion/README.md @@ -0,0 +1,219 @@ +# Diffusion Image Generation Examples + +This directory contains examples for using the MAX diffusion pipeline for image generation. + +## Overview + +The MAX diffusion pipeline supports: +- **Offline generation**: Direct Python API for generating images +- **OpenAI-compatible API**: Server with `/v1/images/generations` endpoint +- **Multiple models**: FLUX.1-dev, FLUX.1-schnell, and other diffusion models + +## Examples + +### 1. Offline Generation (`offline_generation.py`) + +Basic example using the `ImageGenerator` directly: + +```bash +python offline_generation.py +``` + +```python +from max.entrypoints.diffusion import ImageGenerator +from max.pipelines import PipelineConfig + +config = PipelineConfig(model_path="black-forest-labs/FLUX.1-schnell") +generator = ImageGenerator(config) + +# Generate returns a list of PIL Images +images = generator.generate( + "A cat holding a sign that says hello world", + height=1024, + width=1024, + num_inference_steps=50, + guidance_scale=3.5, +) +images[0].save("output.png") +``` + +### 2. OpenAI API Example (`openai_api_example.py`) + +Using the OpenAI-compatible `ImageGenerationRequest`: + +```bash +python openai_api_example.py +``` + +```python +from max.entrypoints.diffusion import ImageGenerator +from max.interfaces import ImageGenerationRequest +from max.pipelines import PipelineConfig + +config = PipelineConfig(model_path="black-forest-labs/FLUX.1-schnell") +generator = ImageGenerator(config) + +# Use OpenAI-compatible request format +request = ImageGenerationRequest( + prompt="A futuristic city skyline at sunset", + size="1024x1024", + n=1, + response_format="b64_json", + num_inference_steps=30, + guidance_scale=3.5, +) + +# create() returns an OpenAI-compatible response +response = generator.create(request) +# response.data[0].b64_json contains the base64-encoded image +``` + +### 3. Client Example (`client_example.py`) + +Connecting to the OpenAI-compatible server: + +```bash +# Start the server +max images serve --model black-forest-labs/FLUX.1-schnell --port 8000 + +# Run the client +python client_example.py +``` + +## CLI Commands + +### Generate Images + +```bash +# Basic generation +max images generate \ + --model black-forest-labs/FLUX.1-schnell \ + --prompt "A beautiful sunset over mountains" \ + --size 1024x1024 \ + --output output.png + +# With custom parameters +max images generate \ + --model black-forest-labs/FLUX.1-dev \ + --prompt "A cyberpunk city" \ + --size 1792x1024 \ + --num-inference-steps 50 \ + --guidance-scale 7.5 \ + --seed 42 \ + --output landscape.png +``` + +### Start Server + +```bash +# Start OpenAI-compatible API server +max images serve \ + --model black-forest-labs/FLUX.1-schnell \ + --port 8000 +``` + +## API Reference + +### Request Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `prompt` | string | Required | Text description of the desired image | +| `model` | string | null | Model to use for generation | +| `n` | integer | 1 | Number of images to generate (1-10) | +| `size` | string | "1024x1024" | Image size (e.g., "1024x1024", "1792x1024") | +| `quality` | string | "standard" | Image quality | +| `response_format` | string | "b64_json" | Response format ("url" or "b64_json") | +| `output_format` | string | "png" | Output format ("png", "jpeg", "webp") | +| `num_inference_steps` | integer | 50 | Number of denoising steps | +| `guidance_scale` | float | 3.5 | Classifier-free guidance scale | +| `seed` | integer | null | Random seed for reproducibility | + +### Response Format + +```json +{ + "created": 1713833628, + "data": [ + { + "b64_json": "iVBORw0KGgo...", + "revised_prompt": "A beautiful sunset over mountains" + } + ] +} +``` + +## curl Examples + +### Generate Image + +```bash +curl http://localhost:8000/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "A beautiful sunset over mountains", + "size": "1024x1024", + "n": 1, + "response_format": "b64_json", + "num_inference_steps": 30, + "guidance_scale": 3.5 + }' +``` + +### Generate and Save + +```bash +curl -s http://localhost:8000/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{"prompt": "A cat", "size": "512x512"}' \ + | jq -r '.data[0].b64_json' | base64 -d > output.png +``` + +### List Models + +```bash +curl http://localhost:8000/v1/models +``` + +### Health Check + +```bash +curl http://localhost:8000/health +``` + +## Using with OpenAI Python Client + +```python +from openai import OpenAI + +client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="not-needed", +) + +response = client.images.generate( + model="black-forest-labs/FLUX.1-schnell", + prompt="A majestic dragon flying over a castle", + size="1024x1024", + n=1, + response_format="b64_json", +) + +# Save the image +import base64 +image_bytes = base64.b64decode(response.data[0].b64_json) +with open("dragon.png", "wb") as f: + f.write(image_bytes) +``` + +## Supported Models + +- `black-forest-labs/FLUX.1-dev` - High quality, slower +- `black-forest-labs/FLUX.1-schnell` - Fast generation + +## Environment Variables + +| Variable | Description | +|----------|-------------| +| `USE_TORCH_RANDN` | Set to "1" to use torch-based random latents | +| `SEED` | Random seed for reproducibility | diff --git a/max/examples/diffusion/client_example.py b/max/examples/diffusion/client_example.py new file mode 100644 index 00000000000..22fa7cd6429 --- /dev/null +++ b/max/examples/diffusion/client_example.py @@ -0,0 +1,240 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Example: Client code for connecting to the diffusion API server. + +This example demonstrates how to connect to the OpenAI-compatible +image generation server using various client methods. + +Prerequisites: + 1. Start the server: + max images serve --model black-forest-labs/FLUX.1-schnell --port 8000 + + 2. Run this client: + python client_example.py + +Dependencies: + pip install requests openai httpx +""" + +import base64 +from pathlib import Path + + +def example_with_requests() -> None: + """Example using the requests library.""" + import requests + + base_url = "http://localhost:8000" + + # Check server health + response = requests.get(f"{base_url}/health") + print(f"Server health: {response.json()}") + + # List available models + response = requests.get(f"{base_url}/v1/models") + models = response.json() + print(f"Available models: {[m['id'] for m in models['data']]}") + + # Generate an image + request_data = { + "prompt": "A beautiful sunset over the ocean with palm trees", + "size": "1024x1024", + "n": 1, + "response_format": "b64_json", + # Diffusion-specific parameters + "num_inference_steps": 30, + "guidance_scale": 3.5, + } + + print(f"\nGenerating image: {request_data['prompt']}") + response = requests.post( + f"{base_url}/v1/images/generations", + json=request_data, + headers={"Content-Type": "application/json"}, + ) + + if response.status_code == 200: + result = response.json() + print(f"Created at: {result['created']}") + + # Save the image + output_dir = Path("outputs/client") + output_dir.mkdir(parents=True, exist_ok=True) + + for i, img_data in enumerate(result["data"]): + if "b64_json" in img_data: + image_bytes = base64.b64decode(img_data["b64_json"]) + output_path = output_dir / f"requests_output_{i}.png" + with open(output_path, "wb") as f: + f.write(image_bytes) + print(f"Image saved to: {output_path}") + else: + print(f"Error: {response.status_code} - {response.text}") + + +def example_with_openai_client() -> None: + """Example using the official OpenAI Python client. + + Note: The OpenAI client can connect to any OpenAI-compatible server. + """ + try: + from openai import OpenAI + except ImportError: + print("Please install openai: pip install openai") + return + + # Point to local server + client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="not-needed", # API key not required for local server + ) + + # List models + models = client.models.list() + print(f"Available models: {[m.id for m in models.data]}") + + # Generate an image + print("\nGenerating image with OpenAI client...") + + # Note: The OpenAI client's images.generate() may not support + # all diffusion-specific parameters. Use raw HTTP for full control. + response = client.images.generate( + model="black-forest-labs/FLUX.1-schnell", + prompt="A majestic dragon flying over a medieval castle", + size="1024x1024", + n=1, + response_format="b64_json", + ) + + print(f"Created at: {response.created}") + + # Save the image + output_dir = Path("outputs/client") + output_dir.mkdir(parents=True, exist_ok=True) + + for i, img_data in enumerate(response.data): + if img_data.b64_json: + image_bytes = base64.b64decode(img_data.b64_json) + output_path = output_dir / f"openai_client_output_{i}.png" + with open(output_path, "wb") as f: + f.write(image_bytes) + print(f"Image saved to: {output_path}") + + +def example_with_httpx_async() -> None: + """Example using httpx for async requests.""" + import asyncio + + try: + import httpx + except ImportError: + print("Please install httpx: pip install httpx") + return + + async def generate_image(): + async with httpx.AsyncClient(timeout=300.0) as client: + base_url = "http://localhost:8000" + + # Generate image + request_data = { + "prompt": "A cyberpunk street scene at night with neon lights", + "size": "1024x1024", + "n": 1, + "response_format": "b64_json", + "num_inference_steps": 30, + "guidance_scale": 3.5, + } + + print(f"Generating: {request_data['prompt']}") + response = await client.post( + f"{base_url}/v1/images/generations", + json=request_data, + ) + + if response.status_code == 200: + result = response.json() + + output_dir = Path("outputs/client") + output_dir.mkdir(parents=True, exist_ok=True) + + for i, img_data in enumerate(result["data"]): + if "b64_json" in img_data: + image_bytes = base64.b64decode(img_data["b64_json"]) + output_path = output_dir / f"httpx_output_{i}.png" + with open(output_path, "wb") as f: + f.write(image_bytes) + print(f"Image saved to: {output_path}") + else: + print(f"Error: {response.status_code}") + + asyncio.run(generate_image()) + + +def example_curl_commands() -> None: + """Print curl commands for reference.""" + print("\n" + "=" * 60) + print("CURL COMMAND EXAMPLES") + print("=" * 60) + + print("\n1. Check server health:") + print(" curl http://localhost:8000/health") + + print("\n2. List available models:") + print(" curl http://localhost:8000/v1/models") + + print("\n3. Generate an image:") + print(""" curl http://localhost:8000/v1/images/generations \\ + -H "Content-Type: application/json" \\ + -d '{ + "prompt": "A beautiful sunset over mountains", + "size": "1024x1024", + "n": 1, + "response_format": "b64_json", + "num_inference_steps": 30, + "guidance_scale": 3.5 + }'""") + + print("\n4. Generate and save image (with jq):") + print(""" curl -s http://localhost:8000/v1/images/generations \\ + -H "Content-Type: application/json" \\ + -d '{"prompt": "A cat", "size": "512x512"}' \\ + | jq -r '.data[0].b64_json' | base64 -d > output.png""") + + +if __name__ == "__main__": + print("=" * 60) + print("Diffusion API Client Examples") + print("=" * 60) + print("\nMake sure the server is running:") + print(" max images serve --model black-forest-labs/FLUX.1-schnell --port 8000") + print() + + # Print curl examples first + example_curl_commands() + + # Uncomment to run the examples: + # print("\n" + "=" * 60) + # print("Running requests example...") + # print("=" * 60) + # example_with_requests() + + # print("\n" + "=" * 60) + # print("Running OpenAI client example...") + # print("=" * 60) + # example_with_openai_client() + + # print("\n" + "=" * 60) + # print("Running httpx async example...") + # print("=" * 60) + # example_with_httpx_async() diff --git a/max/examples/diffusion/openai_api_example.py b/max/examples/diffusion/openai_api_example.py new file mode 100644 index 00000000000..18dc1121539 --- /dev/null +++ b/max/examples/diffusion/openai_api_example.py @@ -0,0 +1,127 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Example: Using OpenAI-compatible API for image generation. + +This example demonstrates how to use the ImageGenerationRequest and +ImageGenerationResponse classes for OpenAI-compatible image generation. + +Usage: + python openai_api_example.py +""" + +import base64 +import os +from pathlib import Path + +from max.entrypoints.diffusion import ImageGenerator +from max.interfaces import ImageGenerationRequest +from max.pipelines import PipelineConfig + + +def main() -> None: + # Configure random seed for reproducibility + seed = 42 + os.environ["USE_TORCH_RANDN"] = "1" + os.environ["SEED"] = str(seed) + + # Initialize the generator + model_path = "black-forest-labs/FLUX.1-schnell" + pipeline_config = PipelineConfig(model_path=model_path) + generator = ImageGenerator(pipeline_config) + + print(f"Model loaded: {generator.model_name}") + + # Create an OpenAI-compatible request + request = ImageGenerationRequest( + prompt="A futuristic city skyline at sunset with flying cars", + size="1024x1024", + n=1, + quality="standard", + response_format="b64_json", + output_format="png", + # Diffusion-specific parameters + num_inference_steps=30, + guidance_scale=3.5, + seed=seed, + ) + + print(f"Generating image with prompt: {request.prompt}") + print(f"Size: {request.size}, Steps: {request.num_inference_steps}") + + # Generate using OpenAI-compatible API (create method) + response = generator.create(request) + + print(f"Response created at: {response.created}") + print(f"Number of images: {len(response.data)}") + + # Save the generated image + output_dir = Path("outputs") + output_dir.mkdir(parents=True, exist_ok=True) + + for i, image_data in enumerate(response.data): + if image_data.b64_json: + # Decode base64 and save + image_bytes = base64.b64decode(image_data.b64_json) + output_path = output_dir / f"openai_api_output_{i}.png" + with open(output_path, "wb") as f: + f.write(image_bytes) + print(f"Image saved to: {output_path}") + + if image_data.revised_prompt: + print(f"Revised prompt: {image_data.revised_prompt}") + + +def batch_generation_example() -> None: + """Example: Generate multiple images with different prompts.""" + os.environ["USE_TORCH_RANDN"] = "1" + os.environ["SEED"] = "42" + + model_path = "black-forest-labs/FLUX.1-schnell" + pipeline_config = PipelineConfig(model_path=model_path) + generator = ImageGenerator(pipeline_config) + + prompts = [ + "A serene mountain landscape with a lake", + "A cute robot playing with a kitten", + "An abstract painting of emotions", + ] + + output_dir = Path("outputs/batch") + output_dir.mkdir(parents=True, exist_ok=True) + + for idx, prompt in enumerate(prompts): + request = ImageGenerationRequest( + prompt=prompt, + size="512x512", # Smaller for faster generation + n=1, + response_format="b64_json", + num_inference_steps=20, + guidance_scale=3.5, + ) + + print(f"\n[{idx + 1}/{len(prompts)}] Generating: {prompt[:50]}...") + response = generator.create(request) + + if response.data and response.data[0].b64_json: + image_bytes = base64.b64decode(response.data[0].b64_json) + output_path = output_dir / f"batch_{idx}.png" + with open(output_path, "wb") as f: + f.write(image_bytes) + print(f"Saved to: {output_path}") + + +if __name__ == "__main__": + main() + # Uncomment to run batch generation: + # batch_generation_example() From d462d262cc1cbf86b4cc43e310f49094363b84bd Mon Sep 17 00:00:00 2001 From: jingulee Date: Fri, 16 Jan 2026 09:14:58 +0000 Subject: [PATCH 12/18] fix: import path --- max/python/max/entrypoints/diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/max/python/max/entrypoints/diffusion.py b/max/python/max/entrypoints/diffusion.py index d166d2de96a..e383c6e8e0d 100644 --- a/max/python/max/entrypoints/diffusion.py +++ b/max/python/max/entrypoints/diffusion.py @@ -66,7 +66,7 @@ from max.pipelines.lib import PIPELINE_REGISTRY, PipelineConfig if TYPE_CHECKING: - from max.pipelines.core import ImageGenerationPipeline + from max.pipelines.lib.pipeline_variants.image_generation import ImageGenerationPipeline # ============================================================================ From 7c9b8633c53dd1ca0f9ae71c93c8865af964d310 Mon Sep 17 00:00:00 2001 From: jingulee Date: Mon, 19 Jan 2026 07:51:55 +0000 Subject: [PATCH 13/18] update: offline image generate --- max/examples/diffusion/README.md | 17 +- max/examples/diffusion/client_example.py | 66 ++++---- max/examples/diffusion/offline_generation.py | 7 +- max/examples/diffusion/openai_api_example.py | 57 ++----- max/python/max/entrypoints/diffusion.py | 41 +---- max/python/max/entrypoints/pipelines.py | 35 ++-- .../max/entrypoints/pipelines_diffusion.py | 2 +- max/python/max/interfaces/__init__.py | 4 + .../interfaces/pipeline_variants/__init__.py | 4 + .../pipeline_variants/image_generation.py | 63 +++++++- max/python/max/interfaces/task.py | 3 + .../max/pipelines/architectures/flux1/arch.py | 1 - .../architectures/flux1/pipeline_flux.py | 149 +++++------------- max/python/max/pipelines/lib/config.py | 138 ++++++++++++---- max/python/max/serve/api_server.py | 58 ++++--- max/python/max/serve/pipelines/diffusion.py | 67 ++++++++ max/python/max/serve/router/openai_routes.py | 99 ++++++++++++ max/python/max/serve/scheduler/__init__.py | 29 ++++ .../scheduler/image_generation_scheduler.py | 123 +++++++++++++++ 19 files changed, 658 insertions(+), 305 deletions(-) create mode 100644 max/python/max/serve/pipelines/diffusion.py create mode 100644 max/python/max/serve/scheduler/image_generation_scheduler.py diff --git a/max/examples/diffusion/README.md b/max/examples/diffusion/README.md index a3990f06223..96ae9f8c999 100644 --- a/max/examples/diffusion/README.md +++ b/max/examples/diffusion/README.md @@ -7,7 +7,7 @@ This directory contains examples for using the MAX diffusion pipeline for image The MAX diffusion pipeline supports: - **Offline generation**: Direct Python API for generating images - **OpenAI-compatible API**: Server with `/v1/images/generations` endpoint -- **Multiple models**: FLUX.1-dev, FLUX.1-schnell, and other diffusion models +- **Multiple models**: FLUX.1-dev, and other diffusion models ## Examples @@ -23,7 +23,7 @@ python offline_generation.py from max.entrypoints.diffusion import ImageGenerator from max.pipelines import PipelineConfig -config = PipelineConfig(model_path="black-forest-labs/FLUX.1-schnell") +config = PipelineConfig(model_path="black-forest-labs/FLUX.1-dev") generator = ImageGenerator(config) # Generate returns a list of PIL Images @@ -50,7 +50,7 @@ from max.entrypoints.diffusion import ImageGenerator from max.interfaces import ImageGenerationRequest from max.pipelines import PipelineConfig -config = PipelineConfig(model_path="black-forest-labs/FLUX.1-schnell") +config = PipelineConfig(model_path="black-forest-labs/FLUX.1-dev") generator = ImageGenerator(config) # Use OpenAI-compatible request format @@ -74,7 +74,7 @@ Connecting to the OpenAI-compatible server: ```bash # Start the server -max images serve --model black-forest-labs/FLUX.1-schnell --port 8000 +max images serve --model black-forest-labs/FLUX.1-dev --port 8000 # Run the client python client_example.py @@ -87,7 +87,7 @@ python client_example.py ```bash # Basic generation max images generate \ - --model black-forest-labs/FLUX.1-schnell \ + --model black-forest-labs/FLUX.1-dev \ --prompt "A beautiful sunset over mountains" \ --size 1024x1024 \ --output output.png @@ -108,7 +108,7 @@ max images generate \ ```bash # Start OpenAI-compatible API server max images serve \ - --model black-forest-labs/FLUX.1-schnell \ + --model black-forest-labs/FLUX.1-dev \ --port 8000 ``` @@ -192,7 +192,7 @@ client = OpenAI( ) response = client.images.generate( - model="black-forest-labs/FLUX.1-schnell", + model="black-forest-labs/FLUX.1-dev", prompt="A majestic dragon flying over a castle", size="1024x1024", n=1, @@ -208,8 +208,7 @@ with open("dragon.png", "wb") as f: ## Supported Models -- `black-forest-labs/FLUX.1-dev` - High quality, slower -- `black-forest-labs/FLUX.1-schnell` - Fast generation +- `black-forest-labs/FLUX.1-dev` - Flux 1 Dev ## Environment Variables diff --git a/max/examples/diffusion/client_example.py b/max/examples/diffusion/client_example.py index 22fa7cd6429..8b9345d0ba3 100644 --- a/max/examples/diffusion/client_example.py +++ b/max/examples/diffusion/client_example.py @@ -18,7 +18,7 @@ Prerequisites: 1. Start the server: - max images serve --model black-forest-labs/FLUX.1-schnell --port 8000 + max images serve --model black-forest-labs/FLUX.1-dev --port 8000 2. Run this client: python client_example.py @@ -30,12 +30,20 @@ import base64 from pathlib import Path +from argparse import ArgumentParser + +parser = ArgumentParser() +parser.add_argument("--port", type=int, default=8000) +args = parser.parse_args() + +PORT = args.port + def example_with_requests() -> None: """Example using the requests library.""" import requests - base_url = "http://localhost:8000" + base_url = f"http://localhost:{PORT}" # Check server health response = requests.get(f"{base_url}/health") @@ -96,7 +104,7 @@ def example_with_openai_client() -> None: # Point to local server client = OpenAI( - base_url="http://localhost:8000/v1", + base_url=f"http://localhost:{PORT}/v1", api_key="not-needed", # API key not required for local server ) @@ -110,7 +118,7 @@ def example_with_openai_client() -> None: # Note: The OpenAI client's images.generate() may not support # all diffusion-specific parameters. Use raw HTTP for full control. response = client.images.generate( - model="black-forest-labs/FLUX.1-schnell", + model="black-forest-labs/FLUX.1-dev", prompt="A majestic dragon flying over a medieval castle", size="1024x1024", n=1, @@ -144,7 +152,7 @@ def example_with_httpx_async() -> None: async def generate_image(): async with httpx.AsyncClient(timeout=300.0) as client: - base_url = "http://localhost:8000" + base_url = f"http://localhost:{PORT}" # Generate image request_data = { @@ -188,13 +196,14 @@ def example_curl_commands() -> None: print("=" * 60) print("\n1. Check server health:") - print(" curl http://localhost:8000/health") + print(f" curl http://localhost:{PORT}/health") print("\n2. List available models:") - print(" curl http://localhost:8000/v1/models") + print(f" curl http://localhost:{PORT}/v1/models") print("\n3. Generate an image:") - print(""" curl http://localhost:8000/v1/images/generations \\ + print( + f""" curl http://localhost:{PORT}/v1/images/generations \\ -H "Content-Type: application/json" \\ -d '{ "prompt": "A beautiful sunset over mountains", @@ -203,13 +212,16 @@ def example_curl_commands() -> None: "response_format": "b64_json", "num_inference_steps": 30, "guidance_scale": 3.5 - }'""") + }'""" + ) print("\n4. Generate and save image (with jq):") - print(""" curl -s http://localhost:8000/v1/images/generations \\ + print( + f""" curl -s http://localhost:{PORT}/v1/images/generations \\ -H "Content-Type: application/json" \\ -d '{"prompt": "A cat", "size": "512x512"}' \\ - | jq -r '.data[0].b64_json' | base64 -d > output.png""") + | jq -r '.data[0].b64_json' | base64 -d > output.png""" + ) if __name__ == "__main__": @@ -217,24 +229,24 @@ def example_curl_commands() -> None: print("Diffusion API Client Examples") print("=" * 60) print("\nMake sure the server is running:") - print(" max images serve --model black-forest-labs/FLUX.1-schnell --port 8000") + print(f" max images serve --model black-forest-labs/FLUX.1-dev --port {PORT}") print() # Print curl examples first - example_curl_commands() + # example_curl_commands() # Uncomment to run the examples: - # print("\n" + "=" * 60) - # print("Running requests example...") - # print("=" * 60) - # example_with_requests() - - # print("\n" + "=" * 60) - # print("Running OpenAI client example...") - # print("=" * 60) - # example_with_openai_client() - - # print("\n" + "=" * 60) - # print("Running httpx async example...") - # print("=" * 60) - # example_with_httpx_async() + print("\n" + "=" * 60) + print("Running requests example...") + print("=" * 60) + example_with_requests() + + print("\n" + "=" * 60) + print("Running OpenAI client example...") + print("=" * 60) + example_with_openai_client() + + print("\n" + "=" * 60) + print("Running httpx async example...") + print("=" * 60) + example_with_httpx_async() diff --git a/max/examples/diffusion/offline_generation.py b/max/examples/diffusion/offline_generation.py index cae6c25a368..20749c2a244 100644 --- a/max/examples/diffusion/offline_generation.py +++ b/max/examples/diffusion/offline_generation.py @@ -21,9 +21,7 @@ def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument( - "--model-path", type=str, default="black-forest-labs/FLUX.1-dev" - ) + parser.add_argument("--model-path", type=str, default="black-forest-labs/FLUX.1-dev") parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() @@ -34,13 +32,14 @@ def main() -> None: prompt = "A cat holding a sign that says hello world" print(f"Prompt: {prompt}") + print(f"Seed: {args.seed}") # Generate images using the new API images = pipe.generate( prompt, height=1024, width=1024, - num_inference_steps=50, + num_inference_steps=28, guidance_scale=3.5, ) diff --git a/max/examples/diffusion/openai_api_example.py b/max/examples/diffusion/openai_api_example.py index 18dc1121539..db7c4e9d514 100644 --- a/max/examples/diffusion/openai_api_example.py +++ b/max/examples/diffusion/openai_api_example.py @@ -17,11 +17,12 @@ ImageGenerationResponse classes for OpenAI-compatible image generation. Usage: - python openai_api_example.py + python openai_api_example.py --seed 42 --prompt "A futuristic city skyline at sunset with flying cars" --model-path "black-forest-labs/FLUX.1-dev" """ import base64 import os +import argparse from pathlib import Path from max.entrypoints.diffusion import ImageGenerator @@ -31,16 +32,21 @@ def main() -> None: # Configure random seed for reproducibility - seed = 42 - os.environ["USE_TORCH_RANDN"] = "1" + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--prompt", type=str, default="A futuristic city skyline at sunset with flying cars") + parser.add_argument("--model-path", type=str, default="black-forest-labs/FLUX.1-dev") + args = parser.parse_args() + seed = args.seed os.environ["SEED"] = str(seed) + model_path = args.model_path # Initialize the generator - model_path = "black-forest-labs/FLUX.1-schnell" pipeline_config = PipelineConfig(model_path=model_path) generator = ImageGenerator(pipeline_config) print(f"Model loaded: {generator.model_name}") + print(f"Seed: {os.getenv('SEED', 'not set')}") # Create an OpenAI-compatible request request = ImageGenerationRequest( @@ -51,7 +57,7 @@ def main() -> None: response_format="b64_json", output_format="png", # Diffusion-specific parameters - num_inference_steps=30, + num_inference_steps=28, guidance_scale=3.5, seed=seed, ) @@ -82,46 +88,5 @@ def main() -> None: print(f"Revised prompt: {image_data.revised_prompt}") -def batch_generation_example() -> None: - """Example: Generate multiple images with different prompts.""" - os.environ["USE_TORCH_RANDN"] = "1" - os.environ["SEED"] = "42" - - model_path = "black-forest-labs/FLUX.1-schnell" - pipeline_config = PipelineConfig(model_path=model_path) - generator = ImageGenerator(pipeline_config) - - prompts = [ - "A serene mountain landscape with a lake", - "A cute robot playing with a kitten", - "An abstract painting of emotions", - ] - - output_dir = Path("outputs/batch") - output_dir.mkdir(parents=True, exist_ok=True) - - for idx, prompt in enumerate(prompts): - request = ImageGenerationRequest( - prompt=prompt, - size="512x512", # Smaller for faster generation - n=1, - response_format="b64_json", - num_inference_steps=20, - guidance_scale=3.5, - ) - - print(f"\n[{idx + 1}/{len(prompts)}] Generating: {prompt[:50]}...") - response = generator.create(request) - - if response.data and response.data[0].b64_json: - image_bytes = base64.b64decode(response.data[0].b64_json) - output_path = output_dir / f"batch_{idx}.png" - with open(output_path, "wb") as f: - f.write(image_bytes) - print(f"Saved to: {output_path}") - - if __name__ == "__main__": main() - # Uncomment to run batch generation: - # batch_generation_example() diff --git a/max/python/max/entrypoints/diffusion.py b/max/python/max/entrypoints/diffusion.py index e383c6e8e0d..c8219d74812 100644 --- a/max/python/max/entrypoints/diffusion.py +++ b/max/python/max/entrypoints/diffusion.py @@ -66,12 +66,9 @@ from max.pipelines.lib import PIPELINE_REGISTRY, PipelineConfig if TYPE_CHECKING: - from max.pipelines.lib.pipeline_variants.image_generation import ImageGenerationPipeline - - -# ============================================================================ -# Internal Request/Response Types -# ============================================================================ + from max.pipelines.lib.pipeline_variants.image_generation import ( + ImageGenerationPipeline, + ) @dataclass @@ -103,11 +100,6 @@ class _ThreadControl: cancel: Event = field(default_factory=Event) -# ============================================================================ -# Main ImageGenerator Class -# ============================================================================ - - class ImageGenerator: """High-level interface for generating images using diffusion models. @@ -139,7 +131,7 @@ def __init__(self, pipeline_config: PipelineConfig) -> None: pipeline_config: Configuration specifying the model and parameters. """ self.pipeline_config = pipeline_config - self.model_name = pipeline_config.model_config.model_path + self.model_name = pipeline_config.model.model_path # Initialize thread control and queues self._thread_control = _ThreadControl() @@ -168,10 +160,6 @@ def __del__(self) -> None: if self._worker_thread.is_alive(): self._worker_thread.join(timeout=5.0) - # ======================================================================== - # Public API: Direct Generation - # ======================================================================== - def generate( self, prompts: str | Sequence[str], @@ -229,10 +217,6 @@ def generate( # Submit request and wait for response return self._submit_and_wait(request) - # ======================================================================== - # Public API: OpenAI-Compatible - # ======================================================================== - def create( self, request: ImageGenerationRequest, @@ -280,10 +264,6 @@ def create( prompt=request.prompt, ) - # ======================================================================== - # Internal Methods - # ======================================================================== - def _submit_and_wait(self, request: _ImageRequest) -> list[Image]: """Submit a request to the queue and wait for response.""" response_queue: queue.Queue[_ImageResponse] = queue.Queue() @@ -296,10 +276,6 @@ def _submit_and_wait(self, request: _ImageRequest) -> list[Image]: finally: self._pending_requests.pop(request.id, None) - # ======================================================================== - # Class Methods - # ======================================================================== - @classmethod def from_model(cls, model: str, **kwargs) -> ImageGenerator: """Create an ImageGenerator from a model identifier. @@ -322,18 +298,9 @@ def from_model(cls, model: str, **kwargs) -> ImageGenerator: return cls(config) -# ============================================================================ -# Legacy Alias (for backward compatibility) -# ============================================================================ - DiffusionPipeline = ImageGenerator -# ============================================================================ -# Background Worker -# ============================================================================ - - def _run_worker( thread_control: _ThreadControl, pipeline_config: PipelineConfig, diff --git a/max/python/max/entrypoints/pipelines.py b/max/python/max/entrypoints/pipelines.py index 62d66752816..1b7c11a25cb 100644 --- a/max/python/max/entrypoints/pipelines.py +++ b/max/python/max/entrypoints/pipelines.py @@ -19,7 +19,7 @@ import sys from collections.abc import Callable, Sequence from pathlib import Path -from typing import Any, TypeVar +from typing import Any, Literal, TypeVar import click from click import shell_completion @@ -396,6 +396,12 @@ def images_group() -> None: @images_group.command(name="generate", cls=WithLazyPipelineOptions) +@click.option( + "--model", + type=str, + required=True, + help="Specify the repository ID of a Hugging Face model to use for image generation (e.g., 'black-forest-labs/FLUX.1-dev').", +) @click.option( "--prompt", type=str, @@ -450,14 +456,6 @@ def images_group() -> None: show_default=True, help="Output image path (numbered if multiple images are generated).", ) -@click.option( - "--use-torch-randn/--no-use-torch-randn", - default=False, - show_default=True, - help=( - "Use torch-based random latents (set USE_TORCH_RANDN and SEED env vars)." - ), -) @click.option( "--seed", type=int, @@ -479,15 +477,15 @@ def images_group() -> None: help="Classifier-free guidance scale (diffusion model parameter).", ) def images_generate( + model: str, prompt: str, n: int, size: str, quality: str, - response_format: str, + response_format: Literal["url", "b64_json"], output_format: str, style: str | None, output: Path, - use_torch_randn: bool, seed: int | None, num_inference_steps: int, guidance_scale: float, @@ -498,20 +496,25 @@ def images_generate( This command follows the OpenAI /v1/images/generations API schema. Example: - max images generate --model black-forest-labs/FLUX.1-schnell \\ + max images generate --model black-forest-labs/FLUX.1-dev \\ --prompt "A beautiful sunset over mountains" \\ --size 1024x1024 --n 1 --output sunset.png """ from max.entrypoints.diffusion import ImageGenerator from max.experimental.realization_context import set_seed from max.interfaces import ImageGenerationRequest - from max.pipelines import PipelineConfig + from max.pipelines.lib.config import ImageGenerationConfig # Set random seed if provided set_seed(seed) - # Create pipeline config and generator - pipeline_config = PipelineConfig(**config_kwargs) + """ + TODO: + - This configuration is dummy for now. Just for logging purpose. + - Modifications are required to enable the use of pipeline config. + """ + config_kwargs["model"] = model + pipeline_config = ImageGenerationConfig(**config_kwargs) pipeline_config.log_basic_config() try: @@ -582,7 +585,7 @@ def images_serve( - GET /health (health check) Example: - max images serve --model black-forest-labs/FLUX.1-schnell --port 8000 + max images serve --model black-forest-labs/FLUX.1-dev --port 8000 Then use curl to generate images: curl http://localhost:8000/v1/images/generations \\ diff --git a/max/python/max/entrypoints/pipelines_diffusion.py b/max/python/max/entrypoints/pipelines_diffusion.py index 52863bf27f4..1fec7400445 100644 --- a/max/python/max/entrypoints/pipelines_diffusion.py +++ b/max/python/max/entrypoints/pipelines_diffusion.py @@ -19,7 +19,7 @@ def main() -> None: pipelines_cli.main( prog_name="pipelines", - args=["diffusion", *sys.argv[1:]], + args=["images", *sys.argv[1:]], ) diff --git a/max/python/max/interfaces/__init__.py b/max/python/max/interfaces/__init__.py index ad0e6844d78..41c32dcf7ce 100644 --- a/max/python/max/interfaces/__init__.py +++ b/max/python/max/interfaces/__init__.py @@ -50,6 +50,8 @@ EmbeddingsGenerationOutput, ImageContentPart, ImageData, + ImageGenerationContext, + ImageGenerationContextType, ImageGenerationInputs, ImageGenerationOutput, ImageGenerationRequest, @@ -117,6 +119,8 @@ def create_text_pipeline() -> Pipeline[TextGenerationInputs, TextGenerationOutpu "GenerationStatus", "ImageContentPart", "ImageData", + "ImageGenerationContext", + "ImageGenerationContextType", "ImageGenerationInputs", "ImageGenerationOutput", "ImageGenerationRequest", diff --git a/max/python/max/interfaces/pipeline_variants/__init__.py b/max/python/max/interfaces/pipeline_variants/__init__.py index 0a4aa1db496..22913eeb74f 100644 --- a/max/python/max/interfaces/pipeline_variants/__init__.py +++ b/max/python/max/interfaces/pipeline_variants/__init__.py @@ -26,6 +26,8 @@ ) from .image_generation import ( ImageData, + ImageGenerationContext, + ImageGenerationContextType, ImageGenerationInputs, ImageGenerationOutput, ImageGenerationRequest, @@ -64,6 +66,8 @@ "EmbeddingsGenerationOutput", "ImageContentPart", "ImageData", + "ImageGenerationContext", + "ImageGenerationContextType", "ImageGenerationInputs", "ImageGenerationOutput", "ImageGenerationRequest", diff --git a/max/python/max/interfaces/pipeline_variants/image_generation.py b/max/python/max/interfaces/pipeline_variants/image_generation.py index 44ea84fe80d..2e3cd739701 100644 --- a/max/python/max/interfaces/pipeline_variants/image_generation.py +++ b/max/python/max/interfaces/pipeline_variants/image_generation.py @@ -23,12 +23,64 @@ import io import time from dataclasses import dataclass, field -from typing import Any, Literal +from typing import Any, Literal, Protocol, TypeVar, runtime_checkable +from max.interfaces.context import BaseContext from max.interfaces.pipeline import PipelineInputs +from max.interfaces.request import RequestID from PIL.Image import Image +@runtime_checkable +class ImageGenerationContext(BaseContext, Protocol): + """Protocol defining the interface for image generation contexts. + + An ``ImageGenerationContext`` represents model inputs for image generation + pipelines, managing the state and parameters needed for generating images + from text prompts using diffusion models. + """ + + @property + def request_id(self) -> RequestID: + """Unique identifier for this request.""" + ... + + @property + def prompt(self) -> str: + """The text prompt for image generation.""" + ... + + @property + def height(self) -> int: + """The height of the generated image in pixels.""" + ... + + @property + def width(self) -> int: + """The width of the generated image in pixels.""" + ... + + @property + def num_inference_steps(self) -> int: + """Number of denoising steps.""" + ... + + @property + def guidance_scale(self) -> float: + """Classifier-free guidance scale.""" + ... + + @property + def num_images_per_prompt(self) -> int: + """Number of images to generate per prompt.""" + ... + + +ImageGenerationContextType = TypeVar( + "ImageGenerationContextType", bound=ImageGenerationContext +) + + # Default image generation parameters DEFAULT_SIZE = "1024x1024" DEFAULT_NUM_IMAGES = 1 @@ -60,6 +112,15 @@ class ImageGenerationOutput: images: list[Image] """List of generated images.""" + @property + def is_done(self) -> bool: + """Indicates whether image generation is complete. + + Returns: + bool: Always True, as image generation is a single-step operation. + """ + return True + @dataclass class ImageGenerationRequest: diff --git a/max/python/max/interfaces/task.py b/max/python/max/interfaces/task.py index 0422ebf657c..243d86ca443 100644 --- a/max/python/max/interfaces/task.py +++ b/max/python/max/interfaces/task.py @@ -74,6 +74,7 @@ def output_type( from .pipeline_variants import ( AudioGenerationOutput, EmbeddingsGenerationOutput, + ImageGenerationOutput, TextGenerationOutput, ) from .scheduler import SchedulerResult @@ -87,6 +88,8 @@ def output_type( return dict[RequestID, SchedulerResult[EmbeddingsGenerationOutput]] elif self == PipelineTask.AUDIO_GENERATION: return dict[RequestID, SchedulerResult[AudioGenerationOutput]] + elif self == PipelineTask.IMAGE_GENERATION: + return dict[RequestID, SchedulerResult[ImageGenerationOutput]] else: raise ValueError( f"PipelineTask ({self}) does not have an output_type defined." diff --git a/max/python/max/pipelines/architectures/flux1/arch.py b/max/python/max/pipelines/architectures/flux1/arch.py index aea17022398..7dac24af036 100644 --- a/max/python/max/pipelines/architectures/flux1/arch.py +++ b/max/python/max/pipelines/architectures/flux1/arch.py @@ -29,7 +29,6 @@ supported_encodings={SupportedEncoding.bfloat16: []}, example_repo_ids=[ "black-forest-labs/FLUX.1-dev", - "black-forest-labs/FLUX.1-schnell", ], pipeline_model=FluxPipeline, tokenizer=TextTokenizer, diff --git a/max/python/max/pipelines/architectures/flux1/pipeline_flux.py b/max/python/max/pipelines/architectures/flux1/pipeline_flux.py index 85298caf20e..d0e830dbdba 100644 --- a/max/python/max/pipelines/architectures/flux1/pipeline_flux.py +++ b/max/python/max/pipelines/architectures/flux1/pipeline_flux.py @@ -22,7 +22,7 @@ from max.dtype import DType from max.experimental import Tensor as Tensor_v3 from max.experimental import functional as F -from max.experimental import random +from max.experimental.random import normal from max.graph import DeviceRef from max.pipelines.lib.diffusion_schedulers import ( FlowMatchEulerDiscreteScheduler, @@ -34,6 +34,7 @@ from max.pipelines.lib.interfaces.diffusion_pipeline import ( DiffusionPipeline, ) +from max.experimental.realization_context import set_seed from tqdm import tqdm from transformers import ( CLIPTokenizer, @@ -80,13 +81,9 @@ def retrieve_timesteps( second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: - raise ValueError( - "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" - ) + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: - accepts_timesteps = "timesteps" in set( - inspect.signature(scheduler.set_timesteps).parameters.keys() - ) + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -96,9 +93,7 @@ def retrieve_timesteps( timesteps = scheduler.timesteps num_inference_steps = int(timesteps.shape[0]) elif sigmas is not None: - accept_sigmas = "sigmas" in set( - inspect.signature(scheduler.set_timesteps).parameters.keys() - ) + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -155,17 +150,9 @@ class FluxPipeline(DiffusionPipeline): } def init_remaining_components(self) -> None: - image_processor_class = self.components.get( - "image_processor", VaeImageProcessor - ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) - if getattr(self, "vae", None) - else 8 - ) - image_processor = image_processor_class( - vae_scale_factor=self.vae_scale_factor * 2 - ) + image_processor_class = self.components.get("image_processor", VaeImageProcessor) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + image_processor = image_processor_class(vae_scale_factor=self.vae_scale_factor * 2) self.image_processor = image_processor def encode_prompt( @@ -204,13 +191,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, FluxPipeline): self._lora_scale = lora_scale - if self.text_encoder is not None and hasattr( - self.text_encoder, "set_lora_scale" - ): + if self.text_encoder is not None and hasattr(self.text_encoder, "set_lora_scale"): self.text_encoder.set_lora_scale(lora_scale) - if self.text_encoder_2 is not None and hasattr( - self.text_encoder_2, "set_lora_scale" - ): + if self.text_encoder_2 is not None and hasattr(self.text_encoder_2, "set_lora_scale"): self.text_encoder_2.set_lora_scale(lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -219,16 +202,12 @@ def encode_prompt( text_inputs = self.tokenizer( prompt, padding="max_length", - max_length=min( - max_sequence_length, self.tokenizer.model_max_length - ), + max_length=min(max_sequence_length, self.tokenizer.model_max_length), truncation=True, return_length=False, return_overflowing_tokens=False, ) - text_input_ids = Tensor_v3.constant( - text_inputs.input_ids, device=device, dtype=DType.int64 - ) + text_input_ids = Tensor_v3.constant(text_inputs.input_ids, device=device, dtype=DType.int64) text_encoder_outputs = self.text_encoder(text_input_ids) prompt_embeds = text_encoder_outputs[0] @@ -246,9 +225,7 @@ def encode_prompt( return_length=False, return_overflowing_tokens=False, ) - text_input_ids_2 = Tensor_v3.constant( - text_inputs_2.input_ids, device=device, dtype=DType.int64 - ) + text_input_ids_2 = Tensor_v3.constant(text_inputs_2.input_ids, device=device, dtype=DType.int64) prompt_embeds_2 = self.text_encoder_2(text_input_ids_2)[0] else: @@ -264,24 +241,14 @@ def encode_prompt( ) bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = Tensor_v3.from_dlpack( - prompt_embeds - ) # V2 Tensor to V3 Tensor - pooled_prompt_embeds = Tensor_v3.from_dlpack( - pooled_prompt_embeds - ) # V2 Tensor to V3 Tensor + prompt_embeds = Tensor_v3.from_dlpack(prompt_embeds) # V2 Tensor to V3 Tensor + pooled_prompt_embeds = Tensor_v3.from_dlpack(pooled_prompt_embeds) # V2 Tensor to V3 Tensor prompt_embeds = F.tile(prompt_embeds, (1, num_images_per_prompt, 1)) - prompt_embeds = prompt_embeds.reshape( - (bs_embed * num_images_per_prompt, seq_len, -1) - ) + prompt_embeds = prompt_embeds.reshape((bs_embed * num_images_per_prompt, seq_len, -1)) - pooled_prompt_embeds = F.tile( - pooled_prompt_embeds, (1, num_images_per_prompt) - ) - pooled_prompt_embeds = pooled_prompt_embeds.reshape( - (bs_embed * num_images_per_prompt, -1) - ) + pooled_prompt_embeds = F.tile(pooled_prompt_embeds, (1, num_images_per_prompt)) + pooled_prompt_embeds = pooled_prompt_embeds.reshape((bs_embed * num_images_per_prompt, -1)) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -315,9 +282,7 @@ def _prepare_latent_image_ids( latent_image_id_channels, ), ) - latent_image_ids = ( - Tensor_v3.from_dlpack(latent_image_ids).to(device).cast(dtype) - ) + latent_image_ids = Tensor_v3.from_dlpack(latent_image_ids).to(device).cast(dtype) return latent_image_ids @@ -366,9 +331,7 @@ def _unpack_latents( ) latents = F.permute(latents, (0, 3, 1, 4, 2, 5)) - latents = F.reshape( - latents, (batch_size.dim, channels.dim // (2 * 2), height, width) - ) + latents = F.reshape(latents, (batch_size.dim, channels.dim // (2 * 2), height, width)) return latents @@ -404,19 +367,13 @@ def prepare_latents( shape = (batch_size, num_channels_latents, height, width) if latents is not None: - latent_image_ids = self._prepare_latent_image_ids( - batch_size, height // 2, width // 2, device, dtype - ) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents.to(device).cast(dtype), latent_image_ids - latents = random.normal(shape, device=device, dtype=dtype) - latents = self._pack_latents( - latents, batch_size, num_channels_latents, height, width - ) + latents = normal(shape, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids( - batch_size, height // 2, width // 2, device, dtype - ) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents, latent_image_ids @@ -560,13 +517,10 @@ def __call__( device = self._execution_device() lora_scale = ( - self._joint_attention_kwargs.get("scale", None) - if self._joint_attention_kwargs is not None - else None + self._joint_attention_kwargs.get("scale", None) if self._joint_attention_kwargs is not None else None ) has_neg_prompt = negative_prompt is not None or ( - negative_prompt_embeds is not None - and negative_pooled_prompt_embeds is not None + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt ( @@ -612,15 +566,8 @@ def __call__( ) # 5. Prepare timesteps - sigmas = ( - np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - if sigmas is None - else sigmas - ) - if ( - hasattr(self.scheduler, "use_flow_sigmas") - and self.scheduler.use_flow_sigmas - ): + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler, "use_flow_sigmas") and self.scheduler.use_flow_sigmas: sigmas = None image_seq_len = latents.shape[1].dim mu = calculate_shift( @@ -661,9 +608,7 @@ def __call__( or negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None ): - raise NotImplementedError( - "IP adapter is not supported for Max yet." - ) + raise NotImplementedError("IP adapter is not supported for Max yet.") if self._joint_attention_kwargs is None: self._joint_attention_kwargs = {} @@ -683,9 +628,7 @@ def __call__( t = timesteps[i] self._current_timestep = t if image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = ( - image_embeds - ) + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds # NOTE: Convert timesteps to a Max Tensor before denoising loop, # as in the original implementation, results in a significant slow down. @@ -707,9 +650,7 @@ def __call__( if do_true_cfg: if negative_image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = ( - negative_image_embeds - ) + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds neg_noise_pred = self.transformer( latents, @@ -723,15 +664,11 @@ def __call__( # TODO: negative prompt path is very slow, need to optimize. noise_pred = Tensor_v3.from_dlpack(noise_pred) neg_noise_pred = Tensor_v3.from_dlpack(neg_noise_pred) - noise_pred = neg_noise_pred + true_cfg_scale * ( - noise_pred - neg_noise_pred - ) + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = self.scheduler.step( - noise_pred, t, latents, return_dict=False - )[0] + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: latents = latents.to(latents_dtype) @@ -740,14 +677,10 @@ def __call__( callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end( - self, i, t, callback_kwargs - ) + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop( - "prompt_embeds", prompt_embeds - ) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) self._current_timestep = None @@ -755,18 +688,12 @@ def __call__( image = latents else: latents = Tensor_v3.from_dlpack(latents) # V2 Tensor to V3 Tensor - latents = self._unpack_latents( - latents, height, width, self.vae_scale_factor - ) - latents = ( - latents / self.vae.config.scaling_factor - ) + self.vae.config.shift_factor + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents)[0] image = Tensor_v3.from_dlpack(image) # V2 Tensor to V3 Tensor - image = self.image_processor.postprocess( - image, output_type=output_type - ) + image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) diff --git a/max/python/max/pipelines/lib/config.py b/max/python/max/pipelines/lib/config.py index e1d0a3e1410..0a6402a0267 100644 --- a/max/python/max/pipelines/lib/config.py +++ b/max/python/max/pipelines/lib/config.py @@ -1153,7 +1153,9 @@ def log_basic_config(self) -> None: raise ValueError( "KVCache config is not available after config resolution." ) - memory_str = to_human_readable_bytes(kv_config._available_cache_memory) + memory_str = to_human_readable_bytes( + kv_config._available_cache_memory + ) devices_str = ", ".join( f"{d.device_type}[{d.id}]" for d in self.model.device_specs @@ -1178,25 +1180,63 @@ def log_basic_config(self) -> None: @staticmethod def help() -> dict[str, str]: return { - "max_length": "Set the maximum sequence length for input data processed by the model. This must be less than the value specified in the Hugging Face configuration file. The default is derived from the Hugging Face configuration value. Larger values may consume more memory.", - "pipeline_role": "Whether the pipeline should serve both a prefill or decode role or both.", - "max_batch_size": "Define the maximum batch size to execute with the model. When not specified (None), we determine this value dynamically. For users launching in a server scenario, the expectation is that this value should be set higher based on server capacity.", - "max_queue_size_tg": "Maximum number of requests in decode queue. By default, this is max-batch-size.", - "min_batch_size_tg": "Specifies a soft floor on the decode batch size. If the TG batch size is larger than this value, the scheduler will continue to run TG batches. If it falls below, the scheduler will prioritize CE. This is an experimental flag solely for the TTS scheduler.", - "ce_delay_ms": "Duration of scheduler sleep prior to starting a prefill batch. This is an experimental flag solely for the TTS scheduler. Default is 0.0.", - "enable_prioritize_first_decode": "When enabled, the scheduler will always run a TG batch immediately after a CE batch, with the same requests. This may be useful for decreasing time-to-first-chunk latency. This is an experimental flag solely for the TTS scheduler. Default is false.", - "experimental_background_queue": "When enabled, offloads queue draining to a background thread for improved performance. This is an experimental flag. Default is false.", - "enable_chunked_prefill": "Enable chunked prefill to split context encoding requests into multiple chunks based on `prefill-chunk-size`. Default is true.", - "enable_in_flight_batching": "When enabled, prioritizes token generation by batching it with context encoding requests. Default is false.", - "max_num_steps": "Specify the number of steps to run for multi-step scheduling during inference. Default is -1 which specifies a default value based on configuration and platform. Ignored for models which are not auto-regressive (e.g. embedding models).", - "prefill_chunk_size": "The target number of un-encoded tokens to include in each batch. This value is used for chunked prefill and memory estimation. Default is 8192.", - "enable_echo": "Whether the model should be built with echo capabilities. This defaults to false.", - "pool_embeddings": "Whether to pool embedding outputs. Default is true.", - "use_experimental_kernels": "Whether to use experimental kernels. Default is false.", - "max_batch_context_length": "Ensures that the sum of the context length in a batch does not exceed max_batch_context_length. If None, the sum of the context length in batch is not limited.", - "pdl_level": "Level of overlap of kernel launch via programmatic dependent grid control. Default is 0.", - "custom_architectures": "A list of custom architecture implementations to register. Each input can either be a raw module name or an import path followed by a colon and the module name.", - "kvcache_ce_watermark": "Projected cache usage threshold for scheduling CE requests, considers current + incoming request. CE is scheduled if either projected usage stays below this threshold OR no active requests exist. Greater KVCache utilization (as controlled by this parameter) was found to cause more preemptions. Default watermark value is 0.95.", + "max_length": ( + "Set the maximum sequence length for input data processed by the model. This must be less than the value specified in the Hugging Face configuration file. The default is derived from the Hugging Face configuration value. Larger values may consume more memory." + ), + "pipeline_role": ( + "Whether the pipeline should serve both a prefill or decode role or both." + ), + "max_batch_size": ( + "Define the maximum batch size to execute with the model. When not specified (None), we determine this value dynamically. For users launching in a server scenario, the expectation is that this value should be set higher based on server capacity." + ), + "max_queue_size_tg": ( + "Maximum number of requests in decode queue. By default, this is max-batch-size." + ), + "min_batch_size_tg": ( + "Specifies a soft floor on the decode batch size. If the TG batch size is larger than this value, the scheduler will continue to run TG batches. If it falls below, the scheduler will prioritize CE. This is an experimental flag solely for the TTS scheduler." + ), + "ce_delay_ms": ( + "Duration of scheduler sleep prior to starting a prefill batch. This is an experimental flag solely for the TTS scheduler. Default is 0.0." + ), + "enable_prioritize_first_decode": ( + "When enabled, the scheduler will always run a TG batch immediately after a CE batch, with the same requests. This may be useful for decreasing time-to-first-chunk latency. This is an experimental flag solely for the TTS scheduler. Default is false." + ), + "experimental_background_queue": ( + "When enabled, offloads queue draining to a background thread for improved performance. This is an experimental flag. Default is false." + ), + "enable_chunked_prefill": ( + "Enable chunked prefill to split context encoding requests into multiple chunks based on `prefill-chunk-size`. Default is true." + ), + "enable_in_flight_batching": ( + "When enabled, prioritizes token generation by batching it with context encoding requests. Default is false." + ), + "max_num_steps": ( + "Specify the number of steps to run for multi-step scheduling during inference. Default is -1 which specifies a default value based on configuration and platform. Ignored for models which are not auto-regressive (e.g. embedding models)." + ), + "prefill_chunk_size": ( + "The target number of un-encoded tokens to include in each batch. This value is used for chunked prefill and memory estimation. Default is 8192." + ), + "enable_echo": ( + "Whether the model should be built with echo capabilities. This defaults to false." + ), + "pool_embeddings": ( + "Whether to pool embedding outputs. Default is true." + ), + "use_experimental_kernels": ( + "Whether to use experimental kernels. Default is false." + ), + "max_batch_context_length": ( + "Ensures that the sum of the context length in a batch does not exceed max_batch_context_length. If None, the sum of the context length in batch is not limited." + ), + "pdl_level": ( + "Level of overlap of kernel launch via programmatic dependent grid control. Default is 0." + ), + "custom_architectures": ( + "A list of custom architecture implementations to register. Each input can either be a raw module name or an import path followed by a colon and the module name." + ), + "kvcache_ce_watermark": ( + "Projected cache usage threshold for scheduling CE requests, considers current + incoming request. CE is scheduled if either projected usage stays below this threshold OR no active requests exist. Greater KVCache utilization (as controlled by this parameter) was found to cause more preemptions. Default watermark value is 0.95." + ), } @property @@ -1361,15 +1401,33 @@ def help() -> dict[str, str]: # Add AudioGenerationConfig-specific fields audio_specific_help = { - "audio_decoder": "The name of the audio decoder model architecture.", - "audio_decoder_weights": "The path to the audio decoder weights file.", - "chunk_size": "The chunk sizes to use for streaming. If this is an int, then fixed-size chunks of the given size are used. If this is a list, then variable chunk sizes are used.", - "buffer": "The number of previous speech tokens to pass to the audio decoder on each generation step. Default is 0.", - "block_causal": "Whether prior buffered tokens should attend to tokens in the current block. Has no effect if buffer is not set. Default is false.", - "prepend_prompt_speech_tokens": "Whether the prompt speech tokens should be forwarded to the audio decoder. Options: 'never', 'once', 'rolling'. Default is 'once'.", - "prepend_prompt_speech_tokens_causal": "Whether the prompt speech tokens should attend to tokens in the currently generated audio block. Has no effect if prepend_prompt_speech_tokens is 'never'. Default is false.", - "audio_decoder_config": "Parameters to pass to the audio decoder model.", - "prometheus_metrics_mode": "The mode to use for Prometheus metrics. Default is 'instrument_only'.", + "audio_decoder": ( + "The name of the audio decoder model architecture." + ), + "audio_decoder_weights": ( + "The path to the audio decoder weights file." + ), + "chunk_size": ( + "The chunk sizes to use for streaming. If this is an int, then fixed-size chunks of the given size are used. If this is a list, then variable chunk sizes are used." + ), + "buffer": ( + "The number of previous speech tokens to pass to the audio decoder on each generation step. Default is 0." + ), + "block_causal": ( + "Whether prior buffered tokens should attend to tokens in the current block. Has no effect if buffer is not set. Default is false." + ), + "prepend_prompt_speech_tokens": ( + "Whether the prompt speech tokens should be forwarded to the audio decoder. Options: 'never', 'once', 'rolling'. Default is 'once'." + ), + "prepend_prompt_speech_tokens_causal": ( + "Whether the prompt speech tokens should attend to tokens in the currently generated audio block. Has no effect if prepend_prompt_speech_tokens is 'never'. Default is false." + ), + "audio_decoder_config": ( + "Parameters to pass to the audio decoder model." + ), + "prometheus_metrics_mode": ( + "The mode to use for Prometheus metrics. Default is 'instrument_only'." + ), } # Check for conflicts @@ -1442,3 +1500,25 @@ def from_flags( prometheus_metrics_mode=prometheus_metrics_mode, **config_flags, ) + + +class ImageGenerationConfig(PipelineConfig): + model_path: str = Field(default="") + """The repo id of the image generation model.""" + + @staticmethod + def help() -> dict[str, str]: + return { + "model": "The path to the image generation model.", + } + + @classmethod + def from_flags( + cls, flags: dict[str, str], **config_flags: Any + ) -> ImageGenerationConfig: + merged_configs = {**flags, **config_flags} + return cls(**merged_configs) + + def log_basic_config(self) -> None: + logger.info(f"model_info: {self.model}") + logger.info("") diff --git a/max/python/max/serve/api_server.py b/max/python/max/serve/api_server.py index 340e5961a25..96b5e40020d 100644 --- a/max/python/max/serve/api_server.py +++ b/max/python/max/serve/api_server.py @@ -29,6 +29,7 @@ from max.interfaces import PipelinesFactory, PipelineTask, PipelineTokenizer from max.pipelines.lib import PIPELINE_REGISTRY, PipelineConfig from max.serve.config import APIType, MetricRecordingMethod, Settings +from max.serve.pipelines.diffusion import ImageGeneratorPipeline from max.serve.pipelines.llm import ( AudioGeneratorPipeline, TokenGeneratorPipeline, @@ -106,35 +107,40 @@ async def lifespan( ) METRICS.configure(client=metric_client) - # start model worker - scheduler_zmq_configs = SchedulerZmqConfigs( - serving_settings.pipeline_task, - context_type=PIPELINE_REGISTRY.retrieve_context_type( - serving_settings.pipeline_config - ), - ) - worker_monitor = await exit_stack.enter_async_context( - start_model_worker( - serving_settings.model_factory, - serving_settings.pipeline_config, - settings, - metric_client, - scheduler_zmq_configs=scheduler_zmq_configs, + # Image generation uses a direct pipeline without model worker + if serving_settings.pipeline_task == PipelineTask.IMAGE_GENERATION: + scheduler_zmq_configs = None + lora_queue = None + else: + # start model worker + scheduler_zmq_configs = SchedulerZmqConfigs( + serving_settings.pipeline_task, + context_type=PIPELINE_REGISTRY.retrieve_context_type( + serving_settings.pipeline_config + ), + ) + worker_monitor = await exit_stack.enter_async_context( + start_model_worker( + serving_settings.model_factory, + serving_settings.pipeline_config, + settings, + metric_client, + scheduler_zmq_configs=scheduler_zmq_configs, + ) ) - ) - lora_queue: LoRAQueue | None = ( - LoRAQueue( - serving_settings.pipeline_config.zmq_endpoint_base, - serving_settings.pipeline_config.lora_config.lora_paths, + lora_queue = ( + LoRAQueue( + serving_settings.pipeline_config.zmq_endpoint_base, + serving_settings.pipeline_config.lora_config.lora_paths, + ) + if serving_settings.pipeline_config.lora_config + else None ) - if serving_settings.pipeline_config.lora_config - else None - ) METRICS.pipeline_load(serving_settings.pipeline_config.model.model_name) - pipeline: TokenGeneratorPipeline | AudioGeneratorPipeline + pipeline: TokenGeneratorPipeline | AudioGeneratorPipeline | ImageGeneratorPipeline if serving_settings.pipeline_task in ( PipelineTask.TEXT_GENERATION, PipelineTask.EMBEDDINGS_GENERATION, @@ -152,6 +158,12 @@ async def lifespan( lora_queue=lora_queue, scheduler_zmq_configs=scheduler_zmq_configs, ) + elif serving_settings.pipeline_task == PipelineTask.IMAGE_GENERATION: + # Image generation uses a simpler pipeline without scheduler + pipeline = ImageGeneratorPipeline( + model_name=serving_settings.pipeline_config.model.model_name, + pipeline_config=serving_settings.pipeline_config, + ) else: raise ValueError( f"Unsupported pipeline task: {serving_settings.pipeline_task}" diff --git a/max/python/max/serve/pipelines/diffusion.py b/max/python/max/serve/pipelines/diffusion.py new file mode 100644 index 00000000000..0f3e56ba2a7 --- /dev/null +++ b/max/python/max/serve/pipelines/diffusion.py @@ -0,0 +1,67 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Image generation pipeline for serving diffusion models.""" + +from __future__ import annotations + +import logging +from typing import Any + +from typing_extensions import Self + + +class ImageGeneratorPipeline: + """Pipeline wrapper for image generation. + + This is a simplified pipeline for image generation that doesn't use + the same streaming/batching infrastructure as text generation. + """ + + def __init__( + self, + model_name: str, + pipeline_config: Any, + ) -> None: + self.model_name = model_name + self.pipeline_config = pipeline_config + self._generator: Any | None = None + self.logger = logging.getLogger( + self.__class__.__module__ + "." + self.__class__.__qualname__ + ) + + async def __aenter__(self) -> Self: + """Initialize the image generator.""" + from max.entrypoints.diffusion import ImageGenerator + + self.logger.info("Loading image generator for model: %s", self.model_name) + self._generator = ImageGenerator(self.pipeline_config) + self.logger.info("Image generator loaded successfully") + return self + + async def __aexit__( + self, et: type[BaseException] | None, exc: BaseException | None, tb: Any + ) -> bool | None: + """Clean up resources.""" + if self._generator is not None: + del self._generator + self._generator = None + self.logger.info("Image generator pipeline closed: %s", self.model_name) + return None + + @property + def generator(self) -> Any: + """Get the underlying image generator.""" + if self._generator is None: + raise RuntimeError("Image generator not initialized. Use async with.") + return self._generator diff --git a/max/python/max/serve/router/openai_routes.py b/max/python/max/serve/router/openai_routes.py index 8eb356e2620..7bfc22d3446 100644 --- a/max/python/max/serve/router/openai_routes.py +++ b/max/python/max/serve/router/openai_routes.py @@ -43,6 +43,8 @@ from max.interfaces import ( AudioGenerationRequest, GenerationStatus, + ImageGenerationRequest, + ImageGenerationResponse, LoRAOperation, LoRARequest, LoRAStatus, @@ -61,6 +63,7 @@ from max.profiler import traced from max.serve.config import Settings from max.serve.parser import LlamaToolParser, parse_json_from_text +from max.serve.pipelines.diffusion import ImageGeneratorPipeline from max.serve.pipelines.llm import ( AudioGeneratorPipeline, TokenGeneratorOutput, @@ -1429,6 +1432,102 @@ async def create_streaming_audio_speech( raise HTTPException(status_code=400, detail="Value error.") from e +@router.post("/images/generations", response_model=None) +async def create_image_generation( + request: Request, +) -> JSONResponse: + """Image generation endpoint following OpenAI /v1/images/generations API. + + This endpoint generates images from text prompts using diffusion models. + """ + request_id = request.state.request_id + record_request_start() + stopwatch = StopWatch() + + try: + # Parse request body + body = await request.body() + import json as json_module + body_dict = json_module.loads(body) + + # Create ImageGenerationRequest from body + image_request = ImageGenerationRequest( + prompt=body_dict.get("prompt", ""), + model=body_dict.get("model"), + n=body_dict.get("n", 1), + size=body_dict.get("size", "1024x1024"), + quality=body_dict.get("quality", "standard"), + response_format=body_dict.get("response_format", "b64_json"), + style=body_dict.get("style"), + num_inference_steps=body_dict.get("num_inference_steps", 50), + guidance_scale=body_dict.get("guidance_scale", 3.5), + seed=body_dict.get("seed"), + ) + + # Get the pipeline + pipeline = request.app.state.pipeline + if not isinstance(pipeline, ImageGeneratorPipeline): + raise HTTPException( + status_code=400, + detail="This server is not configured for image generation. " + "Please start with --task image_generation.", + ) + + # Generate images using the pipeline's generator + response = pipeline.generator.create(image_request) + + record_request_end( + status_code=200, + request_path="/v1/images/generations", + elapsed_ms=stopwatch.elapsed_ms, + ) + + return JSONResponse(content=response.to_dict()) + + except JSONDecodeError as e: + logger.exception("JSONDecodeError in request %s", request_id) + record_request_end( + status_code=400, + request_path="/v1/images/generations", + elapsed_ms=stopwatch.elapsed_ms, + ) + raise HTTPException(status_code=400, detail="Missing JSON.") from e + except (TypeError, ValidationError) as e: + logger.exception("TypeError in request %s", request_id) + record_request_end( + status_code=400, + request_path="/v1/images/generations", + elapsed_ms=stopwatch.elapsed_ms, + ) + raise HTTPException(status_code=400, detail="Invalid JSON.") from e + except InputError as e: + logger.warning( + "Input validation error in request %s: %s", request_id, str(e) + ) + record_request_end( + status_code=400, + request_path="/v1/images/generations", + elapsed_ms=stopwatch.elapsed_ms, + ) + raise HTTPException(status_code=400, detail=str(e)) from e + except ValueError as e: + logger.exception("ValueError in request %s", request_id) + record_request_end( + status_code=400, + request_path="/v1/images/generations", + elapsed_ms=stopwatch.elapsed_ms, + ) + raise HTTPException(status_code=400, detail="Value error.") from e + except Exception as e: + logger.exception("Error in image generation request %s", request_id) + record_request_end( + status_code=500, + request_path="/v1/images/generations", + elapsed_ms=stopwatch.elapsed_ms, + ) + raise HTTPException(status_code=500, detail=str(e)) from e + + @router.post("/load_lora_adapter", response_model=None) async def load_lora_adapter( request: Request, diff --git a/max/python/max/serve/scheduler/__init__.py b/max/python/max/serve/scheduler/__init__.py index fbbb04cd511..5fc4071a08e 100644 --- a/max/python/max/serve/scheduler/__init__.py +++ b/max/python/max/serve/scheduler/__init__.py @@ -16,6 +16,7 @@ from max.interfaces import ( EmbeddingsContext, + ImageGenerationContext, MAXPullQueue, Pipeline, PipelineInputsType, @@ -45,6 +46,10 @@ from .config import TokenGenerationSchedulerConfig from .decode_scheduler import load_decode_scheduler from .embeddings_scheduler import EmbeddingsScheduler, EmbeddingsSchedulerConfig +from .image_generation_scheduler import ( + ImageGenerationScheduler, + ImageGenerationSchedulerConfig, +) from .prefill_scheduler import load_prefill_scheduler from .text_generation_scheduler import load_text_generation_scheduler @@ -54,6 +59,8 @@ "CancelRequest", "EmbeddingsScheduler", "EmbeddingsSchedulerConfig", + "ImageGenerationScheduler", + "ImageGenerationSchedulerConfig", "PrefillRequest", "PrefillResponse", "TokenGenerationSchedulerConfig", @@ -124,6 +131,28 @@ def load_scheduler( paged_manager=paged_manager, offload_queue_draining=pipeline_config.experimental_background_queue, ) + elif pipeline.__class__.__name__ == "ImageGenerationPipeline": + from max.pipelines.lib.pipeline_variants.image_generation import ( + ImageGenerationPipeline, + ) + + image_scheduler_config = ImageGenerationSchedulerConfig( + max_batch_size=pipeline_config.max_batch_size + if pipeline_config.max_batch_size is not None + else 1 + ) + image_pipeline = cast(ImageGenerationPipeline, pipeline) + return ImageGenerationScheduler( + scheduler_config=image_scheduler_config, + pipeline=image_pipeline, + request_queue=cast( + MAXPullQueue[ImageGenerationContext], + request_queue, + ), + response_queue=response_queue, + cancel_queue=cancel_queue, + offload_queue_draining=pipeline_config.experimental_background_queue, + ) elif pipeline_config.pipeline_role == PipelineRole.PrefillAndDecode: assert isinstance(pipeline, Pipeline) text_pipeline = cast( diff --git a/max/python/max/serve/scheduler/image_generation_scheduler.py b/max/python/max/serve/scheduler/image_generation_scheduler.py new file mode 100644 index 00000000000..004dfc6f47b --- /dev/null +++ b/max/python/max/serve/scheduler/image_generation_scheduler.py @@ -0,0 +1,123 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Scheduler for image generation pipelines.""" + +import logging +import queue +from dataclasses import dataclass + +from max.interfaces import ( + ImageGenerationContext, + ImageGenerationInputs, + ImageGenerationOutput, + MAXPullQueue, + MAXPushQueue, + RequestID, + Scheduler, + SchedulerResult, +) +from max.pipelines.lib.pipeline_variants.image_generation import ( + ImageGenerationPipeline, +) +from max.profiler import traced + +from .base import SchedulerProgress + +logger = logging.getLogger("max.serve") + + +@dataclass +class ImageGenerationSchedulerConfig: + """Image generation scheduler configuration.""" + + # The maximum number of requests that can be in the batch. + # For image generation, typically 1 since it's memory intensive. + max_batch_size: int = 1 + + +class ImageGenerationScheduler(Scheduler): + """Scheduler for image generation requests. + + This scheduler handles image generation requests one at a time, + as diffusion models are typically memory-intensive and don't + benefit from batching in the same way as text generation. + """ + + def __init__( + self, + scheduler_config: ImageGenerationSchedulerConfig, + pipeline: ImageGenerationPipeline, + request_queue: MAXPullQueue[ImageGenerationContext], + response_queue: MAXPushQueue[ + dict[RequestID, SchedulerResult[ImageGenerationOutput]] + ], + cancel_queue: MAXPullQueue[list[RequestID]], + offload_queue_draining: bool = False, + ) -> None: + self.scheduler_config = scheduler_config + self.pipeline = pipeline + self.request_queue = request_queue + self.response_queue = response_queue + self.cancel_queue = cancel_queue + # Note: offload_queue_draining is accepted for API compatibility + # but not used since image generation is sequential. + + @traced + def _get_next_request(self) -> ImageGenerationContext | None: + """Get the next request from the queue.""" + try: + return self.request_queue.get_nowait() + except queue.Empty: + return None + + def run_iteration(self) -> SchedulerProgress: + """Process one image generation request. + + Returns: + SchedulerProgress: Indicates whether work was performed. + """ + request = self._get_next_request() + if request is None: + return SchedulerProgress.NO_PROGRESS + + self._execute_request(request) + return SchedulerProgress.MADE_PROGRESS + + @traced + def _execute_request(self, request: ImageGenerationContext) -> None: + """Execute a single image generation request.""" + try: + # Create inputs from context + inputs = ImageGenerationInputs( + prompt=request.prompt, + height=request.height, + width=request.width, + num_inference_steps=request.num_inference_steps, + guidance_scale=request.guidance_scale, + num_images_per_prompt=request.num_images_per_prompt, + ) + + # Execute the pipeline + output: ImageGenerationOutput = self.pipeline.execute(inputs) + + # Send the response + self.response_queue.put_nowait( + {request.request_id: SchedulerResult.create(output)} + ) + except Exception as e: + logger.error(f"Error executing image generation: {e}") + # Send cancelled response on error + self.response_queue.put_nowait( + {request.request_id: SchedulerResult.cancelled()} + ) From d91fb4451573acf95f9657b405ba82d56a46bd49 Mon Sep 17 00:00:00 2001 From: jingulee Date: Mon, 19 Jan 2026 08:55:31 +0000 Subject: [PATCH 14/18] fix: temp params true_cfg_scale --- .gitignore | 1 + max/python/max/entrypoints/diffusion.py | 41 +++++++++++++++---- max/python/max/entrypoints/pipelines.py | 3 +- .../pipeline_variants/image_generation.py | 18 ++++---- 4 files changed, 46 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index 91b516aeb7e..284a1a8f97d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__/ *.py[cod] *$py.class +*.venv/ # C extensions *.so diff --git a/max/python/max/entrypoints/diffusion.py b/max/python/max/entrypoints/diffusion.py index c8219d74812..691d9bacf14 100644 --- a/max/python/max/entrypoints/diffusion.py +++ b/max/python/max/entrypoints/diffusion.py @@ -77,12 +77,13 @@ class _ImageRequest: id: RequestID prompts: Sequence[str] - height: int - width: int - num_inference_steps: int - guidance_scale: float - num_images_per_prompt: int - use_tqdm: bool + negative_prompts: Sequence[str] | None = None + height: int = 1024 + width: int = 1024 + num_inference_steps: int = 50 + guidance_scale: float = 3.5 + num_images_per_prompt: int = 1 + use_tqdm: bool = True @dataclass @@ -353,19 +354,43 @@ def _process_request( all_images: list[Image] = [] # Create iterator with optional progress bar - prompt_iter = request.prompts + if request.negative_prompts is None or len(request.prompts) != len( + request.negative_prompts + ): + if ( + request.negative_prompts is None + or len(request.negative_prompts) == 0 + ): + request.negative_prompts = [None] * len(request.prompts) + else: + raise ValueError( + "Number of prompts and negative prompts must be the same." + ) + + # TODO: temp hard coding for true cfg scale. Need to be removed. + if all( + negative_prompt is not None + for negative_prompt in request.negative_prompts + ): + true_cfg_scale = 1.0 + else: + true_cfg_scale = 4.0 + + prompt_iter = zip(request.prompts, request.negative_prompts) if request.use_tqdm: prompt_iter = tqdm.tqdm(prompt_iter, desc="Generating images") # Generate images for each prompt - for prompt in prompt_iter: + for prompt, negative_prompt in prompt_iter: inputs = ImageGenerationInputs( prompt=prompt, + negative_prompt=negative_prompt, height=request.height, width=request.width, num_inference_steps=request.num_inference_steps, guidance_scale=request.guidance_scale, num_images_per_prompt=request.num_images_per_prompt, + true_cfg_scale=true_cfg_scale, ) output: ImageGenerationOutput = pipeline.execute(inputs) diff --git a/max/python/max/entrypoints/pipelines.py b/max/python/max/entrypoints/pipelines.py index 1b7c11a25cb..620c617b641 100644 --- a/max/python/max/entrypoints/pipelines.py +++ b/max/python/max/entrypoints/pipelines.py @@ -506,7 +506,8 @@ def images_generate( from max.pipelines.lib.config import ImageGenerationConfig # Set random seed if provided - set_seed(seed) + if seed is not None: + set_seed(seed) """ TODO: diff --git a/max/python/max/interfaces/pipeline_variants/image_generation.py b/max/python/max/interfaces/pipeline_variants/image_generation.py index 2e3cd739701..493407e1760 100644 --- a/max/python/max/interfaces/pipeline_variants/image_generation.py +++ b/max/python/max/interfaces/pipeline_variants/image_generation.py @@ -96,13 +96,13 @@ class ImageGenerationInputs(PipelineInputs): """Inputs for image-generation pipelines.""" prompt: str - negative_prompt: str | None - true_cfg_scale: float - height: int - width: int - num_inference_steps: int - guidance_scale: float - num_images_per_prompt: int + negative_prompt: str | None = None + true_cfg_scale: float | None = None + height: int = 1024 + width: int = 1024 + num_inference_steps: int = 50 + guidance_scale: float = 3.5 + num_images_per_prompt: int = 1 @dataclass(kw_only=True) @@ -364,7 +364,9 @@ def to_dict(self) -> dict[str, Any]: if self.usage.input_tokens_details is not None: usage_dict["input_tokens_details"] = { "text_tokens": self.usage.input_tokens_details.text_tokens, - "image_tokens": self.usage.input_tokens_details.image_tokens, + "image_tokens": ( + self.usage.input_tokens_details.image_tokens + ), } result["usage"] = usage_dict From 1778973a025d5a6534678967b006c0721d09c045 Mon Sep 17 00:00:00 2001 From: jingulee Date: Mon, 19 Jan 2026 23:03:41 +0000 Subject: [PATCH 15/18] update: formatting --- max/examples/diffusion/README.md | 1 + max/examples/diffusion/client_example.py | 11 +- max/examples/diffusion/offline_generation.py | 4 +- max/examples/diffusion/openai_api_example.py | 12 +- .../cli/serve/serve_diffusion_api.py | 9 +- max/python/max/entrypoints/diffusion.py | 5 +- .../max/entrypoints/pipelines_diffusion.py | 13 ++ max/python/max/experimental/BUILD.bazel | 2 +- .../architectures/flux1/pipeline_flux.py | 145 +++++++++++++----- max/python/max/serve/api_server.py | 6 +- max/python/max/serve/pipelines/diffusion.py | 8 +- max/python/max/serve/router/openai_routes.py | 2 +- 12 files changed, 162 insertions(+), 56 deletions(-) diff --git a/max/examples/diffusion/README.md b/max/examples/diffusion/README.md index 96ae9f8c999..3e219bd1d8d 100644 --- a/max/examples/diffusion/README.md +++ b/max/examples/diffusion/README.md @@ -5,6 +5,7 @@ This directory contains examples for using the MAX diffusion pipeline for image ## Overview The MAX diffusion pipeline supports: + - **Offline generation**: Direct Python API for generating images - **OpenAI-compatible API**: Server with `/v1/images/generations` endpoint - **Multiple models**: FLUX.1-dev, and other diffusion models diff --git a/max/examples/diffusion/client_example.py b/max/examples/diffusion/client_example.py index 8b9345d0ba3..d16fb63535b 100644 --- a/max/examples/diffusion/client_example.py +++ b/max/examples/diffusion/client_example.py @@ -28,9 +28,8 @@ """ import base64 -from pathlib import Path - from argparse import ArgumentParser +from pathlib import Path parser = ArgumentParser() parser.add_argument("--port", type=int, default=8000) @@ -150,7 +149,7 @@ def example_with_httpx_async() -> None: print("Please install httpx: pip install httpx") return - async def generate_image(): + async def generate_image() -> None: async with httpx.AsyncClient(timeout=300.0) as client: base_url = f"http://localhost:{PORT}" @@ -206,7 +205,7 @@ def example_curl_commands() -> None: f""" curl http://localhost:{PORT}/v1/images/generations \\ -H "Content-Type: application/json" \\ -d '{ - "prompt": "A beautiful sunset over mountains", + "prompt": "A beautiful sunset over mountains", "size": "1024x1024", "n": 1, "response_format": "b64_json", @@ -229,7 +228,9 @@ def example_curl_commands() -> None: print("Diffusion API Client Examples") print("=" * 60) print("\nMake sure the server is running:") - print(f" max images serve --model black-forest-labs/FLUX.1-dev --port {PORT}") + print( + f" max images serve --model black-forest-labs/FLUX.1-dev --port {PORT}" + ) print() # Print curl examples first diff --git a/max/examples/diffusion/offline_generation.py b/max/examples/diffusion/offline_generation.py index 20749c2a244..f4e398891a4 100644 --- a/max/examples/diffusion/offline_generation.py +++ b/max/examples/diffusion/offline_generation.py @@ -21,7 +21,9 @@ def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument("--model-path", type=str, default="black-forest-labs/FLUX.1-dev") + parser.add_argument( + "--model-path", type=str, default="black-forest-labs/FLUX.1-dev" + ) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() diff --git a/max/examples/diffusion/openai_api_example.py b/max/examples/diffusion/openai_api_example.py index db7c4e9d514..4bfa31085b2 100644 --- a/max/examples/diffusion/openai_api_example.py +++ b/max/examples/diffusion/openai_api_example.py @@ -20,9 +20,9 @@ python openai_api_example.py --seed 42 --prompt "A futuristic city skyline at sunset with flying cars" --model-path "black-forest-labs/FLUX.1-dev" """ +import argparse import base64 import os -import argparse from pathlib import Path from max.entrypoints.diffusion import ImageGenerator @@ -34,8 +34,14 @@ def main() -> None: # Configure random seed for reproducibility parser = argparse.ArgumentParser() parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--prompt", type=str, default="A futuristic city skyline at sunset with flying cars") - parser.add_argument("--model-path", type=str, default="black-forest-labs/FLUX.1-dev") + parser.add_argument( + "--prompt", + type=str, + default="A futuristic city skyline at sunset with flying cars", + ) + parser.add_argument( + "--model-path", type=str, default="black-forest-labs/FLUX.1-dev" + ) args = parser.parse_args() seed = args.seed os.environ["SEED"] = str(seed) diff --git a/max/python/max/entrypoints/cli/serve/serve_diffusion_api.py b/max/python/max/entrypoints/cli/serve/serve_diffusion_api.py index 34cfadaad54..b71c778ddb7 100644 --- a/max/python/max/entrypoints/cli/serve/serve_diffusion_api.py +++ b/max/python/max/entrypoints/cli/serve/serve_diffusion_api.py @@ -28,7 +28,6 @@ import time from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import Any import uvloop from fastapi import FastAPI, HTTPException, Request @@ -171,7 +170,9 @@ async def create_image(request: Request) -> JSONResponse: logger.debug( "Processing image generation request: prompt=%r, size=%s, n=%d", - internal_request.prompt[:50] if len(internal_request.prompt) > 50 else internal_request.prompt, + internal_request.prompt[:50] + if len(internal_request.prompt) > 50 + else internal_request.prompt, internal_request.size, internal_request.n or 1, ) @@ -226,7 +227,9 @@ async def get_model(model_id: str, request: Request) -> JSONResponse: } ) - raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found") + raise HTTPException( + status_code=404, detail=f"Model '{model_id}' not found" + ) return app diff --git a/max/python/max/entrypoints/diffusion.py b/max/python/max/entrypoints/diffusion.py index 691d9bacf14..9f0d8950f0c 100644 --- a/max/python/max/entrypoints/diffusion.py +++ b/max/python/max/entrypoints/diffusion.py @@ -53,8 +53,6 @@ from typing import TYPE_CHECKING import tqdm -from PIL.Image import Image - from max.interfaces import ( ImageGenerationInputs, ImageGenerationOutput, @@ -64,6 +62,7 @@ RequestID, ) from max.pipelines.lib import PIPELINE_REGISTRY, PipelineConfig +from PIL.Image import Image if TYPE_CHECKING: from max.pipelines.lib.pipeline_variants.image_generation import ( @@ -376,7 +375,7 @@ def _process_request( else: true_cfg_scale = 4.0 - prompt_iter = zip(request.prompts, request.negative_prompts) + prompt_iter = zip(request.prompts, request.negative_prompts, strict=False) if request.use_tqdm: prompt_iter = tqdm.tqdm(prompt_iter, desc="Generating images") diff --git a/max/python/max/entrypoints/pipelines_diffusion.py b/max/python/max/entrypoints/pipelines_diffusion.py index 1fec7400445..e2d9dcb06b5 100644 --- a/max/python/max/entrypoints/pipelines_diffusion.py +++ b/max/python/max/entrypoints/pipelines_diffusion.py @@ -1,3 +1,16 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + """Diffusion-only CLI wrapper. This exists so Bazel can keep `//max/python/max/entrypoints:pipelines` lean, diff --git a/max/python/max/experimental/BUILD.bazel b/max/python/max/experimental/BUILD.bazel index 946a73ca810..575bbbc0a09 100644 --- a/max/python/max/experimental/BUILD.bazel +++ b/max/python/max/experimental/BUILD.bazel @@ -8,8 +8,8 @@ modular_py_library( "__init__.py", "_passes.py", "_tensor_repr.py", - "functional.py", "compile_utils.py", + "functional.py", "random.py", "realization_context.py", "support.py", diff --git a/max/python/max/pipelines/architectures/flux1/pipeline_flux.py b/max/python/max/pipelines/architectures/flux1/pipeline_flux.py index d0e830dbdba..a518f66ef4a 100644 --- a/max/python/max/pipelines/architectures/flux1/pipeline_flux.py +++ b/max/python/max/pipelines/architectures/flux1/pipeline_flux.py @@ -34,7 +34,6 @@ from max.pipelines.lib.interfaces.diffusion_pipeline import ( DiffusionPipeline, ) -from max.experimental.realization_context import set_seed from tqdm import tqdm from transformers import ( CLIPTokenizer, @@ -81,9 +80,13 @@ def retrieve_timesteps( second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -93,7 +96,9 @@ def retrieve_timesteps( timesteps = scheduler.timesteps num_inference_steps = int(timesteps.shape[0]) elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -150,9 +155,17 @@ class FluxPipeline(DiffusionPipeline): } def init_remaining_components(self) -> None: - image_processor_class = self.components.get("image_processor", VaeImageProcessor) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 - image_processor = image_processor_class(vae_scale_factor=self.vae_scale_factor * 2) + image_processor_class = self.components.get( + "image_processor", VaeImageProcessor + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) + if getattr(self, "vae", None) + else 8 + ) + image_processor = image_processor_class( + vae_scale_factor=self.vae_scale_factor * 2 + ) self.image_processor = image_processor def encode_prompt( @@ -191,9 +204,13 @@ def encode_prompt( if lora_scale is not None and isinstance(self, FluxPipeline): self._lora_scale = lora_scale - if self.text_encoder is not None and hasattr(self.text_encoder, "set_lora_scale"): + if self.text_encoder is not None and hasattr( + self.text_encoder, "set_lora_scale" + ): self.text_encoder.set_lora_scale(lora_scale) - if self.text_encoder_2 is not None and hasattr(self.text_encoder_2, "set_lora_scale"): + if self.text_encoder_2 is not None and hasattr( + self.text_encoder_2, "set_lora_scale" + ): self.text_encoder_2.set_lora_scale(lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -202,12 +219,16 @@ def encode_prompt( text_inputs = self.tokenizer( prompt, padding="max_length", - max_length=min(max_sequence_length, self.tokenizer.model_max_length), + max_length=min( + max_sequence_length, self.tokenizer.model_max_length + ), truncation=True, return_length=False, return_overflowing_tokens=False, ) - text_input_ids = Tensor_v3.constant(text_inputs.input_ids, device=device, dtype=DType.int64) + text_input_ids = Tensor_v3.constant( + text_inputs.input_ids, device=device, dtype=DType.int64 + ) text_encoder_outputs = self.text_encoder(text_input_ids) prompt_embeds = text_encoder_outputs[0] @@ -225,7 +246,9 @@ def encode_prompt( return_length=False, return_overflowing_tokens=False, ) - text_input_ids_2 = Tensor_v3.constant(text_inputs_2.input_ids, device=device, dtype=DType.int64) + text_input_ids_2 = Tensor_v3.constant( + text_inputs_2.input_ids, device=device, dtype=DType.int64 + ) prompt_embeds_2 = self.text_encoder_2(text_input_ids_2)[0] else: @@ -241,14 +264,24 @@ def encode_prompt( ) bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = Tensor_v3.from_dlpack(prompt_embeds) # V2 Tensor to V3 Tensor - pooled_prompt_embeds = Tensor_v3.from_dlpack(pooled_prompt_embeds) # V2 Tensor to V3 Tensor + prompt_embeds = Tensor_v3.from_dlpack( + prompt_embeds + ) # V2 Tensor to V3 Tensor + pooled_prompt_embeds = Tensor_v3.from_dlpack( + pooled_prompt_embeds + ) # V2 Tensor to V3 Tensor prompt_embeds = F.tile(prompt_embeds, (1, num_images_per_prompt, 1)) - prompt_embeds = prompt_embeds.reshape((bs_embed * num_images_per_prompt, seq_len, -1)) + prompt_embeds = prompt_embeds.reshape( + (bs_embed * num_images_per_prompt, seq_len, -1) + ) - pooled_prompt_embeds = F.tile(pooled_prompt_embeds, (1, num_images_per_prompt)) - pooled_prompt_embeds = pooled_prompt_embeds.reshape((bs_embed * num_images_per_prompt, -1)) + pooled_prompt_embeds = F.tile( + pooled_prompt_embeds, (1, num_images_per_prompt) + ) + pooled_prompt_embeds = pooled_prompt_embeds.reshape( + (bs_embed * num_images_per_prompt, -1) + ) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -282,7 +315,9 @@ def _prepare_latent_image_ids( latent_image_id_channels, ), ) - latent_image_ids = Tensor_v3.from_dlpack(latent_image_ids).to(device).cast(dtype) + latent_image_ids = ( + Tensor_v3.from_dlpack(latent_image_ids).to(device).cast(dtype) + ) return latent_image_ids @@ -331,7 +366,9 @@ def _unpack_latents( ) latents = F.permute(latents, (0, 3, 1, 4, 2, 5)) - latents = F.reshape(latents, (batch_size.dim, channels.dim // (2 * 2), height, width)) + latents = F.reshape( + latents, (batch_size.dim, channels.dim // (2 * 2), height, width) + ) return latents @@ -367,13 +404,19 @@ def prepare_latents( shape = (batch_size, num_channels_latents, height, width) if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + latent_image_ids = self._prepare_latent_image_ids( + batch_size, height // 2, width // 2, device, dtype + ) return latents.to(device).cast(dtype), latent_image_ids latents = normal(shape, device=device, dtype=dtype) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents( + latents, batch_size, num_channels_latents, height, width + ) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + latent_image_ids = self._prepare_latent_image_ids( + batch_size, height // 2, width // 2, device, dtype + ) return latents, latent_image_ids @@ -517,10 +560,13 @@ def __call__( device = self._execution_device() lora_scale = ( - self._joint_attention_kwargs.get("scale", None) if self._joint_attention_kwargs is not None else None + self._joint_attention_kwargs.get("scale", None) + if self._joint_attention_kwargs is not None + else None ) has_neg_prompt = negative_prompt is not None or ( - negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + negative_prompt_embeds is not None + and negative_pooled_prompt_embeds is not None ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt ( @@ -566,8 +612,15 @@ def __call__( ) # 5. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas - if hasattr(self.scheduler, "use_flow_sigmas") and self.scheduler.use_flow_sigmas: + sigmas = ( + np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + if sigmas is None + else sigmas + ) + if ( + hasattr(self.scheduler, "use_flow_sigmas") + and self.scheduler.use_flow_sigmas + ): sigmas = None image_seq_len = latents.shape[1].dim mu = calculate_shift( @@ -608,7 +661,9 @@ def __call__( or negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None ): - raise NotImplementedError("IP adapter is not supported for Max yet.") + raise NotImplementedError( + "IP adapter is not supported for Max yet." + ) if self._joint_attention_kwargs is None: self._joint_attention_kwargs = {} @@ -628,7 +683,9 @@ def __call__( t = timesteps[i] self._current_timestep = t if image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + self._joint_attention_kwargs["ip_adapter_image_embeds"] = ( + image_embeds + ) # NOTE: Convert timesteps to a Max Tensor before denoising loop, # as in the original implementation, results in a significant slow down. @@ -650,7 +707,9 @@ def __call__( if do_true_cfg: if negative_image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + self._joint_attention_kwargs["ip_adapter_image_embeds"] = ( + negative_image_embeds + ) neg_noise_pred = self.transformer( latents, @@ -664,11 +723,15 @@ def __call__( # TODO: negative prompt path is very slow, need to optimize. noise_pred = Tensor_v3.from_dlpack(noise_pred) neg_noise_pred = Tensor_v3.from_dlpack(neg_noise_pred) - noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + noise_pred = neg_noise_pred + true_cfg_scale * ( + noise_pred - neg_noise_pred + ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False + )[0] if latents.dtype != latents_dtype: latents = latents.to(latents_dtype) @@ -677,10 +740,14 @@ def __call__( callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs + ) latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds + ) self._current_timestep = None @@ -688,12 +755,18 @@ def __call__( image = latents else: latents = Tensor_v3.from_dlpack(latents) # V2 Tensor to V3 Tensor - latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) - latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + latents = self._unpack_latents( + latents, height, width, self.vae_scale_factor + ) + latents = ( + latents / self.vae.config.scaling_factor + ) + self.vae.config.shift_factor image = self.vae.decode(latents)[0] image = Tensor_v3.from_dlpack(image) # V2 Tensor to V3 Tensor - image = self.image_processor.postprocess(image, output_type=output_type) + image = self.image_processor.postprocess( + image, output_type=output_type + ) if not return_dict: return (image,) diff --git a/max/python/max/serve/api_server.py b/max/python/max/serve/api_server.py index 96b5e40020d..c175ecc5edb 100644 --- a/max/python/max/serve/api_server.py +++ b/max/python/max/serve/api_server.py @@ -140,7 +140,11 @@ async def lifespan( METRICS.pipeline_load(serving_settings.pipeline_config.model.model_name) - pipeline: TokenGeneratorPipeline | AudioGeneratorPipeline | ImageGeneratorPipeline + pipeline: ( + TokenGeneratorPipeline + | AudioGeneratorPipeline + | ImageGeneratorPipeline + ) if serving_settings.pipeline_task in ( PipelineTask.TEXT_GENERATION, PipelineTask.EMBEDDINGS_GENERATION, diff --git a/max/python/max/serve/pipelines/diffusion.py b/max/python/max/serve/pipelines/diffusion.py index 0f3e56ba2a7..76601d0a226 100644 --- a/max/python/max/serve/pipelines/diffusion.py +++ b/max/python/max/serve/pipelines/diffusion.py @@ -44,7 +44,9 @@ async def __aenter__(self) -> Self: """Initialize the image generator.""" from max.entrypoints.diffusion import ImageGenerator - self.logger.info("Loading image generator for model: %s", self.model_name) + self.logger.info( + "Loading image generator for model: %s", self.model_name + ) self._generator = ImageGenerator(self.pipeline_config) self.logger.info("Image generator loaded successfully") return self @@ -63,5 +65,7 @@ async def __aexit__( def generator(self) -> Any: """Get the underlying image generator.""" if self._generator is None: - raise RuntimeError("Image generator not initialized. Use async with.") + raise RuntimeError( + "Image generator not initialized. Use async with." + ) return self._generator diff --git a/max/python/max/serve/router/openai_routes.py b/max/python/max/serve/router/openai_routes.py index 7bfc22d3446..a1ba8b8c0d7 100644 --- a/max/python/max/serve/router/openai_routes.py +++ b/max/python/max/serve/router/openai_routes.py @@ -44,7 +44,6 @@ AudioGenerationRequest, GenerationStatus, ImageGenerationRequest, - ImageGenerationResponse, LoRAOperation, LoRARequest, LoRAStatus, @@ -1448,6 +1447,7 @@ async def create_image_generation( # Parse request body body = await request.body() import json as json_module + body_dict = json_module.loads(body) # Create ImageGenerationRequest from body From af8b4a2909f448320927a318360a054730bc2997 Mon Sep 17 00:00:00 2001 From: jingulee Date: Mon, 19 Jan 2026 23:08:16 +0000 Subject: [PATCH 16/18] fix: rollback architectures --- max/examples/diffusion/README.md | 219 ------------------ .../max/pipelines/architectures/flux1/arch.py | 1 + .../architectures/flux1/pipeline_flux.py | 4 +- 3 files changed, 3 insertions(+), 221 deletions(-) delete mode 100644 max/examples/diffusion/README.md diff --git a/max/examples/diffusion/README.md b/max/examples/diffusion/README.md deleted file mode 100644 index 3e219bd1d8d..00000000000 --- a/max/examples/diffusion/README.md +++ /dev/null @@ -1,219 +0,0 @@ -# Diffusion Image Generation Examples - -This directory contains examples for using the MAX diffusion pipeline for image generation. - -## Overview - -The MAX diffusion pipeline supports: - -- **Offline generation**: Direct Python API for generating images -- **OpenAI-compatible API**: Server with `/v1/images/generations` endpoint -- **Multiple models**: FLUX.1-dev, and other diffusion models - -## Examples - -### 1. Offline Generation (`offline_generation.py`) - -Basic example using the `ImageGenerator` directly: - -```bash -python offline_generation.py -``` - -```python -from max.entrypoints.diffusion import ImageGenerator -from max.pipelines import PipelineConfig - -config = PipelineConfig(model_path="black-forest-labs/FLUX.1-dev") -generator = ImageGenerator(config) - -# Generate returns a list of PIL Images -images = generator.generate( - "A cat holding a sign that says hello world", - height=1024, - width=1024, - num_inference_steps=50, - guidance_scale=3.5, -) -images[0].save("output.png") -``` - -### 2. OpenAI API Example (`openai_api_example.py`) - -Using the OpenAI-compatible `ImageGenerationRequest`: - -```bash -python openai_api_example.py -``` - -```python -from max.entrypoints.diffusion import ImageGenerator -from max.interfaces import ImageGenerationRequest -from max.pipelines import PipelineConfig - -config = PipelineConfig(model_path="black-forest-labs/FLUX.1-dev") -generator = ImageGenerator(config) - -# Use OpenAI-compatible request format -request = ImageGenerationRequest( - prompt="A futuristic city skyline at sunset", - size="1024x1024", - n=1, - response_format="b64_json", - num_inference_steps=30, - guidance_scale=3.5, -) - -# create() returns an OpenAI-compatible response -response = generator.create(request) -# response.data[0].b64_json contains the base64-encoded image -``` - -### 3. Client Example (`client_example.py`) - -Connecting to the OpenAI-compatible server: - -```bash -# Start the server -max images serve --model black-forest-labs/FLUX.1-dev --port 8000 - -# Run the client -python client_example.py -``` - -## CLI Commands - -### Generate Images - -```bash -# Basic generation -max images generate \ - --model black-forest-labs/FLUX.1-dev \ - --prompt "A beautiful sunset over mountains" \ - --size 1024x1024 \ - --output output.png - -# With custom parameters -max images generate \ - --model black-forest-labs/FLUX.1-dev \ - --prompt "A cyberpunk city" \ - --size 1792x1024 \ - --num-inference-steps 50 \ - --guidance-scale 7.5 \ - --seed 42 \ - --output landscape.png -``` - -### Start Server - -```bash -# Start OpenAI-compatible API server -max images serve \ - --model black-forest-labs/FLUX.1-dev \ - --port 8000 -``` - -## API Reference - -### Request Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `prompt` | string | Required | Text description of the desired image | -| `model` | string | null | Model to use for generation | -| `n` | integer | 1 | Number of images to generate (1-10) | -| `size` | string | "1024x1024" | Image size (e.g., "1024x1024", "1792x1024") | -| `quality` | string | "standard" | Image quality | -| `response_format` | string | "b64_json" | Response format ("url" or "b64_json") | -| `output_format` | string | "png" | Output format ("png", "jpeg", "webp") | -| `num_inference_steps` | integer | 50 | Number of denoising steps | -| `guidance_scale` | float | 3.5 | Classifier-free guidance scale | -| `seed` | integer | null | Random seed for reproducibility | - -### Response Format - -```json -{ - "created": 1713833628, - "data": [ - { - "b64_json": "iVBORw0KGgo...", - "revised_prompt": "A beautiful sunset over mountains" - } - ] -} -``` - -## curl Examples - -### Generate Image - -```bash -curl http://localhost:8000/v1/images/generations \ - -H "Content-Type: application/json" \ - -d '{ - "prompt": "A beautiful sunset over mountains", - "size": "1024x1024", - "n": 1, - "response_format": "b64_json", - "num_inference_steps": 30, - "guidance_scale": 3.5 - }' -``` - -### Generate and Save - -```bash -curl -s http://localhost:8000/v1/images/generations \ - -H "Content-Type: application/json" \ - -d '{"prompt": "A cat", "size": "512x512"}' \ - | jq -r '.data[0].b64_json' | base64 -d > output.png -``` - -### List Models - -```bash -curl http://localhost:8000/v1/models -``` - -### Health Check - -```bash -curl http://localhost:8000/health -``` - -## Using with OpenAI Python Client - -```python -from openai import OpenAI - -client = OpenAI( - base_url="http://localhost:8000/v1", - api_key="not-needed", -) - -response = client.images.generate( - model="black-forest-labs/FLUX.1-dev", - prompt="A majestic dragon flying over a castle", - size="1024x1024", - n=1, - response_format="b64_json", -) - -# Save the image -import base64 -image_bytes = base64.b64decode(response.data[0].b64_json) -with open("dragon.png", "wb") as f: - f.write(image_bytes) -``` - -## Supported Models - -- `black-forest-labs/FLUX.1-dev` - Flux 1 Dev - -## Environment Variables - -| Variable | Description | -|----------|-------------| -| `USE_TORCH_RANDN` | Set to "1" to use torch-based random latents | -| `SEED` | Random seed for reproducibility | diff --git a/max/python/max/pipelines/architectures/flux1/arch.py b/max/python/max/pipelines/architectures/flux1/arch.py index 7dac24af036..aea17022398 100644 --- a/max/python/max/pipelines/architectures/flux1/arch.py +++ b/max/python/max/pipelines/architectures/flux1/arch.py @@ -29,6 +29,7 @@ supported_encodings={SupportedEncoding.bfloat16: []}, example_repo_ids=[ "black-forest-labs/FLUX.1-dev", + "black-forest-labs/FLUX.1-schnell", ], pipeline_model=FluxPipeline, tokenizer=TextTokenizer, diff --git a/max/python/max/pipelines/architectures/flux1/pipeline_flux.py b/max/python/max/pipelines/architectures/flux1/pipeline_flux.py index a518f66ef4a..85298caf20e 100644 --- a/max/python/max/pipelines/architectures/flux1/pipeline_flux.py +++ b/max/python/max/pipelines/architectures/flux1/pipeline_flux.py @@ -22,7 +22,7 @@ from max.dtype import DType from max.experimental import Tensor as Tensor_v3 from max.experimental import functional as F -from max.experimental.random import normal +from max.experimental import random from max.graph import DeviceRef from max.pipelines.lib.diffusion_schedulers import ( FlowMatchEulerDiscreteScheduler, @@ -409,7 +409,7 @@ def prepare_latents( ) return latents.to(device).cast(dtype), latent_image_ids - latents = normal(shape, device=device, dtype=dtype) + latents = random.normal(shape, device=device, dtype=dtype) latents = self._pack_latents( latents, batch_size, num_channels_latents, height, width ) From 7f9cbac0d3bc4654ece44344da21e3f0847e329f Mon Sep 17 00:00:00 2001 From: jingulee Date: Tue, 20 Jan 2026 00:24:10 +0000 Subject: [PATCH 17/18] fix: formatting --- max/python/max/entrypoints/pipelines.py | 5 - max/python/max/pipelines/lib/config.py | 118 ++++++------------------ 2 files changed, 30 insertions(+), 93 deletions(-) diff --git a/max/python/max/entrypoints/pipelines.py b/max/python/max/entrypoints/pipelines.py index 620c617b641..571d8516fe9 100644 --- a/max/python/max/entrypoints/pipelines.py +++ b/max/python/max/entrypoints/pipelines.py @@ -385,11 +385,6 @@ def cli_pipeline( ) -# ============================================================================ -# Images Group (OpenAI-compatible /v1/images/* endpoints) -# ============================================================================ - - @main.group(name="images", cls=ModelGroup) def images_group() -> None: """Commands for image generation (OpenAI-compatible /v1/images/* API).""" diff --git a/max/python/max/pipelines/lib/config.py b/max/python/max/pipelines/lib/config.py index 0a6402a0267..098fae555f0 100644 --- a/max/python/max/pipelines/lib/config.py +++ b/max/python/max/pipelines/lib/config.py @@ -1153,9 +1153,7 @@ def log_basic_config(self) -> None: raise ValueError( "KVCache config is not available after config resolution." ) - memory_str = to_human_readable_bytes( - kv_config._available_cache_memory - ) + memory_str = to_human_readable_bytes(kv_config._available_cache_memory) devices_str = ", ".join( f"{d.device_type}[{d.id}]" for d in self.model.device_specs @@ -1180,63 +1178,25 @@ def log_basic_config(self) -> None: @staticmethod def help() -> dict[str, str]: return { - "max_length": ( - "Set the maximum sequence length for input data processed by the model. This must be less than the value specified in the Hugging Face configuration file. The default is derived from the Hugging Face configuration value. Larger values may consume more memory." - ), - "pipeline_role": ( - "Whether the pipeline should serve both a prefill or decode role or both." - ), - "max_batch_size": ( - "Define the maximum batch size to execute with the model. When not specified (None), we determine this value dynamically. For users launching in a server scenario, the expectation is that this value should be set higher based on server capacity." - ), - "max_queue_size_tg": ( - "Maximum number of requests in decode queue. By default, this is max-batch-size." - ), - "min_batch_size_tg": ( - "Specifies a soft floor on the decode batch size. If the TG batch size is larger than this value, the scheduler will continue to run TG batches. If it falls below, the scheduler will prioritize CE. This is an experimental flag solely for the TTS scheduler." - ), - "ce_delay_ms": ( - "Duration of scheduler sleep prior to starting a prefill batch. This is an experimental flag solely for the TTS scheduler. Default is 0.0." - ), - "enable_prioritize_first_decode": ( - "When enabled, the scheduler will always run a TG batch immediately after a CE batch, with the same requests. This may be useful for decreasing time-to-first-chunk latency. This is an experimental flag solely for the TTS scheduler. Default is false." - ), - "experimental_background_queue": ( - "When enabled, offloads queue draining to a background thread for improved performance. This is an experimental flag. Default is false." - ), - "enable_chunked_prefill": ( - "Enable chunked prefill to split context encoding requests into multiple chunks based on `prefill-chunk-size`. Default is true." - ), - "enable_in_flight_batching": ( - "When enabled, prioritizes token generation by batching it with context encoding requests. Default is false." - ), - "max_num_steps": ( - "Specify the number of steps to run for multi-step scheduling during inference. Default is -1 which specifies a default value based on configuration and platform. Ignored for models which are not auto-regressive (e.g. embedding models)." - ), - "prefill_chunk_size": ( - "The target number of un-encoded tokens to include in each batch. This value is used for chunked prefill and memory estimation. Default is 8192." - ), - "enable_echo": ( - "Whether the model should be built with echo capabilities. This defaults to false." - ), - "pool_embeddings": ( - "Whether to pool embedding outputs. Default is true." - ), - "use_experimental_kernels": ( - "Whether to use experimental kernels. Default is false." - ), - "max_batch_context_length": ( - "Ensures that the sum of the context length in a batch does not exceed max_batch_context_length. If None, the sum of the context length in batch is not limited." - ), - "pdl_level": ( - "Level of overlap of kernel launch via programmatic dependent grid control. Default is 0." - ), - "custom_architectures": ( - "A list of custom architecture implementations to register. Each input can either be a raw module name or an import path followed by a colon and the module name." - ), - "kvcache_ce_watermark": ( - "Projected cache usage threshold for scheduling CE requests, considers current + incoming request. CE is scheduled if either projected usage stays below this threshold OR no active requests exist. Greater KVCache utilization (as controlled by this parameter) was found to cause more preemptions. Default watermark value is 0.95." - ), + "max_length": "Set the maximum sequence length for input data processed by the model. This must be less than the value specified in the Hugging Face configuration file. The default is derived from the Hugging Face configuration value. Larger values may consume more memory.", + "pipeline_role": "Whether the pipeline should serve both a prefill or decode role or both.", + "max_batch_size": "Define the maximum batch size to execute with the model. When not specified (None), we determine this value dynamically. For users launching in a server scenario, the expectation is that this value should be set higher based on server capacity.", + "max_queue_size_tg": "Maximum number of requests in decode queue. By default, this is max-batch-size.", + "min_batch_size_tg": "Specifies a soft floor on the decode batch size. If the TG batch size is larger than this value, the scheduler will continue to run TG batches. If it falls below, the scheduler will prioritize CE. This is an experimental flag solely for the TTS scheduler.", + "ce_delay_ms": "Duration of scheduler sleep prior to starting a prefill batch. This is an experimental flag solely for the TTS scheduler. Default is 0.0.", + "enable_prioritize_first_decode": "When enabled, the scheduler will always run a TG batch immediately after a CE batch, with the same requests. This may be useful for decreasing time-to-first-chunk latency. This is an experimental flag solely for the TTS scheduler. Default is false.", + "experimental_background_queue": "When enabled, offloads queue draining to a background thread for improved performance. This is an experimental flag. Default is false.", + "enable_chunked_prefill": "Enable chunked prefill to split context encoding requests into multiple chunks based on `prefill-chunk-size`. Default is true.", + "enable_in_flight_batching": "When enabled, prioritizes token generation by batching it with context encoding requests. Default is false.", + "max_num_steps": "Specify the number of steps to run for multi-step scheduling during inference. Default is -1 which specifies a default value based on configuration and platform. Ignored for models which are not auto-regressive (e.g. embedding models).", + "prefill_chunk_size": "The target number of un-encoded tokens to include in each batch. This value is used for chunked prefill and memory estimation. Default is 8192.", + "enable_echo": "Whether the model should be built with echo capabilities. This defaults to false.", + "pool_embeddings": "Whether to pool embedding outputs. Default is true.", + "use_experimental_kernels": "Whether to use experimental kernels. Default is false.", + "max_batch_context_length": "Ensures that the sum of the context length in a batch does not exceed max_batch_context_length. If None, the sum of the context length in batch is not limited.", + "pdl_level": "Level of overlap of kernel launch via programmatic dependent grid control. Default is 0.", + "custom_architectures": "A list of custom architecture implementations to register. Each input can either be a raw module name or an import path followed by a colon and the module name.", + "kvcache_ce_watermark": "Projected cache usage threshold for scheduling CE requests, considers current + incoming request. CE is scheduled if either projected usage stays below this threshold OR no active requests exist. Greater KVCache utilization (as controlled by this parameter) was found to cause more preemptions. Default watermark value is 0.95.", } @property @@ -1401,33 +1361,15 @@ def help() -> dict[str, str]: # Add AudioGenerationConfig-specific fields audio_specific_help = { - "audio_decoder": ( - "The name of the audio decoder model architecture." - ), - "audio_decoder_weights": ( - "The path to the audio decoder weights file." - ), - "chunk_size": ( - "The chunk sizes to use for streaming. If this is an int, then fixed-size chunks of the given size are used. If this is a list, then variable chunk sizes are used." - ), - "buffer": ( - "The number of previous speech tokens to pass to the audio decoder on each generation step. Default is 0." - ), - "block_causal": ( - "Whether prior buffered tokens should attend to tokens in the current block. Has no effect if buffer is not set. Default is false." - ), - "prepend_prompt_speech_tokens": ( - "Whether the prompt speech tokens should be forwarded to the audio decoder. Options: 'never', 'once', 'rolling'. Default is 'once'." - ), - "prepend_prompt_speech_tokens_causal": ( - "Whether the prompt speech tokens should attend to tokens in the currently generated audio block. Has no effect if prepend_prompt_speech_tokens is 'never'. Default is false." - ), - "audio_decoder_config": ( - "Parameters to pass to the audio decoder model." - ), - "prometheus_metrics_mode": ( - "The mode to use for Prometheus metrics. Default is 'instrument_only'." - ), + "audio_decoder": "The name of the audio decoder model architecture.", + "audio_decoder_weights": "The path to the audio decoder weights file.", + "chunk_size": "The chunk sizes to use for streaming. If this is an int, then fixed-size chunks of the given size are used. If this is a list, then variable chunk sizes are used.", + "buffer": "The number of previous speech tokens to pass to the audio decoder on each generation step. Default is 0.", + "block_causal": "Whether prior buffered tokens should attend to tokens in the current block. Has no effect if buffer is not set. Default is false.", + "prepend_prompt_speech_tokens": "Whether the prompt speech tokens should be forwarded to the audio decoder. Options: 'never', 'once', 'rolling'. Default is 'once'.", + "prepend_prompt_speech_tokens_causal": "Whether the prompt speech tokens should attend to tokens in the currently generated audio block. Has no effect if prepend_prompt_speech_tokens is 'never'. Default is false.", + "audio_decoder_config": "Parameters to pass to the audio decoder model.", + "prometheus_metrics_mode": "The mode to use for Prometheus metrics. Default is 'instrument_only'.", } # Check for conflicts @@ -1521,4 +1463,4 @@ def from_flags( def log_basic_config(self) -> None: logger.info(f"model_info: {self.model}") - logger.info("") + logger.info("") \ No newline at end of file From ef1fa77f49a7d4046f773498be1d4f99f4350af7 Mon Sep 17 00:00:00 2001 From: jingulee Date: Tue, 20 Jan 2026 01:44:34 +0000 Subject: [PATCH 18/18] fix: client example --- max/examples/diffusion/client_example.py | 2 +- max/python/max/entrypoints/pipelines_diffusion.py | 2 +- max/python/max/serve/router/openai_routes.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/max/examples/diffusion/client_example.py b/max/examples/diffusion/client_example.py index d16fb63535b..dd3e3e46e91 100644 --- a/max/examples/diffusion/client_example.py +++ b/max/examples/diffusion/client_example.py @@ -46,7 +46,7 @@ def example_with_requests() -> None: # Check server health response = requests.get(f"{base_url}/health") - print(f"Server health: {response.json()}") + print(f"Server health: {response.status_code} OK") # List available models response = requests.get(f"{base_url}/v1/models") diff --git a/max/python/max/entrypoints/pipelines_diffusion.py b/max/python/max/entrypoints/pipelines_diffusion.py index e2d9dcb06b5..009bebd07e1 100644 --- a/max/python/max/entrypoints/pipelines_diffusion.py +++ b/max/python/max/entrypoints/pipelines_diffusion.py @@ -32,7 +32,7 @@ def main() -> None: pipelines_cli.main( prog_name="pipelines", - args=["images", *sys.argv[1:]], + args=[*sys.argv[1:]], ) diff --git a/max/python/max/serve/router/openai_routes.py b/max/python/max/serve/router/openai_routes.py index a1ba8b8c0d7..61076cbf306 100644 --- a/max/python/max/serve/router/openai_routes.py +++ b/max/python/max/serve/router/openai_routes.py @@ -1345,7 +1345,7 @@ async def openai_get_models(request: Request) -> ListModelsResponse: Model(id=pipeline.model_name, object="model", created=None, owned_by="") ] - if lora_queue := request.app.state.pipeline.lora_queue: + if lora_queue := getattr(request.app.state.pipeline, "lora_queue", None): model_list += [ Model(id=lora, object="model", created=None, owned_by="") for lora in lora_queue.list_loras()