diff --git a/.gitignore b/.gitignore index 91b516aeb7e..284a1a8f97d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__/ *.py[cod] *$py.class +*.venv/ # C extensions *.so diff --git a/max/examples/diffusion/client_example.py b/max/examples/diffusion/client_example.py new file mode 100644 index 00000000000..dd3e3e46e91 --- /dev/null +++ b/max/examples/diffusion/client_example.py @@ -0,0 +1,253 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Example: Client code for connecting to the diffusion API server. + +This example demonstrates how to connect to the OpenAI-compatible +image generation server using various client methods. + +Prerequisites: + 1. Start the server: + max images serve --model black-forest-labs/FLUX.1-dev --port 8000 + + 2. Run this client: + python client_example.py + +Dependencies: + pip install requests openai httpx +""" + +import base64 +from argparse import ArgumentParser +from pathlib import Path + +parser = ArgumentParser() +parser.add_argument("--port", type=int, default=8000) +args = parser.parse_args() + +PORT = args.port + + +def example_with_requests() -> None: + """Example using the requests library.""" + import requests + + base_url = f"http://localhost:{PORT}" + + # Check server health + response = requests.get(f"{base_url}/health") + print(f"Server health: {response.status_code} OK") + + # List available models + response = requests.get(f"{base_url}/v1/models") + models = response.json() + print(f"Available models: {[m['id'] for m in models['data']]}") + + # Generate an image + request_data = { + "prompt": "A beautiful sunset over the ocean with palm trees", + "size": "1024x1024", + "n": 1, + "response_format": "b64_json", + # Diffusion-specific parameters + "num_inference_steps": 30, + "guidance_scale": 3.5, + } + + print(f"\nGenerating image: {request_data['prompt']}") + response = requests.post( + f"{base_url}/v1/images/generations", + json=request_data, + headers={"Content-Type": "application/json"}, + ) + + if response.status_code == 200: + result = response.json() + print(f"Created at: {result['created']}") + + # Save the image + output_dir = Path("outputs/client") + output_dir.mkdir(parents=True, exist_ok=True) + + for i, img_data in enumerate(result["data"]): + if "b64_json" in img_data: + image_bytes = base64.b64decode(img_data["b64_json"]) + output_path = output_dir / f"requests_output_{i}.png" + with open(output_path, "wb") as f: + f.write(image_bytes) + print(f"Image saved to: {output_path}") + else: + print(f"Error: {response.status_code} - {response.text}") + + +def example_with_openai_client() -> None: + """Example using the official OpenAI Python client. + + Note: The OpenAI client can connect to any OpenAI-compatible server. + """ + try: + from openai import OpenAI + except ImportError: + print("Please install openai: pip install openai") + return + + # Point to local server + client = OpenAI( + base_url=f"http://localhost:{PORT}/v1", + api_key="not-needed", # API key not required for local server + ) + + # List models + models = client.models.list() + print(f"Available models: {[m.id for m in models.data]}") + + # Generate an image + print("\nGenerating image with OpenAI client...") + + # Note: The OpenAI client's images.generate() may not support + # all diffusion-specific parameters. Use raw HTTP for full control. + response = client.images.generate( + model="black-forest-labs/FLUX.1-dev", + prompt="A majestic dragon flying over a medieval castle", + size="1024x1024", + n=1, + response_format="b64_json", + ) + + print(f"Created at: {response.created}") + + # Save the image + output_dir = Path("outputs/client") + output_dir.mkdir(parents=True, exist_ok=True) + + for i, img_data in enumerate(response.data): + if img_data.b64_json: + image_bytes = base64.b64decode(img_data.b64_json) + output_path = output_dir / f"openai_client_output_{i}.png" + with open(output_path, "wb") as f: + f.write(image_bytes) + print(f"Image saved to: {output_path}") + + +def example_with_httpx_async() -> None: + """Example using httpx for async requests.""" + import asyncio + + try: + import httpx + except ImportError: + print("Please install httpx: pip install httpx") + return + + async def generate_image() -> None: + async with httpx.AsyncClient(timeout=300.0) as client: + base_url = f"http://localhost:{PORT}" + + # Generate image + request_data = { + "prompt": "A cyberpunk street scene at night with neon lights", + "size": "1024x1024", + "n": 1, + "response_format": "b64_json", + "num_inference_steps": 30, + "guidance_scale": 3.5, + } + + print(f"Generating: {request_data['prompt']}") + response = await client.post( + f"{base_url}/v1/images/generations", + json=request_data, + ) + + if response.status_code == 200: + result = response.json() + + output_dir = Path("outputs/client") + output_dir.mkdir(parents=True, exist_ok=True) + + for i, img_data in enumerate(result["data"]): + if "b64_json" in img_data: + image_bytes = base64.b64decode(img_data["b64_json"]) + output_path = output_dir / f"httpx_output_{i}.png" + with open(output_path, "wb") as f: + f.write(image_bytes) + print(f"Image saved to: {output_path}") + else: + print(f"Error: {response.status_code}") + + asyncio.run(generate_image()) + + +def example_curl_commands() -> None: + """Print curl commands for reference.""" + print("\n" + "=" * 60) + print("CURL COMMAND EXAMPLES") + print("=" * 60) + + print("\n1. Check server health:") + print(f" curl http://localhost:{PORT}/health") + + print("\n2. List available models:") + print(f" curl http://localhost:{PORT}/v1/models") + + print("\n3. Generate an image:") + print( + f""" curl http://localhost:{PORT}/v1/images/generations \\ + -H "Content-Type: application/json" \\ + -d '{ + "prompt": "A beautiful sunset over mountains", + "size": "1024x1024", + "n": 1, + "response_format": "b64_json", + "num_inference_steps": 30, + "guidance_scale": 3.5 + }'""" + ) + + print("\n4. Generate and save image (with jq):") + print( + f""" curl -s http://localhost:{PORT}/v1/images/generations \\ + -H "Content-Type: application/json" \\ + -d '{"prompt": "A cat", "size": "512x512"}' \\ + | jq -r '.data[0].b64_json' | base64 -d > output.png""" + ) + + +if __name__ == "__main__": + print("=" * 60) + print("Diffusion API Client Examples") + print("=" * 60) + print("\nMake sure the server is running:") + print( + f" max images serve --model black-forest-labs/FLUX.1-dev --port {PORT}" + ) + print() + + # Print curl examples first + # example_curl_commands() + + # Uncomment to run the examples: + print("\n" + "=" * 60) + print("Running requests example...") + print("=" * 60) + example_with_requests() + + print("\n" + "=" * 60) + print("Running OpenAI client example...") + print("=" * 60) + example_with_openai_client() + + print("\n" + "=" * 60) + print("Running httpx async example...") + print("=" * 60) + example_with_httpx_async() diff --git a/max/examples/diffusion/offline_generation.py b/max/examples/diffusion/offline_generation.py new file mode 100644 index 00000000000..f4e398891a4 --- /dev/null +++ b/max/examples/diffusion/offline_generation.py @@ -0,0 +1,55 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import argparse +from pathlib import Path + +from max.entrypoints.diffusion import ImageGenerator +from max.experimental.realization_context import set_seed +from max.pipelines import PipelineConfig + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-path", type=str, default="black-forest-labs/FLUX.1-dev" + ) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + model_path = args.model_path + set_seed(args.seed) + pipeline_config = PipelineConfig(model_path=model_path) + pipe = ImageGenerator(pipeline_config) + + prompt = "A cat holding a sign that says hello world" + print(f"Prompt: {prompt}") + print(f"Seed: {args.seed}") + + # Generate images using the new API + images = pipe.generate( + prompt, + height=1024, + width=1024, + num_inference_steps=28, + guidance_scale=3.5, + ) + + output_path = Path("output.png") + output_path.parent.mkdir(parents=True, exist_ok=True) + images[0].save(output_path) + print(f"Image saved to: {output_path}") + + +if __name__ == "__main__": + main() diff --git a/max/examples/diffusion/openai_api_example.py b/max/examples/diffusion/openai_api_example.py new file mode 100644 index 00000000000..4bfa31085b2 --- /dev/null +++ b/max/examples/diffusion/openai_api_example.py @@ -0,0 +1,98 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Example: Using OpenAI-compatible API for image generation. + +This example demonstrates how to use the ImageGenerationRequest and +ImageGenerationResponse classes for OpenAI-compatible image generation. + +Usage: + python openai_api_example.py --seed 42 --prompt "A futuristic city skyline at sunset with flying cars" --model-path "black-forest-labs/FLUX.1-dev" +""" + +import argparse +import base64 +import os +from pathlib import Path + +from max.entrypoints.diffusion import ImageGenerator +from max.interfaces import ImageGenerationRequest +from max.pipelines import PipelineConfig + + +def main() -> None: + # Configure random seed for reproducibility + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument( + "--prompt", + type=str, + default="A futuristic city skyline at sunset with flying cars", + ) + parser.add_argument( + "--model-path", type=str, default="black-forest-labs/FLUX.1-dev" + ) + args = parser.parse_args() + seed = args.seed + os.environ["SEED"] = str(seed) + model_path = args.model_path + + # Initialize the generator + pipeline_config = PipelineConfig(model_path=model_path) + generator = ImageGenerator(pipeline_config) + + print(f"Model loaded: {generator.model_name}") + print(f"Seed: {os.getenv('SEED', 'not set')}") + + # Create an OpenAI-compatible request + request = ImageGenerationRequest( + prompt="A futuristic city skyline at sunset with flying cars", + size="1024x1024", + n=1, + quality="standard", + response_format="b64_json", + output_format="png", + # Diffusion-specific parameters + num_inference_steps=28, + guidance_scale=3.5, + seed=seed, + ) + + print(f"Generating image with prompt: {request.prompt}") + print(f"Size: {request.size}, Steps: {request.num_inference_steps}") + + # Generate using OpenAI-compatible API (create method) + response = generator.create(request) + + print(f"Response created at: {response.created}") + print(f"Number of images: {len(response.data)}") + + # Save the generated image + output_dir = Path("outputs") + output_dir.mkdir(parents=True, exist_ok=True) + + for i, image_data in enumerate(response.data): + if image_data.b64_json: + # Decode base64 and save + image_bytes = base64.b64decode(image_data.b64_json) + output_path = output_dir / f"openai_api_output_{i}.png" + with open(output_path, "wb") as f: + f.write(image_bytes) + print(f"Image saved to: {output_path}") + + if image_data.revised_prompt: + print(f"Revised prompt: {image_data.revised_prompt}") + + +if __name__ == "__main__": + main() diff --git a/max/python/max/config/__init__.py b/max/python/max/config/__init__.py index 4b10a5e89d1..440822a3478 100644 --- a/max/python/max/config/__init__.py +++ b/max/python/max/config/__init__.py @@ -16,7 +16,9 @@ import argparse import enum +import json import logging +import os import types from abc import abstractmethod from collections.abc import Mapping @@ -506,6 +508,7 @@ def _extract_max_config_data( config_dict: The loaded YAML configuration dictionary. config_class: The config class we're extracting data for. section_name: Optional specific section name to look for. + config_file_path: Path to the config file for resolving inheritance. Returns: Configuration data for the specific config class. @@ -854,9 +857,9 @@ def _add_field_as_argument( ): # For enums, use the string value as default but we'll need to convert back arg_kwargs = { - "default": field_value.value - if field_value - else field_obj.default + "default": ( + field_value.value if field_value else field_obj.default + ) } else: arg_kwargs = {"default": field_value} @@ -1071,6 +1074,19 @@ def parse_args( # type: ignore[override] # noqa: ANN202 return MAXConfigArgumentParser(parser, self) +def load_config(config_path: str | os.PathLike) -> dict: + if not os.path.exists(config_path): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + try: + with open(config_path, encoding="utf-8") as f: + config_dict = json.loads(f.read()) + except Exception as e: + raise ValueError( + f"Failed to load configuration from {config_path}: {e}" + ) from e + return config_dict + + all = [ "MAXBaseModel", "ConfigFileModel", diff --git a/max/python/max/dtype/__init__.py b/max/python/max/dtype/__init__.py index 864514236b8..ad702907d8c 100644 --- a/max/python/max/dtype/__init__.py +++ b/max/python/max/dtype/__init__.py @@ -11,4 +11,5 @@ # limitations under the License. # ===----------------------------------------------------------------------=== # +from . import dtype_extension from .dtype import DType diff --git a/max/python/max/dtype/dtype_extension.py b/max/python/max/dtype/dtype_extension.py new file mode 100644 index 00000000000..fba3de7e83f --- /dev/null +++ b/max/python/max/dtype/dtype_extension.py @@ -0,0 +1,56 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Extension for max.dtype to support additional attributes.""" + +from numpy import finfo as np_finfo + +from .dtype import DType + + +class finfo: + """A numerical properties of a floating point max.dtype.DType. + + This class mimics torch.finfo behavior without torch dependency, + including support for bfloat16. + + NOTE: Currently, it's applied through patching. + This extension is better to be implemented in dtype library itself. + """ + + def __init__(self, dtype: DType): + """Initialize finfo for a given max.dtype.DType. + + Args: + dtype: The data type to get limits for. + """ + if dtype == DType.bfloat16: + self.min = -3.38953e38 + self.max = 3.38953e38 + self.bits = 16 + self.eps = 0.0078125 + self.resolution = 0.01 + self.tiny = 1.17549e-38 + self.dtype = "bfloat16" + else: + np_finfo_obj = np_finfo(dtype.to_numpy()) + self.min = float(np_finfo_obj.min) + self.max = float(np_finfo_obj.max) + self.bits = np_finfo_obj.bits + self.eps = float(np_finfo_obj.eps) + self.resolution = float(np_finfo_obj.resolution) + self.tiny = float(np_finfo_obj.tiny) + self.dtype = str(np_finfo_obj.dtype) + + +DType.finfo = finfo diff --git a/max/python/max/entrypoints/BUILD.bazel b/max/python/max/entrypoints/BUILD.bazel index 6bef0b817c7..caf523b6598 100644 --- a/max/python/max/entrypoints/BUILD.bazel +++ b/max/python/max/entrypoints/BUILD.bazel @@ -93,6 +93,42 @@ modular_py_binary( ], ) +modular_py_binary( + name = "pipelines_diffusion", + srcs = [ + "pipelines_diffusion.py", + ], + data = [ + "@nvshmem_prebuilt//:host", + ], + env = { + "OTEL_EXPORTER_OTLP_METRICS_DEFAULT_HISTOGRAM_AGGREGATION": "base2_exponential_bucket_histogram", + "MODULAR_SHMEM_LIB_DIR": "../+http_archive+nvshmem_prebuilt", + }, + mojo_deps = select({ + "//:emit_mojo_enabled": PROD_MOJOPKGS, + "//conditions:default": [], + }), + deps = [ + # Provides the `max.entrypoints.pipelines` module for the wrapper to import. + ":_pipelines", + ":entrypoints", + "//max/python/max:_core", + "//max/python/max/benchmark:benchmark_serving_lib", + "//max/python/max/interfaces", + "//max/python/max/pipelines", + "//max/python/max/serve:config", + "//max/python/max/serve/telemetry", + requirement("typing-extensions"), + requirement("click"), + ] + select({ + "//:nvidia_gpu": [ + requirement("torch"), + ], + "//conditions:default": [], + }), +) + modular_py_binary( name = "replay_recording", srcs = ["replay_recording.py"], diff --git a/max/python/max/entrypoints/cli/__init__.py b/max/python/max/entrypoints/cli/__init__.py index 0aa58bb5569..33d030adf22 100644 --- a/max/python/max/entrypoints/cli/__init__.py +++ b/max/python/max/entrypoints/cli/__init__.py @@ -30,7 +30,7 @@ from .generate import generate_text_for_pipeline, stream_text_to_console from .list import list_pipelines_to_console, list_pipelines_to_json from .metrics import TextGenerationMetrics -from .serve import serve_api_server_and_model_worker +from .serve import serve_api_server_and_model_worker, serve_diffusion_api_server __all__ = [ "DevicesOptionType", @@ -48,6 +48,7 @@ "pipeline_encode", "sampling_params_options", "serve_api_server_and_model_worker", + "serve_diffusion_api_server", "stream_text_to_console", "validate_field_type", ] diff --git a/max/python/max/entrypoints/cli/generate.py b/max/python/max/entrypoints/cli/generate.py index 9dd3c93d5ad..01f38fc0086 100644 --- a/max/python/max/entrypoints/cli/generate.py +++ b/max/python/max/entrypoints/cli/generate.py @@ -19,6 +19,7 @@ import dataclasses import logging from collections.abc import Iterable +from pathlib import Path from typing import Any import requests @@ -148,3 +149,41 @@ def generate_text_for_pipeline( print_tokens=True, ) ) + + +def generate_image( + pipeline_config: PipelineConfig, + prompt: str, + height: int, + width: int, + num_inference_steps: int, + guidance_scale: float, + num_images_per_prompt: int, + output: Path, +) -> None: + from ..diffusion import DiffusionPipeline + + pipeline = DiffusionPipeline(pipeline_config) + result = pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + ) + + images = result.images + assert images, "Expected at least one generated image." + + output.parent.mkdir(parents=True, exist_ok=True) + if num_images_per_prompt == 1: + images[0].save(output) + logger.info(f"Image saved to: {output}") + else: + stem = output.stem + suffix = output.suffix + for i, image in enumerate(images): + numbered_path = output.parent / f"{stem}_{i + 1}{suffix}" + image.save(numbered_path) + logger.info(f"{len(images)} images saved to: {output.parent}") diff --git a/max/python/max/entrypoints/cli/serve/__init__.py b/max/python/max/entrypoints/cli/serve/__init__.py index 7b2da021d95..58ae44d4633 100644 --- a/max/python/max/entrypoints/cli/serve/__init__.py +++ b/max/python/max/entrypoints/cli/serve/__init__.py @@ -12,5 +12,6 @@ # ===----------------------------------------------------------------------=== # from .serve_api_and_model_worker import serve_api_server_and_model_worker +from .serve_diffusion_api import serve_diffusion_api_server -__all__ = ["serve_api_server_and_model_worker"] +__all__ = ["serve_api_server_and_model_worker", "serve_diffusion_api_server"] diff --git a/max/python/max/entrypoints/cli/serve/serve_diffusion_api.py b/max/python/max/entrypoints/cli/serve/serve_diffusion_api.py new file mode 100644 index 00000000000..b71c778ddb7 --- /dev/null +++ b/max/python/max/entrypoints/cli/serve/serve_diffusion_api.py @@ -0,0 +1,266 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""API server for diffusion image generation. + +This module provides an OpenAI-compatible API server for image generation +using diffusion models (e.g., FLUX.1). + +The server implements the /v1/images/generations endpoint following the +OpenAI API specification. +""" + +from __future__ import annotations + +import logging +import os +import signal +import time +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager + +import uvloop +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse +from max.entrypoints.diffusion import ImageGenerator +from max.interfaces import ImageGenerationRequest +from max.pipelines import PipelineConfig +from max.profiler import Tracer +from max.serve.api_server import validate_port_is_free +from max.serve.config import Settings +from uvicorn import Config, Server + +logger = logging.getLogger("max.serve.diffusion") + + +@asynccontextmanager +async def lifespan( + app: FastAPI, + settings: Settings, + pipeline_config: PipelineConfig, +) -> AsyncGenerator[None]: + """Manage the lifecycle of the diffusion server.""" + logger.info("Starting diffusion image generation server...") + + # Initialize the diffusion pipeline + try: + pipeline = ImageGenerator(pipeline_config) + app.state.pipeline = pipeline + app.state.pipeline_config = pipeline_config + app.state.settings = settings + except Exception: + logger.exception("Failed to initialize diffusion pipeline") + raise + + logger.info( + f"\n\n{'*' * 80}\n\n" + f"{'Image generation server ready on http://' + settings.host + ':' + str(settings.port) + ' (Press CTRL+C to quit)'.center(80)}\n\n" + f"{'*' * 80}\n" + ) + + yield + + logger.info("Shutting down diffusion server...") + + +def create_diffusion_app( + settings: Settings, + pipeline_config: PipelineConfig, +) -> FastAPI: + """Create the FastAPI application for diffusion serving.""" + + @asynccontextmanager + async def lifespan_wrap(app: FastAPI) -> AsyncGenerator[None, None]: + try: + async with lifespan(app, settings, pipeline_config): + yield + except Exception: + logger.exception("Server exception, shutting down...") + os.kill(os.getpid(), signal.SIGINT) + os.kill(os.getpid(), signal.SIGINT) + + app = FastAPI(title="MAX Serve - Image Generation", lifespan=lifespan_wrap) + + # Health check + @app.get("/health") + async def health() -> JSONResponse: + return JSONResponse({"status": "ok"}) + + @app.get("/v1/health") + async def v1_health() -> JSONResponse: + return JSONResponse({"status": "ok"}) + + # Version endpoint + @app.get("/version") + async def version() -> JSONResponse: + from importlib.metadata import PackageNotFoundError, version + + try: + package_version = version("max") + return JSONResponse({"version": package_version}) + except PackageNotFoundError: + return JSONResponse({"version": "unknown"}) + + # OpenAI-compatible image generation endpoint + @app.post("/v1/images/generations") + async def create_image(request: Request) -> JSONResponse: + """Generate images from a text prompt (OpenAI-compatible). + + Request body follows the OpenAI /v1/images/generations schema: + - prompt (required): A text description of the desired image(s) + - model: The model to use for image generation + - n: Number of images to generate (1-10) + - quality: Image quality ('standard', 'hd', 'high', 'medium', 'low') + - response_format: 'url' or 'b64_json' + - size: Image size (e.g., '1024x1024') + - style: Image style ('vivid', 'natural') + - user: End-user identifier + - background: Background transparency ('transparent', 'opaque', 'auto') + - output_format: Output format ('png', 'jpeg', 'webp') + - num_inference_steps: Number of denoising steps (extension) + - guidance_scale: Classifier-free guidance scale (extension) + - seed: Random seed for reproducibility (extension) + + Returns: + OpenAI-compatible ImagesResponse with created timestamp and + image data (b64_json or url). + """ + try: + # Parse request body + body = await request.json() + + # Validate required field + if "prompt" not in body: + raise ValueError("'prompt' is a required field") + + # Get pipeline from app state + pipeline: ImageGenerator = request.app.state.pipeline + + # Build internal request from OpenAI-compatible fields + internal_request = ImageGenerationRequest( + prompt=body["prompt"], + model=body.get("model"), + n=body.get("n", 1), + quality=body.get("quality", "standard"), + response_format=body.get("response_format", "b64_json"), + size=body.get("size", "1024x1024"), + style=body.get("style"), + user=body.get("user"), + background=body.get("background"), + moderation=body.get("moderation"), + output_compression=body.get("output_compression"), + output_format=body.get("output_format", "png"), + partial_images=body.get("partial_images"), + stream=body.get("stream"), + # Extension parameters for diffusion models + num_inference_steps=body.get("num_inference_steps", 50), + guidance_scale=body.get("guidance_scale", 3.5), + seed=body.get("seed"), + ) + + logger.debug( + "Processing image generation request: prompt=%r, size=%s, n=%d", + internal_request.prompt[:50] + if len(internal_request.prompt) > 50 + else internal_request.prompt, + internal_request.size, + internal_request.n or 1, + ) + + # Generate images + response = pipeline.generate(internal_request) + + # Return OpenAI-compatible response + return JSONResponse(response.to_dict()) + + except ValueError as e: + logger.warning("Invalid request: %s", str(e)) + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + logger.exception("Image generation failed") + raise HTTPException( + status_code=500, detail="Image generation failed" + ) from e + + # Model info endpoint + @app.get("/v1/models") + async def list_models(request: Request) -> JSONResponse: + """List available models.""" + pipeline: ImageGenerator = request.app.state.pipeline + return JSONResponse( + { + "object": "list", + "data": [ + { + "id": pipeline.model_name, + "object": "model", + "created": int(time.time()), + "owned_by": "modular", + } + ], + } + ) + + @app.get("/v1/models/{model_id}") + async def get_model(model_id: str, request: Request) -> JSONResponse: + """Get model information.""" + pipeline: ImageGenerator = request.app.state.pipeline + + # Check if the model_id matches + if model_id == pipeline.model_name or model_id in pipeline.model_name: + return JSONResponse( + { + "id": pipeline.model_name, + "object": "model", + "created": int(time.time()), + "owned_by": "modular", + } + ) + + raise HTTPException( + status_code=404, detail=f"Model '{model_id}' not found" + ) + + return app + + +def serve_diffusion_api_server( + settings: Settings, + pipeline_config: PipelineConfig, +) -> None: + """Start the diffusion API server. + + Args: + settings: Server settings (port, host, etc.). + pipeline_config: Configuration for the diffusion pipeline. + """ + # Create the FastAPI app + app = create_diffusion_app(settings, pipeline_config) + + # Configure uvicorn + config = Config( + app=app, + log_config=None, + loop="uvloop", + host=settings.host, + port=settings.port, + timeout_graceful_shutdown=5, + ) + + # Validate port before loading models + validate_port_is_free(settings.port) + + server = Server(config) + + with Tracer("diffusion_api_server"): + uvloop.run(server.serve()) diff --git a/max/python/max/entrypoints/diffusion.py b/max/python/max/entrypoints/diffusion.py new file mode 100644 index 00000000000..9f0d8950f0c --- /dev/null +++ b/max/python/max/entrypoints/diffusion.py @@ -0,0 +1,398 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""High-level interface for image generation using diffusion models. + +This module provides both programmatic and OpenAI-compatible API access +to diffusion-based image generation pipelines. + +Example (Direct API): + ```python + from max.entrypoints.diffusion import ImageGenerator + from max.pipelines import PipelineConfig + + config = PipelineConfig(model="black-forest-labs/FLUX.1-schnell") + generator = ImageGenerator(config) + + images = generator.generate("A beautiful sunset over mountains") + images[0].save("output.png") + ``` + +Example (OpenAI-compatible API): + ```python + from max.entrypoints.diffusion import ImageGenerator + from max.interfaces import ImageGenerationRequest + + generator = ImageGenerator(config) + request = ImageGenerationRequest( + prompt="A beautiful sunset", + size="1024x1024", + n=1, + ) + response = generator.create(request) + # response.data[0].b64_json contains the base64-encoded image + ``` +""" + +from __future__ import annotations + +import queue +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from threading import Event, Thread +from typing import TYPE_CHECKING + +import tqdm +from max.interfaces import ( + ImageGenerationInputs, + ImageGenerationOutput, + ImageGenerationRequest, + ImageGenerationResponse, + PipelineTask, + RequestID, +) +from max.pipelines.lib import PIPELINE_REGISTRY, PipelineConfig +from PIL.Image import Image + +if TYPE_CHECKING: + from max.pipelines.lib.pipeline_variants.image_generation import ( + ImageGenerationPipeline, + ) + + +@dataclass +class _ImageRequest: + """Internal request object for the image generation queue.""" + + id: RequestID + prompts: Sequence[str] + negative_prompts: Sequence[str] | None = None + height: int = 1024 + width: int = 1024 + num_inference_steps: int = 50 + guidance_scale: float = 3.5 + num_images_per_prompt: int = 1 + use_tqdm: bool = True + + +@dataclass +class _ImageResponse: + """Internal response object from the image generation queue.""" + + images: list[Image] + + +@dataclass +class _ThreadControl: + """Thread synchronization primitives.""" + + ready: Event = field(default_factory=Event) + cancel: Event = field(default_factory=Event) + + +class ImageGenerator: + """High-level interface for generating images using diffusion models. + + This class provides a thread-safe interface for image generation with + support for both direct API calls and OpenAI-compatible request/response. + + The generator runs a background worker thread that processes requests + from a queue, allowing for concurrent request handling. + + Attributes: + model_name: The name/path of the loaded model. + pipeline_config: The configuration for the pipeline. + """ + + # Thread control and communication + _thread_control: _ThreadControl + _worker_thread: Thread + _request_queue: queue.Queue[_ImageRequest] + _pending_requests: dict[RequestID, queue.Queue[_ImageResponse]] + + # Configuration + pipeline_config: PipelineConfig + model_name: str + + def __init__(self, pipeline_config: PipelineConfig) -> None: + """Initialize the image generator. + + Args: + pipeline_config: Configuration specifying the model and parameters. + """ + self.pipeline_config = pipeline_config + self.model_name = pipeline_config.model.model_path + + # Initialize thread control and queues + self._thread_control = _ThreadControl() + self._request_queue = queue.Queue() + self._pending_requests = {} + + # Start background worker + self._worker_thread = Thread( + target=_run_worker, + args=( + self._thread_control, + self.pipeline_config, + self._request_queue, + self._pending_requests, + ), + daemon=True, + ) + self._worker_thread.start() + + # Wait for worker to be ready + self._thread_control.ready.wait() + + def __del__(self) -> None: + """Clean up resources.""" + self._thread_control.cancel.set() + if self._worker_thread.is_alive(): + self._worker_thread.join(timeout=5.0) + + def generate( + self, + prompts: str | Sequence[str], + *, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 50, + guidance_scale: float = 3.5, + num_images_per_prompt: int = 1, + use_tqdm: bool = True, + ) -> list[Image]: + """Generate images from text prompts. + + This method is thread-safe and can be called from multiple threads. + + Args: + prompts: Single prompt string or sequence of prompts. + height: Image height in pixels. + width: Image width in pixels. + num_inference_steps: Number of denoising steps. + guidance_scale: Classifier-free guidance scale. + num_images_per_prompt: Number of images per prompt. + use_tqdm: Show progress bar. + + Returns: + List of generated PIL Images. + + Example: + ```python + images = generator.generate( + "A cat sitting on a couch", + height=1024, + width=1024, + num_inference_steps=30, + ) + images[0].save("cat.png") + ``` + """ + # Normalize prompts to sequence + if isinstance(prompts, str): + prompts = [prompts] + + # Create internal request + request = _ImageRequest( + id=RequestID(), + prompts=prompts, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + use_tqdm=use_tqdm, + ) + + # Submit request and wait for response + return self._submit_and_wait(request) + + def create( + self, + request: ImageGenerationRequest, + ) -> ImageGenerationResponse: + """Generate images using OpenAI-compatible request format. + + Args: + request: OpenAI-compatible image generation request. + + Returns: + OpenAI-compatible response with base64-encoded images. + + Example: + ```python + request = ImageGenerationRequest( + prompt="A beautiful landscape", + size="1024x1024", + n=2, + response_format="b64_json", + ) + response = generator.create(request) + print(f"Generated {len(response.data)} images") + ``` + """ + # Parse dimensions from size string + width, height = request.get_dimensions() + + # Generate images + images = self.generate( + prompts=request.prompt, + height=height, + width=width, + num_inference_steps=request.num_inference_steps, + guidance_scale=request.guidance_scale, + num_images_per_prompt=request.n or 1, + use_tqdm=False, + ) + + # Convert to OpenAI response format + output = ImageGenerationOutput(images=images) + return ImageGenerationResponse.from_pipeline_output( + output=output, + response_format=request.response_format, + output_format=request.get_output_format(), + prompt=request.prompt, + ) + + def _submit_and_wait(self, request: _ImageRequest) -> list[Image]: + """Submit a request to the queue and wait for response.""" + response_queue: queue.Queue[_ImageResponse] = queue.Queue() + self._pending_requests[request.id] = response_queue + + try: + self._request_queue.put_nowait(request) + response = response_queue.get() + return response.images + finally: + self._pending_requests.pop(request.id, None) + + @classmethod + def from_model(cls, model: str, **kwargs) -> ImageGenerator: + """Create an ImageGenerator from a model identifier. + + Args: + model: Model identifier (e.g., "black-forest-labs/FLUX.1-schnell"). + **kwargs: Additional PipelineConfig arguments. + + Returns: + Configured ImageGenerator instance. + + Example: + ```python + generator = ImageGenerator.from_model( + "black-forest-labs/FLUX.1-schnell" + ) + ``` + """ + config = PipelineConfig(model=model, **kwargs) + return cls(config) + + +DiffusionPipeline = ImageGenerator + + +def _run_worker( + thread_control: _ThreadControl, + pipeline_config: PipelineConfig, + request_queue: queue.Queue[_ImageRequest], + pending_requests: Mapping[RequestID, queue.Queue[_ImageResponse]], +) -> None: + """Background worker that processes image generation requests. + + This function runs in a separate thread and continuously processes + requests from the queue until cancellation is signaled. + """ + # Load the pipeline + _, model_factory = PIPELINE_REGISTRY.retrieve_factory( + pipeline_config, + task=PipelineTask.IMAGE_GENERATION, + ) + pipeline: ImageGenerationPipeline = model_factory() + + # Signal that we're ready + thread_control.ready.set() + + # Main processing loop + while not thread_control.cancel.is_set(): + try: + request = request_queue.get(timeout=0.3) + except queue.Empty: + continue + + # Process the request + images = _process_request(pipeline, request) + + # Send response + if response_queue := pending_requests.get(request.id): + response_queue.put(_ImageResponse(images=images)) + + +def _process_request( + pipeline: ImageGenerationPipeline, + request: _ImageRequest, +) -> list[Image]: + """Process a single image generation request. + + Args: + pipeline: The image generation pipeline. + request: The request to process. + + Returns: + List of generated images. + """ + all_images: list[Image] = [] + + # Create iterator with optional progress bar + if request.negative_prompts is None or len(request.prompts) != len( + request.negative_prompts + ): + if ( + request.negative_prompts is None + or len(request.negative_prompts) == 0 + ): + request.negative_prompts = [None] * len(request.prompts) + else: + raise ValueError( + "Number of prompts and negative prompts must be the same." + ) + + # TODO: temp hard coding for true cfg scale. Need to be removed. + if all( + negative_prompt is not None + for negative_prompt in request.negative_prompts + ): + true_cfg_scale = 1.0 + else: + true_cfg_scale = 4.0 + + prompt_iter = zip(request.prompts, request.negative_prompts, strict=False) + if request.use_tqdm: + prompt_iter = tqdm.tqdm(prompt_iter, desc="Generating images") + + # Generate images for each prompt + for prompt, negative_prompt in prompt_iter: + inputs = ImageGenerationInputs( + prompt=prompt, + negative_prompt=negative_prompt, + height=request.height, + width=request.width, + num_inference_steps=request.num_inference_steps, + guidance_scale=request.guidance_scale, + num_images_per_prompt=request.num_images_per_prompt, + true_cfg_scale=true_cfg_scale, + ) + + output: ImageGenerationOutput = pipeline.execute(inputs) + all_images.extend(output.images) + + return all_images diff --git a/max/python/max/entrypoints/pipelines.py b/max/python/max/entrypoints/pipelines.py index c8dcada7415..571d8516fe9 100644 --- a/max/python/max/entrypoints/pipelines.py +++ b/max/python/max/entrypoints/pipelines.py @@ -18,7 +18,8 @@ import os import sys from collections.abc import Callable, Sequence -from typing import Any, TypeVar +from pathlib import Path +from typing import Any, Literal, TypeVar import click from click import shell_completion @@ -384,6 +385,246 @@ def cli_pipeline( ) +@main.group(name="images", cls=ModelGroup) +def images_group() -> None: + """Commands for image generation (OpenAI-compatible /v1/images/* API).""" + + +@images_group.command(name="generate", cls=WithLazyPipelineOptions) +@click.option( + "--model", + type=str, + required=True, + help="Specify the repository ID of a Hugging Face model to use for image generation (e.g., 'black-forest-labs/FLUX.1-dev').", +) +@click.option( + "--prompt", + type=str, + required=True, + help="A text description of the desired image(s).", +) +@click.option( + "--n", + type=click.IntRange(min=1, max=10), + default=1, + show_default=True, + help="The number of images to generate (1-10).", +) +@click.option( + "--size", + type=str, + default="1024x1024", + show_default=True, + help="The size of generated images (e.g., '1024x1024', '1792x1024', '1024x1792').", +) +@click.option( + "--quality", + type=click.Choice(["auto", "standard", "hd", "high", "medium", "low"]), + default="auto", + show_default=True, + help="The quality of the image.", +) +@click.option( + "--response-format", + type=click.Choice(["b64_json", "url"]), + default="b64_json", + show_default=True, + help="The format in which generated images are returned.", +) +@click.option( + "--output-format", + type=click.Choice(["png", "jpeg", "webp"]), + default="png", + show_default=True, + help="The output image format.", +) +@click.option( + "--style", + type=click.Choice(["vivid", "natural"]), + default=None, + help="The style of generated images.", +) +@click.option( + "--output", + type=click.Path(dir_okay=False, path_type=Path), + default="output.png", + show_default=True, + help="Output image path (numbered if multiple images are generated).", +) +@click.option( + "--seed", + type=int, + default=None, + help="Random seed for reproducibility.", +) +@click.option( + "--num-inference-steps", + type=click.IntRange(min=1), + default=50, + show_default=True, + help="Number of denoising steps (diffusion model parameter).", +) +@click.option( + "--guidance-scale", + type=float, + default=3.5, + show_default=True, + help="Classifier-free guidance scale (diffusion model parameter).", +) +def images_generate( + model: str, + prompt: str, + n: int, + size: str, + quality: str, + response_format: Literal["url", "b64_json"], + output_format: str, + style: str | None, + output: Path, + seed: int | None, + num_inference_steps: int, + guidance_scale: float, + **config_kwargs: Any, +) -> None: + """Generate images from a text prompt. + + This command follows the OpenAI /v1/images/generations API schema. + + Example: + max images generate --model black-forest-labs/FLUX.1-dev \\ + --prompt "A beautiful sunset over mountains" \\ + --size 1024x1024 --n 1 --output sunset.png + """ + from max.entrypoints.diffusion import ImageGenerator + from max.experimental.realization_context import set_seed + from max.interfaces import ImageGenerationRequest + from max.pipelines.lib.config import ImageGenerationConfig + + # Set random seed if provided + if seed is not None: + set_seed(seed) + + """ + TODO: + - This configuration is dummy for now. Just for logging purpose. + - Modifications are required to enable the use of pipeline config. + """ + config_kwargs["model"] = model + pipeline_config = ImageGenerationConfig(**config_kwargs) + pipeline_config.log_basic_config() + + try: + generator = ImageGenerator(pipeline_config) + + # Create OpenAI-compatible request + request = ImageGenerationRequest( + prompt=prompt, + n=n, + size=size, + quality=quality, + response_format=response_format, + output_format=output_format, + style=style, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + seed=seed, + ) + + logger.info(f"Generating {n} image(s) with prompt: {prompt[:50]}...") + + # Generate images directly (not using OpenAI response format for CLI) + images = generator.generate( + prompts=prompt, + height=request.get_dimensions()[1], + width=request.get_dimensions()[0], + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=n, + use_tqdm=True, + ) + + # Save images + output.parent.mkdir(parents=True, exist_ok=True) + if len(images) == 1: + images[0].save(output) + logger.info(f"Image saved to: {output}") + else: + for i, img in enumerate(images): + numbered_output = output.with_stem(f"{output.stem}_{i}") + img.save(numbered_output) + logger.info(f"Image saved to: {numbered_output}") + + except Exception as exc: + logger.exception( + "Image generation failed for model %s with prompt %r", + pipeline_config.model.model_path, + prompt, + ) + raise click.ClickException("Image generation failed.") from exc + + +@images_group.command(name="serve", cls=WithLazyPipelineOptions) +@common_server_options +def images_serve( + profile_serve: bool, + sim_failure: int, + port: int, + headless: bool, + log_prefix: str | None, + **config_kwargs: Any, +) -> None: + """Start an OpenAI-compatible image generation server. + + This command launches a server with the following endpoints: + - POST /v1/images/generations (create image) + - GET /v1/models (list models) + - GET /health (health check) + + Example: + max images serve --model black-forest-labs/FLUX.1-dev --port 8000 + + Then use curl to generate images: + curl http://localhost:8000/v1/images/generations \\ + -H "Content-Type: application/json" \\ + -d '{"prompt": "A cat", "size": "1024x1024"}' + """ + from max.entrypoints.cli import serve_diffusion_api_server + from max.pipelines import PipelineConfig + from max.serve.config import Settings + from max.serve.telemetry.common import configure_logging + + # Initialize Settings + setting_kwargs: dict[str, Any] = {} + if port is not None: + setting_kwargs["MAX_SERVE_PORT"] = port + if log_prefix is not None: + setting_kwargs["MAX_SERVE_LOG_PREFIX"] = log_prefix + if headless is not None: + setting_kwargs["MAX_SERVE_HEADLESS"] = headless + + settings = Settings(**setting_kwargs) + + # Initialize pipeline config + pipeline_config = PipelineConfig(**config_kwargs) + pipeline_config.log_basic_config() + + # Configure logging + configure_logging(settings) + + if headless: + logger.error("Headless mode is not supported for image serving yet") + raise click.ClickException("Headless mode not supported") + + serve_diffusion_api_server( + settings=settings, + pipeline_config=pipeline_config, + ) + + +# Legacy alias for backward compatibility +diffusion_group = images_group + + @main.command(name="encode", cls=WithLazyPipelineOptions) @click.option( "--prompt", diff --git a/max/python/max/entrypoints/pipelines_diffusion.py b/max/python/max/entrypoints/pipelines_diffusion.py new file mode 100644 index 00000000000..009bebd07e1 --- /dev/null +++ b/max/python/max/entrypoints/pipelines_diffusion.py @@ -0,0 +1,40 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Diffusion-only CLI wrapper. + +This exists so Bazel can keep `//max/python/max/entrypoints:pipelines` lean, +while allowing `//max/python/max/entrypoints:pipelines_diffusion` to pull in +extra runtime deps. +""" + +from __future__ import annotations + +import sys + + +def main() -> None: + # Import the main pipelines CLI and dispatch into the `diffusion` group. + # + # NOTE: `max.entrypoints.pipelines.main` is a click command object. Calling it + # with `args=[...]` is equivalent to invoking the CLI with those argv tokens. + import max.entrypoints.pipelines as pipelines_cli + + pipelines_cli.main( + prog_name="pipelines", + args=[*sys.argv[1:]], + ) + + +if __name__ == "__main__": + main() diff --git a/max/python/max/experimental/BUILD.bazel b/max/python/max/experimental/BUILD.bazel index 9c95184007c..575bbbc0a09 100644 --- a/max/python/max/experimental/BUILD.bazel +++ b/max/python/max/experimental/BUILD.bazel @@ -8,6 +8,7 @@ modular_py_library( "__init__.py", "_passes.py", "_tensor_repr.py", + "compile_utils.py", "functional.py", "random.py", "realization_context.py", diff --git a/max/python/max/experimental/compile_utils.py b/max/python/max/experimental/compile_utils.py new file mode 100644 index 00000000000..13c10b93d83 --- /dev/null +++ b/max/python/max/experimental/compile_utils.py @@ -0,0 +1,97 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from collections.abc import Callable, Iterable +from typing import Any + +from max.driver import CPU, Accelerator +from max.engine import InferenceSession +from max.graph import Graph, TensorType +from max.nn.module_v3 import Module + + +class CompileWrapper: + def __init__( + self, + compile_target: Callable | Module, + input_types: Iterable[TensorType] | None = None, + ) -> None: + """Initialize the CompileWrapper. + + Args: + compile_target: The function or module to be compiled. + input_types: A list of input types (TensorTypes) required for compilation. + + Raises: + ValueError: If input_types is not provided. + """ + if input_types is None: + raise ValueError( + f"input_types must be provided for compilation of {compile_target.__name__}." + ) + + self.is_module = False + if isinstance(compile_target, Module): + self.is_module = True + self.session = compile_target.compile(input_types) + return + + with Graph(compile_target.__name__, input_types=input_types) as graph: + output = compile_target(*graph.inputs) + graph.output(output) + compiled_graph = graph + + if any(input_type.device.is_gpu() for input_type in input_types): + device = Accelerator() + else: + device = CPU() + session = InferenceSession([device]) + loaded_session = session.load(compiled_graph) + self.session = loaded_session + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Execute the compiled session with the given arguments. + + Args: + *args: Positional arguments to pass to the session. + **kwargs: Keyword arguments to pass to the session. + + Returns: + The result of the session execution. + """ + if self.is_module: + return self.session(*args, **kwargs) + return self.session.execute(*args, **kwargs) + + +def max_compile( + compile_target: Callable | Module | None = None, + input_types: Iterable[TensorType] | None = None, +) -> Callable[[Callable | Module], CompileWrapper] | CompileWrapper: + """Decorator or function to compile a target with specified input types. + + Args: + compile_target: The function or module to compile. If None, returns a decorator. + input_types: The input types for the compilation. + + Returns: + A CompileWrapper instance if compile_target is provided, otherwise a decorator. + """ + if compile_target is None: + + def decorator(f: Callable | Module) -> CompileWrapper: + return CompileWrapper(f, input_types) + + return decorator + + return CompileWrapper(compile_target, input_types) diff --git a/max/python/max/interfaces/__init__.py b/max/python/max/interfaces/__init__.py index 3d3062549e5..41c32dcf7ce 100644 --- a/max/python/max/interfaces/__init__.py +++ b/max/python/max/interfaces/__init__.py @@ -49,7 +49,16 @@ EmbeddingsGenerationInputs, EmbeddingsGenerationOutput, ImageContentPart, + ImageData, + ImageGenerationContext, + ImageGenerationContextType, + ImageGenerationInputs, + ImageGenerationOutput, + ImageGenerationRequest, + ImageGenerationResponse, + ImageGenerationUsage, ImageMetadata, + InputTokensDetails, TextContentPart, TextGenerationContext, TextGenerationContextType, @@ -109,7 +118,16 @@ def create_text_pipeline() -> Pipeline[TextGenerationInputs, TextGenerationOutpu "EmbeddingsGenerationOutput", "GenerationStatus", "ImageContentPart", + "ImageData", + "ImageGenerationContext", + "ImageGenerationContextType", + "ImageGenerationInputs", + "ImageGenerationOutput", + "ImageGenerationRequest", + "ImageGenerationResponse", + "ImageGenerationUsage", "ImageMetadata", + "InputTokensDetails", "LoRAOperation", "LoRARequest", "LoRAResponse", diff --git a/max/python/max/interfaces/pipeline_variants/__init__.py b/max/python/max/interfaces/pipeline_variants/__init__.py index d5a7b10d1f3..22913eeb74f 100644 --- a/max/python/max/interfaces/pipeline_variants/__init__.py +++ b/max/python/max/interfaces/pipeline_variants/__init__.py @@ -24,6 +24,17 @@ EmbeddingsGenerationInputs, EmbeddingsGenerationOutput, ) +from .image_generation import ( + ImageData, + ImageGenerationContext, + ImageGenerationContextType, + ImageGenerationInputs, + ImageGenerationOutput, + ImageGenerationRequest, + ImageGenerationResponse, + ImageGenerationUsage, + InputTokensDetails, +) from .text_generation import ( BatchType, ImageContentPart, @@ -54,7 +65,16 @@ "EmbeddingsGenerationInputs", "EmbeddingsGenerationOutput", "ImageContentPart", + "ImageData", + "ImageGenerationContext", + "ImageGenerationContextType", + "ImageGenerationInputs", + "ImageGenerationOutput", + "ImageGenerationRequest", + "ImageGenerationResponse", + "ImageGenerationUsage", "ImageMetadata", + "InputTokensDetails", "TextContentPart", "TextGenerationContext", "TextGenerationContextType", diff --git a/max/python/max/interfaces/pipeline_variants/image_generation.py b/max/python/max/interfaces/pipeline_variants/image_generation.py new file mode 100644 index 00000000000..493407e1760 --- /dev/null +++ b/max/python/max/interfaces/pipeline_variants/image_generation.py @@ -0,0 +1,373 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""OpenAI-compatible image generation request/response models. + +This module provides dataclasses that map to the OpenAI /v1/images/generations +API schema for seamless integration with OpenAI-compatible clients. +""" + +from __future__ import annotations + +import base64 +import io +import time +from dataclasses import dataclass, field +from typing import Any, Literal, Protocol, TypeVar, runtime_checkable + +from max.interfaces.context import BaseContext +from max.interfaces.pipeline import PipelineInputs +from max.interfaces.request import RequestID +from PIL.Image import Image + + +@runtime_checkable +class ImageGenerationContext(BaseContext, Protocol): + """Protocol defining the interface for image generation contexts. + + An ``ImageGenerationContext`` represents model inputs for image generation + pipelines, managing the state and parameters needed for generating images + from text prompts using diffusion models. + """ + + @property + def request_id(self) -> RequestID: + """Unique identifier for this request.""" + ... + + @property + def prompt(self) -> str: + """The text prompt for image generation.""" + ... + + @property + def height(self) -> int: + """The height of the generated image in pixels.""" + ... + + @property + def width(self) -> int: + """The width of the generated image in pixels.""" + ... + + @property + def num_inference_steps(self) -> int: + """Number of denoising steps.""" + ... + + @property + def guidance_scale(self) -> float: + """Classifier-free guidance scale.""" + ... + + @property + def num_images_per_prompt(self) -> int: + """Number of images to generate per prompt.""" + ... + + +ImageGenerationContextType = TypeVar( + "ImageGenerationContextType", bound=ImageGenerationContext +) + + +# Default image generation parameters +DEFAULT_SIZE = "1024x1024" +DEFAULT_NUM_IMAGES = 1 +DEFAULT_RESPONSE_FORMAT = "b64_json" +DEFAULT_QUALITY = "standard" +DEFAULT_OUTPUT_FORMAT = "png" +DEFAULT_NUM_INFERENCE_STEPS = 50 +DEFAULT_GUIDANCE_SCALE = 3.5 + + +@dataclass(eq=True) +class ImageGenerationInputs(PipelineInputs): + """Inputs for image-generation pipelines.""" + + prompt: str + negative_prompt: str | None = None + true_cfg_scale: float | None = None + height: int = 1024 + width: int = 1024 + num_inference_steps: int = 50 + guidance_scale: float = 3.5 + num_images_per_prompt: int = 1 + + +@dataclass(kw_only=True) +class ImageGenerationOutput: + """Output container for generated images.""" + + images: list[Image] + """List of generated images.""" + + @property + def is_done(self) -> bool: + """Indicates whether image generation is complete. + + Returns: + bool: Always True, as image generation is a single-step operation. + """ + return True + + +@dataclass +class ImageGenerationRequest: + """OpenAI-compatible image generation request. + + This maps to the OpenAI /v1/images/generations API schema. + See: https://platform.openai.com/docs/api-reference/images/create + """ + + # Required field + prompt: str + """A text description of the desired image(s). Required.""" + + # OpenAI standard fields + model: str | None = None + """The model to use for image generation.""" + + n: int | None = DEFAULT_NUM_IMAGES + """The number of images to generate. Must be between 1 and 10.""" + + quality: str | None = DEFAULT_QUALITY + """The quality of the image (e.g., 'standard', 'hd', 'high', 'medium', 'low').""" + + response_format: Literal["url", "b64_json"] | None = DEFAULT_RESPONSE_FORMAT + """The format in which generated images are returned.""" + + size: str | None = DEFAULT_SIZE + """The size of the generated images (e.g., '1024x1024').""" + + style: str | None = None + """The style of the generated images (e.g., 'vivid', 'natural').""" + + user: str | None = None + """A unique identifier representing your end-user.""" + + # Extended fields for GPT image models + background: str | None = None + """Background transparency ('transparent', 'opaque', 'auto').""" + + moderation: str | None = None + """Content-moderation level ('low', 'auto').""" + + output_compression: int | None = None + """Compression level (0-100%).""" + + output_format: str | None = DEFAULT_OUTPUT_FORMAT + """Output format ('png', 'jpeg', 'webp').""" + + partial_images: int | None = None + """Number of partial images for streaming (0-3).""" + + stream: bool | None = None + """Generate in streaming mode.""" + + # Extended parameters for diffusion models (not in OpenAI spec) + num_inference_steps: int = DEFAULT_NUM_INFERENCE_STEPS + """Number of denoising steps. Extension for diffusion models.""" + + guidance_scale: float = DEFAULT_GUIDANCE_SCALE + """Classifier-free guidance scale. Extension for diffusion models.""" + + seed: int | None = None + """Random seed for reproducibility.""" + + def to_pipeline_inputs(self) -> ImageGenerationInputs: + """Convert OpenAI request to pipeline-native inputs.""" + width, height = self.get_dimensions() + return ImageGenerationInputs( + prompt=self.prompt, + height=height, + width=width, + num_inference_steps=self.num_inference_steps, + guidance_scale=self.guidance_scale, + num_images_per_prompt=self.n or 1, + ) + + def get_dimensions(self) -> tuple[int, int]: + """Parse size string and return (width, height) tuple.""" + size = self.size or DEFAULT_SIZE + + # Handle 'auto' as default + if size == "auto": + size = DEFAULT_SIZE + + # Parse WIDTHxHEIGHT format + try: + parts = size.lower().split("x") + if len(parts) == 2: + width, height = int(parts[0]), int(parts[1]) + if width >= 64 and height >= 64: + return width, height + except (ValueError, IndexError): + pass + + raise ValueError( + f"Invalid size '{size}'. Use format 'WIDTHxHEIGHT' (e.g., '1024x1024')." + ) + + def get_output_format(self) -> str: + """Get the output image format.""" + return self.output_format or DEFAULT_OUTPUT_FORMAT + + +@dataclass +class ImageData: + """Individual image data in the response.""" + + b64_json: str | None = None + """The base64-encoded image data.""" + + url: str | None = None + """The URL of the generated image (valid for 60 minutes).""" + + revised_prompt: str | None = None + """The prompt that was used to generate the image, if revised.""" + + +@dataclass +class InputTokensDetails: + """Details about input token usage.""" + + text_tokens: int = 0 + """Number of text tokens in the input.""" + + image_tokens: int = 0 + """Number of image tokens in the input.""" + + +@dataclass +class ImageGenerationUsage: + """Token usage statistics for image generation.""" + + total_tokens: int = 0 + """Total number of tokens used.""" + + input_tokens: int = 0 + """Number of input tokens.""" + + output_tokens: int = 0 + """Number of output tokens.""" + + input_tokens_details: InputTokensDetails | None = None + """Detailed breakdown of input tokens.""" + + +@dataclass +class ImageGenerationResponse: + """OpenAI-compatible image generation response. + + This maps to the OpenAI ImagesResponse schema. + See: https://platform.openai.com/docs/api-reference/images/object + """ + + created: int + """Unix timestamp when the response was created.""" + + data: list[ImageData] = field(default_factory=list) + """List of generated image data.""" + + usage: ImageGenerationUsage | None = None + """Token usage statistics.""" + + @classmethod + def from_pipeline_output( + cls, + output: ImageGenerationOutput, + response_format: Literal["url", "b64_json"] | None = "b64_json", + output_format: str = "png", + prompt: str | None = None, + ) -> ImageGenerationResponse: + """Convert pipeline output to OpenAI-compatible response. + + Args: + output: The raw pipeline output containing PIL images. + response_format: The desired response format ('url' or 'b64_json'). + output_format: The image format ('png', 'jpeg', 'webp'). + prompt: The original prompt, included as revised_prompt if provided. + + Returns: + An OpenAI-compatible ImageGenerationResponse. + """ + data: list[ImageData] = [] + fmt = response_format or "b64_json" + + # Map output_format to PIL format + pil_format_map = { + "png": "PNG", + "jpeg": "JPEG", + "webp": "WEBP", + } + pil_format = pil_format_map.get(output_format.lower(), "PNG") + + for image in output.images: + image_data = ImageData(revised_prompt=prompt) + + if fmt == "b64_json": + buffer = io.BytesIO() + image.save(buffer, format=pil_format) + buffer.seek(0) + image_data.b64_json = base64.b64encode(buffer.read()).decode( + "utf-8" + ) + elif fmt == "url": + raise ValueError( + "response_format='url' requires external storage. " + "Use 'b64_json' for local image generation." + ) + + data.append(image_data) + + return cls( + created=int(time.time()), + data=data, + usage=None, + ) + + def to_dict(self) -> dict[str, Any]: + """Convert response to dictionary for JSON serialization.""" + result: dict[str, Any] = { + "created": self.created, + "data": [ + { + k: v + for k, v in { + "b64_json": img.b64_json, + "url": img.url, + "revised_prompt": img.revised_prompt, + }.items() + if v is not None + } + for img in self.data + ], + } + + if self.usage is not None: + usage_dict: dict[str, Any] = { + "total_tokens": self.usage.total_tokens, + "input_tokens": self.usage.input_tokens, + "output_tokens": self.usage.output_tokens, + } + if self.usage.input_tokens_details is not None: + usage_dict["input_tokens_details"] = { + "text_tokens": self.usage.input_tokens_details.text_tokens, + "image_tokens": ( + self.usage.input_tokens_details.image_tokens + ), + } + result["usage"] = usage_dict + + return result diff --git a/max/python/max/interfaces/task.py b/max/python/max/interfaces/task.py index 477b77451e7..243d86ca443 100644 --- a/max/python/max/interfaces/task.py +++ b/max/python/max/interfaces/task.py @@ -58,6 +58,8 @@ class PipelineTask(str, Enum): """Task for generating audio.""" SPEECH_TOKEN_GENERATION = "speech_token_generation" """Task for generating speech tokens.""" + IMAGE_GENERATION = "image_generation" + """Task for generating images.""" @property def output_type( @@ -72,6 +74,7 @@ def output_type( from .pipeline_variants import ( AudioGenerationOutput, EmbeddingsGenerationOutput, + ImageGenerationOutput, TextGenerationOutput, ) from .scheduler import SchedulerResult @@ -85,6 +88,8 @@ def output_type( return dict[RequestID, SchedulerResult[EmbeddingsGenerationOutput]] elif self == PipelineTask.AUDIO_GENERATION: return dict[RequestID, SchedulerResult[AudioGenerationOutput]] + elif self == PipelineTask.IMAGE_GENERATION: + return dict[RequestID, SchedulerResult[ImageGenerationOutput]] else: raise ValueError( f"PipelineTask ({self}) does not have an output_type defined." diff --git a/max/python/max/nn/norm/group_norm.py b/max/python/max/nn/norm/group_norm.py index f1241d1a9aa..7ffc38e3ea6 100644 --- a/max/python/max/nn/norm/group_norm.py +++ b/max/python/max/nn/norm/group_norm.py @@ -45,6 +45,7 @@ def __init__( eps: float = 1e-5, affine: bool = True, device: DeviceRef = DeviceRef.GPU(), + dtype: DType = DType.float32, ) -> None: super().__init__() self.num_groups = num_groups @@ -65,13 +66,13 @@ def __init__( self.weight = Weight( name="weight", shape=(self.num_channels,), - dtype=DType.float32, + dtype=dtype, device=device, ) self.bias = Weight( name="bias", shape=(self.num_channels,), - dtype=DType.float32, + dtype=dtype, device=device, ) diff --git a/max/python/max/nn/norm/layer_norm.py b/max/python/max/nn/norm/layer_norm.py index 47a328c0faa..3e247bf7689 100644 --- a/max/python/max/nn/norm/layer_norm.py +++ b/max/python/max/nn/norm/layer_norm.py @@ -36,37 +36,56 @@ def __init__( dtype: DType, eps: float = 1e-5, use_bias: bool = True, + keep_dtype: bool = False, + elementwise_affine: bool = True, ) -> None: super().__init__() self.devices = devices - self.weight = Weight("weight", dtype, (dims,), device=self.devices[0]) - self.bias = ( - Weight("bias", dtype, (dims,), device=self.devices[0]) - if use_bias - else None - ) + if elementwise_affine: + self.weight = Weight( + "weight", dtype, (dims,), device=self.devices[0] + ) + self.bias = ( + Weight("bias", dtype, (dims,), device=self.devices[0]) + if use_bias + else None + ) + else: + self.weight = None + self.bias = None self.eps = eps self.dim = dims self.dtype = dtype + self.keep_dtype = keep_dtype self._sharding_strategy: ShardingStrategy | None = None def __call__(self, input: TensorValue): # TODO: AIPIPE-95 Replace with a broadcasting rmo.layer_norm bias = ( - ops.cast(self.bias, DType.float32) + self.bias if self.bias # If bias wasn't passed then use bias-less layer norm (beta = 0). else ops.broadcast_to( - ops.constant(0.0, DType.float32, self.weight.device), + ops.constant(0.0, self.dtype, input.device), + shape=(input.shape[-1],), + ) + ) + gamma = ( + self.weight + if self.weight + else ops.broadcast_to( + ops.constant(1.0, self.dtype, input.device), shape=(input.shape[-1],), ) ) - return ops.layer_norm( - input.cast(DType.float32), - gamma=ops.cast(self.weight, DType.float32), - beta=bias, + + output = ops.layer_norm( + input=input if self.keep_dtype else input.cast(DType.float32), + gamma=gamma if self.keep_dtype else ops.cast(gamma, DType.float32), + beta=bias if self.keep_dtype else ops.cast(bias, DType.float32), epsilon=self.eps, - ).cast(input.dtype) + ) + return output if self.keep_dtype else output.cast(input.dtype) @property def sharding_strategy(self) -> ShardingStrategy | None: diff --git a/max/python/max/pipelines/architectures/__init__.py b/max/python/max/pipelines/architectures/__init__.py index 3745c11317b..cdc222baa83 100644 --- a/max/python/max/pipelines/architectures/__init__.py +++ b/max/python/max/pipelines/architectures/__init__.py @@ -28,6 +28,7 @@ def register_all_models() -> None: from .deepseekV3 import deepseekV3_arch from .eagle_llama3 import eagle_llama_arch from .exaone import exaone_arch + from .flux1 import flux1_arch from .gemma3 import gemma3_arch from .gemma3multimodal import gemma3_multimodal_arch from .gpt_oss import gpt_oss_arch @@ -54,6 +55,7 @@ def register_all_models() -> None: deepseekV2_arch, deepseekV3_arch, eagle_llama_arch, + flux1_arch, gemma3_arch, gemma3_multimodal_arch, granite_arch, diff --git a/max/python/max/pipelines/architectures/autoencoder_kl/__init__.py b/max/python/max/pipelines/architectures/autoencoder_kl/__init__.py new file mode 100644 index 00000000000..e18af050cf2 --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoder_kl/__init__.py @@ -0,0 +1,14 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .model import AutoencoderKLModel diff --git a/max/python/max/pipelines/architectures/autoencoder_kl/autoencoder_kl.py b/max/python/max/pipelines/architectures/autoencoder_kl/autoencoder_kl.py new file mode 100644 index 00000000000..ef80846d949 --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoder_kl/autoencoder_kl.py @@ -0,0 +1,750 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from dataclasses import dataclass + +import max.nn as nn +from max.dtype import DType +from max.graph import DeviceRef, TensorType, TensorValue, ops +from max.nn import GroupNorm +from max.nn.layer.layer_list import LayerList + +from .layers import Upsample2D +from .model_config import AutoencoderKLConfig + + +class ResnetBlock2D(nn.Module): + """Residual block for 2D VAE decoder. + + This module implements a residual block with two convolutional layers, + group normalization, and optional shortcut connection. It supports + time embedding conditioning and configurable activation functions. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int | None, + groups: int, + groups_out: int, + eps: float = 1e-6, + non_linearity: str = "silu", + use_conv_shortcut: bool = False, + conv_shortcut_bias: bool = True, + device: DeviceRef | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize ResnetBlock2D module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + temb_channels: Number of time embedding channels (None if not used). + groups: Number of groups for first GroupNorm. + groups_out: Number of groups for second GroupNorm. + eps: Epsilon value for GroupNorm layers. + non_linearity: Activation function name (e.g., "silu"). + use_conv_shortcut: Whether to use convolutional shortcut. + conv_shortcut_bias: Whether to use bias in shortcut convolution. + device: Device reference for module placement. + dtype: Data type for module parameters. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.use_conv_shortcut = use_conv_shortcut + + self.norm1 = GroupNorm( + num_groups=groups, + num_channels=in_channels, + eps=eps, + affine=True, + device=device, + dtype=dtype, + ) + + self.conv1 = nn.Conv2d( + kernel_size=3, + in_channels=in_channels, + out_channels=out_channels, + dtype=dtype, + stride=1, + padding=1, + dilation=1, + num_groups=1, + has_bias=True, + device=device, + permute=True, + ) + + self.norm2 = GroupNorm( + num_groups=groups_out, + num_channels=out_channels, + eps=eps, + affine=True, + device=device, + dtype=dtype, + ) + + self.conv2 = nn.Conv2d( + kernel_size=3, + in_channels=out_channels, + out_channels=out_channels, + dtype=dtype, + stride=1, + padding=1, + dilation=1, + num_groups=1, + has_bias=True, + device=device, + permute=True, + ) + + self.conv_shortcut = None + if self.use_conv_shortcut: + self.conv_shortcut = nn.Conv2d( + kernel_size=1, + in_channels=in_channels, + out_channels=out_channels, + dtype=dtype, + stride=1, + padding=0, + dilation=1, + num_groups=1, + has_bias=conv_shortcut_bias, + device=device, + permute=True, + ) + elif in_channels != out_channels: + self.conv_shortcut = nn.Conv2d( + kernel_size=1, + in_channels=in_channels, + out_channels=out_channels, + dtype=dtype, + stride=1, + padding=0, + dilation=1, + num_groups=1, + has_bias=conv_shortcut_bias, + device=device, + permute=True, + ) + + def __call__( + self, x: TensorValue, temb: TensorValue | None = None + ) -> TensorValue: + """Apply ResnetBlock2D forward pass. + + Args: + x: Input tensor of shape [N, C, H, W]. + temb: Optional time embedding tensor (currently unused). + + Returns: + Output tensor of shape [N, C_out, H, W] with residual connection. + """ + shortcut = ( + self.conv_shortcut(x) if self.conv_shortcut is not None else x + ) + + h = ops.silu(self.norm1(x)) + h = self.conv1(h) + + h = ops.silu(self.norm2(h)) + h = self.conv2(h) + + return h + shortcut + + +class UpDecoderBlock2D(nn.Module): + """Upsampling decoder block for 2D VAE. + + This module consists of multiple ResNet blocks followed by an optional + upsampling layer. It progressively increases spatial resolution while + processing features through residual connections. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: int | None = None, + device: DeviceRef | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize UpDecoderBlock2D module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + resolution_idx: Optional resolution index for tracking. + dropout: Dropout rate (currently unused). + num_layers: Number of ResNet blocks in this decoder block. + resnet_eps: Epsilon value for ResNet GroupNorm layers. + resnet_time_scale_shift: Time embedding scale/shift mode. + resnet_act_fn: Activation function for ResNet blocks. + resnet_groups: Number of groups for ResNet GroupNorm. + resnet_pre_norm: Whether to apply normalization before ResNet. + output_scale_factor: Scaling factor for output (currently unused). + add_upsample: Whether to add upsampling layer after ResNet blocks. + temb_channels: Number of time embedding channels (None if not used). + device: Device reference for module placement. + dtype: Data type for module parameters. + """ + super().__init__() + resnets_list = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnet = ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + groups=resnet_groups, + groups_out=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + use_conv_shortcut=False, + conv_shortcut_bias=True, + device=device, + dtype=dtype, + ) + resnets_list.append(resnet) + self.resnets = LayerList(resnets_list) + + if add_upsample: + upsampler = Upsample2D( + channels=out_channels, + use_conv=True, + out_channels=out_channels, + name="conv", + kernel_size=3, + padding=1, + bias=True, + interpolate=True, + device=device, + dtype=dtype, + ) + self.upsamplers = LayerList([upsampler]) + else: + self.upsamplers = None + + def __call__( + self, hidden_states: TensorValue, temb: TensorValue | None = None + ) -> TensorValue: + """Apply UpDecoderBlock2D forward pass. + + Args: + hidden_states: Input tensor of shape [N, C_in, H, W]. + temb: Optional time embedding tensor. + + Returns: + Output tensor of shape [N, C_out, H*2, W*2] (if upsampling) or + [N, C_out, H, W] (if no upsampling). + """ + # Process through all resnet blocks + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + + # Apply upsampling if configured (compile-time decision) + if self.upsamplers is not None: + hidden_states = self.upsamplers[0](hidden_states) + + return hidden_states + + +class VAEAttention(nn.Module): + """Spatial attention module for VAE models. + + This module performs self-attention on 2D spatial features by: + 1. Converting [N, C, H, W] to [N, H*W, C] sequence format + 2. Applying scaled dot-product attention (optimized for small sequences) + 3. Converting back to [N, C, H, W] format + + Note: Manual attention is used instead of flash_attention_gpu because + VAE attention typically has small sequence lengths (H*W) where flash + attention overhead outweighs benefits. + """ + + def __init__( + self, + query_dim: int, + heads: int, + dim_head: int, + num_groups: int = 32, + eps: float = 1e-6, + device: DeviceRef | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize VAE attention module. + + Args: + query_dim: Dimension of query (number of channels). + heads: Number of attention heads. + dim_head: Dimension of each attention head. + num_groups: Number of groups for GroupNorm. + eps: Epsilon value for GroupNorm. + device: Device reference. + dtype: Data type. + """ + super().__init__() + self.query_dim = query_dim + self.heads = heads + self.dim_head = dim_head + self.inner_dim = heads * dim_head + + self.group_norm = GroupNorm( + num_groups=num_groups, + num_channels=query_dim, + eps=eps, + affine=True, + device=device, + dtype=dtype, + ) + + self.to_q = nn.Linear( + query_dim, self.inner_dim, has_bias=True, device=device, dtype=dtype + ) + self.to_k = nn.Linear( + query_dim, self.inner_dim, has_bias=True, device=device, dtype=dtype + ) + self.to_v = nn.Linear( + query_dim, self.inner_dim, has_bias=True, device=device, dtype=dtype + ) + self.to_out = LayerList( + [ + nn.Linear( + self.inner_dim, + query_dim, + has_bias=True, + device=device, + dtype=dtype, + ) + ] + ) + + self.scale = 1.0 / math.sqrt(dim_head) + + def __call__(self, x: TensorValue) -> TensorValue: + """Apply spatial attention to 2D image tensor. + + Args: + x: Input tensor of shape [N, C, H, W]. + + Returns: + Output tensor of shape [N, C, H, W] with residual connection. + """ + residual = x + + x = self.group_norm(x) + + n, c, h, w = x.shape + seq_len = h * w + + x = ops.reshape(x, (n, c, seq_len)) + x = ops.permute(x, (0, 2, 1)) + + q = self.to_q(x) + k = self.to_k(x) + v = self.to_v(x) + + q = ops.reshape(q, (n, seq_len, self.heads, self.dim_head)) + q = ops.permute(q, (0, 2, 1, 3)) + k = ops.reshape(k, (n, seq_len, self.heads, self.dim_head)) + k = ops.permute(k, (0, 2, 1, 3)) + v = ops.reshape(v, (n, seq_len, self.heads, self.dim_head)) + v = ops.permute(v, (0, 2, 1, 3)) + + attn = q @ ops.permute(k, (0, 1, 3, 2)) * self.scale + attn = ops.softmax(attn, axis=-1) + out = attn @ v + + out = ops.permute(out, (0, 2, 1, 3)) + out = ops.reshape(out, (n, seq_len, self.inner_dim)) + + out = self.to_out[0](out) + + out = ops.permute(out, (0, 2, 1)) + out = ops.reshape(out, (n, c, h, w)) + + return residual + out + + +class MidBlock2D(nn.Module): + """Internal MAX module for MidBlock2D graph generation.""" + + def __init__( + self, + in_channels: int, + temb_channels: int | None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + device: DeviceRef | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize MidBlock2D module.""" + super().__init__() + resnets_list = [] + attentions_list = [] + + resnet = ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + groups=resnet_groups, + groups_out=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + use_conv_shortcut=False, + conv_shortcut_bias=True, + device=device, + dtype=dtype, + ) + resnets_list.append(resnet) + + for _i in range(num_layers): + if add_attention: + attn = VAEAttention( + query_dim=in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + num_groups=resnet_groups, + eps=resnet_eps, + device=device, + dtype=dtype, + ) + attentions_list.append(attn) + else: + attentions_list.append(None) + + resnet = ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + groups=resnet_groups, + groups_out=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + use_conv_shortcut=False, + conv_shortcut_bias=True, + device=device, + dtype=dtype, + ) + resnets_list.append(resnet) + + self.resnets = LayerList(resnets_list) + self.attentions = ( + LayerList(attentions_list) if attentions_list else None + ) + + def __call__( + self, hidden_states: TensorValue, temb: TensorValue | None = None + ) -> TensorValue: + """Apply MidBlock2D forward pass. + + Args: + hidden_states: Input tensor of shape [N, C, H, W]. + temb: Optional time embedding tensor. + + Returns: + Output tensor of shape [N, C, H, W] with same spatial dimensions. + """ + hidden_states = self.resnets[0](hidden_states, temb) + + for i in range(len(self.resnets) - 1): + if self.attentions is not None and self.attentions[i] is not None: + hidden_states = self.attentions[i](hidden_states) + hidden_states = self.resnets[i + 1](hidden_states, temb) + + return hidden_states + + +@dataclass +class DecoderOutput: + r"""Output of decoding method. + + Args: + sample (`TensorValue` of shape `(batch_size, num_channels, height, width)`): + The decoded output sample from the last layer of the model. + """ + + sample: TensorValue + commit_loss: TensorValue | None = None + + +class Decoder(nn.Module): + """VAE decoder for generating images from latent representations. + + This decoder progressively upsamples latent features through multiple + decoder blocks, applying ResNet layers and attention mechanisms to + reconstruct high-resolution images from compressed latent codes. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", + mid_block_add_attention: bool = True, + use_post_quant_conv: bool = True, + device: DeviceRef | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize Decoder module. + + Args: + in_channels: Number of input channels (latent channels). + out_channels: Number of output channels (image channels). + up_block_types: Tuple of upsampling block types. + block_out_channels: Tuple of channel counts for each decoder block. + layers_per_block: Number of ResNet layers per decoder block. + norm_num_groups: Number of groups for GroupNorm layers. + act_fn: Activation function name (e.g., "silu"). + norm_type: Normalization type ("group" or "spatial"). + mid_block_add_attention: Whether to add attention in middle block. + use_post_quant_conv: Whether to use post-quantization convolution. + device: Device reference for module placement. + dtype: Data type for module parameters. + """ + super().__init__() + self.layers_per_block = layers_per_block + self.session = None + self.in_channels = in_channels + self.device = device + self.dtype = dtype + + self.post_quant_conv = None + if use_post_quant_conv: + self.post_quant_conv = nn.Conv2d( + kernel_size=1, + in_channels=in_channels, + out_channels=in_channels, + dtype=dtype, + stride=1, + padding=0, + dilation=1, + num_groups=1, + has_bias=True, + device=device, + permute=True, + ) + + self.conv_in = nn.Conv2d( + kernel_size=3, + in_channels=in_channels, + out_channels=block_out_channels[-1], + dtype=dtype, + stride=1, + padding=1, + dilation=1, + num_groups=1, + has_bias=True, + device=device, + permute=True, + ) + + temb_channels = in_channels if norm_type == "spatial" else None + self.mid_block = MidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=temb_channels, + dropout=0.0, + num_layers=1, + resnet_eps=1e-6, + resnet_time_scale_shift=( + "default" if norm_type == "group" else norm_type + ), + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_pre_norm=True, + add_attention=mid_block_add_attention, + attention_head_dim=block_out_channels[-1], + output_scale_factor=1.0, + device=device, + dtype=dtype, + ) + + up_blocks_list = [] + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if up_block_type == "UpDecoderBlock2D": + up_block = UpDecoderBlock2D( + in_channels=prev_output_channel, + out_channels=output_channel, + resolution_idx=i, + dropout=0.0, + num_layers=self.layers_per_block + 1, + resnet_eps=1e-6, + resnet_time_scale_shift=norm_type, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_pre_norm=True, + output_scale_factor=1.0, + add_upsample=not is_final_block, + temb_channels=temb_channels, + device=device, + dtype=dtype, + ) + up_blocks_list.append(up_block) + else: + raise ValueError(f"Unsupported up_block_type: {up_block_type}") + + prev_output_channel = output_channel + + self.up_blocks = LayerList(up_blocks_list) + + if norm_type == "spatial": + raise NotImplementedError("SpatialNorm not implemented in MAX VAE") + else: + self.conv_norm_out = GroupNorm( + num_groups=norm_num_groups, + num_channels=block_out_channels[0], + eps=1e-6, + affine=True, + device=device, + dtype=dtype, + ) + + self.conv_out = nn.Conv2d( + kernel_size=3, + in_channels=block_out_channels[0], + out_channels=out_channels, + dtype=dtype, + stride=1, + padding=1, + dilation=1, + num_groups=1, + has_bias=True, + device=device, + permute=True, + ) + + def __call__( + self, z: TensorValue, temb: TensorValue | None = None + ) -> TensorValue: + """Apply Decoder forward pass. + + Args: + z: Input latent tensor of shape [N, C_latent, H_latent, W_latent]. + temb: Optional time embedding tensor. + + Returns: + Decoded image tensor of shape [N, C_out, H, W] where H and W are + upsampled from H_latent and W_latent. + """ + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + sample = self.conv_in(z) + sample = self.mid_block(sample, temb) + + for up_block in self.up_blocks: + sample = up_block(sample, temb) + + sample = self.conv_norm_out(sample) + sample = ops.silu(sample) + sample = self.conv_out(sample) + + return sample + + def input_types(self) -> tuple[TensorType, ...]: + """Define input tensor types for the decoder model. + + Returns: + Tuple of TensorType specifications for decoder input. + """ + latent_type = TensorType( + self.dtype, + shape=[ + "batch_size", + self.in_channels, + "latent_height", + "latent_width", + ], + device=self.device, + ) + + return (latent_type,) + + +class AutoencoderKL(nn.Module): + r"""A VAE model with KL loss for encoding images into latents and decoding latent representations into images.""" + + def __init__( + self, + config: AutoencoderKLConfig, + ): + """Initialize VAE AutoencoderKL model. + + Args: + config: Autoencoder configuration containing channel sizes, block + structure, normalization settings, and device/dtype information. + """ + super().__init__() + self.decoder = Decoder( + in_channels=config.latent_channels, + out_channels=config.out_channels, + up_block_types=config.up_block_types, + block_out_channels=config.block_out_channels, + layers_per_block=config.layers_per_block, + norm_num_groups=config.norm_num_groups, + act_fn=config.act_fn, + norm_type="group", + mid_block_add_attention=config.mid_block_add_attention, + use_post_quant_conv=config.use_post_quant_conv, + device=config.device, + dtype=config.dtype, + ) + + def __call__(self, *args, **kwargs): + pass diff --git a/max/python/max/pipelines/architectures/autoencoder_kl/layers/__init__.py b/max/python/max/pipelines/architectures/autoencoder_kl/layers/__init__.py new file mode 100644 index 00000000000..ba1d3f82854 --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoder_kl/layers/__init__.py @@ -0,0 +1,14 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .upsampling import Upsample2D diff --git a/max/python/max/pipelines/architectures/autoencoder_kl/layers/upsampling.py b/max/python/max/pipelines/architectures/autoencoder_kl/layers/upsampling.py new file mode 100644 index 00000000000..35a526b32b0 --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoder_kl/layers/upsampling.py @@ -0,0 +1,168 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Upsampling utilities for MAX framework.""" + +import max.nn as nn +from max.dtype import DType +from max.experimental import tensor +from max.graph import DeviceRef, TensorValue, ops + + +class Interpolate2DNearest(nn.Module): + """2D nearest-neighbor upsampling module. + + This is a workaround implementation because MAX framework does not have + a native `interpolate` operation. The workaround uses reshape and broadcast + operations to achieve nearest-neighbor upsampling by a factor of 2. + + Note: + This workaround can be removed once MAX framework adds native interpolate support. + """ + + def __init__( + self, + scale_factor: int = 2, + device: DeviceRef = None, + dtype: DType = None, + ): + """Initialize 2D nearest-neighbor interpolation module. + + Args: + scale_factor: Upsampling factor (currently only 2 is supported). + device: Device reference for creating intermediate tensors. + dtype: Data type for intermediate tensors. + """ + super().__init__() + if scale_factor != 2: + raise NotImplementedError( + f"Only scale_factor=2 is currently supported, got {scale_factor}" + ) + + self.scale_factor = scale_factor + self.device = device + self.dtype = dtype + + def __call__(self, x: TensorValue) -> TensorValue: + """Upsample a 2D tensor using nearest-neighbor interpolation. + + Args: + x: Input tensor of shape [N, C, H, W]. + + Returns: + Upsampled tensor of shape [N, C, H*scale_factor, W*scale_factor]. + """ + n, c, h, w = x.shape + target_shape = [n, c, h * self.scale_factor, w * self.scale_factor] + + x_reshaped = ops.reshape(x, [n, c, h, 1, w, 1]) + + ones = tensor.Tensor.ones( + shape=(1, 1, 1, self.scale_factor, 1, self.scale_factor), + dtype=self.dtype, + device=self.device, + ) + x_expanded = x_reshaped * ones + + x = ops.reshape(x_expanded, target_shape) + + return x + + +class Upsample2D(nn.Module): + """2D upsampling module with optional convolution. + + This module performs 2D upsampling using nearest-neighbor interpolation + (via Interpolate2DNearest workaround) followed by an optional convolution layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: int | None = None, + name: str = "conv", + kernel_size: int | None = None, + padding: int = 1, + bias: bool = True, + interpolate: bool = True, + device: DeviceRef | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize 2D upsampling module. + + Args: + channels: Number of input channels. + use_conv: Whether to apply a convolution after upsampling. + use_conv_transpose: Whether to use transposed convolution (not supported yet). + out_channels: Number of output channels. If None, uses channels. + name: Name for the convolution layer (unused, kept for compatibility). + kernel_size: Kernel size for the convolution. + padding: Padding for the convolution. + bias: Whether to use bias in the convolution. + interpolate: Whether to perform interpolation upsampling. + device: Device reference. + dtype: Data type. + """ + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.interpolate = interpolate + self.device = device + self.dtype = dtype + + self.interpolate_module = None + if interpolate: + self.interpolate_module = Interpolate2DNearest( + scale_factor=2, device=device, dtype=dtype + ) + + self.conv = None + if use_conv_transpose: + raise NotImplementedError( + "Upsample2D does not support use_conv_transpose=True yet." + ) + elif use_conv: + if kernel_size is None: + kernel_size = 3 + self.conv = nn.Conv2d( + kernel_size=kernel_size, + in_channels=self.channels, + out_channels=self.out_channels, + dtype=dtype, + stride=1, + padding=padding, + has_bias=bias, + device=device, + permute=True, + ) + + def __call__(self, x: TensorValue) -> TensorValue: + """Apply 2D upsampling with optional convolution. + + Args: + x: Input tensor of shape [N, C, H, W]. + + Returns: + Upsampled tensor, optionally convolved. + """ + if self.interpolate_module is not None: + x = self.interpolate_module(x) + + if self.use_conv: + x = self.conv(x) + + return x diff --git a/max/python/max/pipelines/architectures/autoencoder_kl/model.py b/max/python/max/pipelines/architectures/autoencoder_kl/model.py new file mode 100644 index 00000000000..04ba6006d3d --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoder_kl/model.py @@ -0,0 +1,74 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.driver import CPU, Accelerator, Device, Tensor +from max.engine import InferenceSession, Model +from max.graph import Graph +from max.graph.weights import Weights +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.interfaces.max_model import MaxModel + +from .autoencoder_kl import AutoencoderKL +from .model_config import AutoencoderKLConfig + + +class AutoencoderKLModel(MaxModel): + config_name = AutoencoderKLConfig.config_name + + def __init__( + self, + config: dict, + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + ) -> None: + super().__init__(config, encoding, devices, weights) + self.config = AutoencoderKLConfig.generate( + config, + encoding, + devices, + ) + self.load_model() + + def load_model(self) -> None: + autoencoder_kl = AutoencoderKL(self.config) + + if self.config.device.is_cpu(): + session = InferenceSession([CPU()]) + else: + session = InferenceSession([Accelerator()]) + + self.load_decoder(session, autoencoder_kl) + + def load_decoder( + self, session: InferenceSession, autoencoder_kl: AutoencoderKL + ) -> Model: + state_dict = { + key: value.data() + for key, value in self.weights.items() + if not key.startswith("encoder.") + } + autoencoder_kl.load_state_dict(state_dict) + with Graph( + "autoencoder_kl_decoder", + input_types=autoencoder_kl.decoder.input_types(), + ) as graph: + outputs = autoencoder_kl.decoder(*graph.inputs) + graph.output(outputs) + compiled_graph = graph + self.decode_session = session.load( + compiled_graph, weights_registry=autoencoder_kl.state_dict() + ) + + def decode(self, *args, **kwargs) -> list[Tensor]: + return self.decode_session.execute(*args, **kwargs) diff --git a/max/python/max/pipelines/architectures/autoencoder_kl/model_config.py b/max/python/max/pipelines/architectures/autoencoder_kl/model_config.py new file mode 100644 index 00000000000..f3b28565fc4 --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoder_kl/model_config.py @@ -0,0 +1,66 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from typing import ClassVar + +from max.driver import Device +from max.dtype import DType +from max.graph import DeviceRef +from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding +from pydantic import Field + + +class AutoencoderKLConfigBase(MAXModelConfigBase): + in_channels: int = 3 + out_channels: int = 3 + down_block_types: list[str] = Field(default_factory=list, max_length=4) + up_block_types: list[str] = Field(default_factory=list, max_length=4) + block_out_channels: list[int] = Field(default_factory=list, max_length=4) + layers_per_block: int = 1 + act_fn: str = "silu" + latent_channels: int = 4 + norm_num_groups: int = 32 + sample_size: int = 32 + scaling_factor: float = 0.18215 + shift_factor: float | None = None + latents_mean: tuple[float] | None = None + latents_std: tuple[float] | None = None + force_upcast: bool = True + use_quant_conv: bool = True + use_post_quant_conv: bool = True + mid_block_add_attention: bool = True + device: DeviceRef = Field(default_factory=DeviceRef.CPU) + dtype: DType = DType.bfloat16 + + +class AutoencoderKLConfig(AutoencoderKLConfigBase): + config_name: ClassVar[str] = "config.json" + + @staticmethod + def generate( + config_dict: dict, + encoding: SupportedEncoding, + devices: list[Device], + ) -> AutoencoderKLConfigBase: + init_dict = { + key: value + for key, value in config_dict.items() + if key in AutoencoderKLConfigBase.__annotations__ + } + init_dict.update( + { + "dtype": encoding.dtype, + "device": DeviceRef.from_device(devices[0]), + } + ) + return AutoencoderKLConfigBase(**init_dict) diff --git a/max/python/max/pipelines/architectures/clip/__init__.py b/max/python/max/pipelines/architectures/clip/__init__.py new file mode 100644 index 00000000000..32cb1def84e --- /dev/null +++ b/max/python/max/pipelines/architectures/clip/__init__.py @@ -0,0 +1,14 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .model import ClipModel diff --git a/max/python/max/pipelines/architectures/clip/clip.py b/max/python/max/pipelines/architectures/clip/clip.py new file mode 100644 index 00000000000..95642bfbd93 --- /dev/null +++ b/max/python/max/pipelines/architectures/clip/clip.py @@ -0,0 +1,517 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from functools import partial + +import max.nn as nn +from max.dtype import DType +from max.graph import TensorType, TensorValue, ops +from max.nn import LayerNorm, Module + +from .model_config import ClipConfig + + +class CLIPTextEmbeddings(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP text embeddings. + + Args: + config: CLIP configuration for embedding dimensions and device/dtype. + """ + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.position_embedding = nn.Embedding( + config.max_position_embeddings, + self.embed_dim, + device=config.device, + dtype=config.dtype, + ) + self.token_embedding = nn.Embedding( + config.vocab_size, + self.embed_dim, + device=config.device, + dtype=config.dtype, + ) + + def __call__( + self, + input_ids: TensorValue | None = None, + position_ids: TensorValue | None = None, + inputs_embeds: TensorValue | None = None, + ) -> TensorValue: + """Apply embeddings to input tokens. + + Args: + input_ids: Input token IDs. + position_ids: Position IDs. + inputs_embeds: Pre-computed input embeddings. + + Returns: + Combined embeddings. + """ + if input_ids is None and inputs_embeds is None: + raise ValueError( + "You have to specify either input_ids or inputs_embeds" + ) + + if input_ids is not None: + seq_length = input_ids.shape[-1] + else: + seq_length = inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = ops.range( + 0, + seq_length, + step=1, + dtype=DType.int32, + device=self.config.device, + ) + position_ids = ops.unsqueeze(position_ids, 0) + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class CLIPAttention(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP attention module. + + Args: + config: CLIP configuration for attention dimensions and device/dtype. + """ + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear( + self.embed_dim, + self.embed_dim, + has_bias=True, + device=config.device, + dtype=config.dtype, + ) + self.v_proj = nn.Linear( + self.embed_dim, + self.embed_dim, + has_bias=True, + device=config.device, + dtype=config.dtype, + ) + self.q_proj = nn.Linear( + self.embed_dim, + self.embed_dim, + has_bias=True, + device=config.device, + dtype=config.dtype, + ) + self.out_proj = nn.Linear( + self.embed_dim, + self.embed_dim, + has_bias=True, + device=config.device, + dtype=config.dtype, + ) + + def __call__( + self, + hidden_states: TensorValue, + attention_mask: TensorValue | None = None, + causal_attention_mask: TensorValue | None = None, + ) -> TensorValue: + """Apply multi-head attention. + + Args: + hidden_states: Input hidden states. + attention_mask: Attention mask. + causal_attention_mask: Causal attention mask. + + Returns: + Attention output. + """ + batch_size, seq_length, embed_dim = hidden_states.shape + + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = ops.reshape( + query, (batch_size, seq_length, self.num_heads, self.head_dim) + ) + query = ops.transpose(query, 1, 2) + + key = ops.reshape( + key, (batch_size, seq_length, self.num_heads, self.head_dim) + ) + key = ops.transpose(key, 1, 2) + + value = ops.reshape( + value, (batch_size, seq_length, self.num_heads, self.head_dim) + ) + value = ops.transpose(value, 1, 2) + + if attention_mask is not None and causal_attention_mask is not None: + attention_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attention_mask = causal_attention_mask + + attn_weights = ( + ops.matmul(query, ops.transpose(key, -1, -2)) * self.scale + ) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = ops.softmax( + ops.cast(attn_weights, DType.float32), axis=-1 + ) + attn_weights = ops.cast(attn_weights, hidden_states.dtype) + + attn_output = ops.matmul(attn_weights, value) + attn_output = ops.transpose(attn_output, 1, 2) + attn_output = ops.reshape( + attn_output, (batch_size, seq_length, embed_dim) + ) + + attn_output = self.out_proj(attn_output) + + return attn_output + + +class CLIPMLP(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP MLP. + + Args: + config: CLIP configuration for MLP dimensions and device/dtype. + """ + super().__init__() + self.config = config + self.fc1 = nn.Linear( + config.hidden_size, + config.intermediate_size, + has_bias=True, + device=config.device, + dtype=config.dtype, + ) + self.fc2 = nn.Linear( + config.intermediate_size, + config.hidden_size, + has_bias=True, + device=config.device, + dtype=config.dtype, + ) + self.act_fn = partial(ops.gelu, approximate="quick") + + def __call__(self, hidden_states: TensorValue) -> TensorValue: + """Apply MLP block. + + Args: + hidden_states: Input hidden states. + + Returns: + Output hidden states. + """ + hidden_states = self.fc1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class CLIPEncoderLayer(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP encoder layer. + + Args: + config: CLIP configuration for encoder layer structure. + """ + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = CLIPAttention(config) + self.layer_norm1 = LayerNorm( + self.embed_dim, + eps=config.layer_norm_eps, + devices=[config.device], + dtype=config.dtype, + keep_dtype=True, + ) + self.mlp = CLIPMLP(config) + self.layer_norm2 = LayerNorm( + self.embed_dim, + eps=config.layer_norm_eps, + devices=[config.device], + dtype=config.dtype, + keep_dtype=True, + ) + + def __call__( + self, + hidden_states: TensorValue, + attention_mask: TensorValue, + causal_attention_mask: TensorValue, + ) -> TensorValue: + """Apply encoder layer. + + Args: + hidden_states: Input hidden states. + attention_mask: Attention mask. + causal_attention_mask: Causal attention mask. + + Returns: + Output hidden states. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class CLIPEncoder(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP encoder. + + Args: + config: CLIP configuration for encoder depth and dimensions. + """ + super().__init__() + self.layers = nn.LayerList( + [CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + + def __call__( + self, + inputs_embeds: TensorValue, + attention_mask: TensorValue | None = None, + causal_attention_mask: TensorValue | None = None, + ) -> TensorValue: + """Apply encoder (stack of layers). + + Args: + inputs_embeds: Input embeddings. + attention_mask: Attention mask. + causal_attention_mask: Causal attention mask. + + Returns: + Encoded hidden states. + """ + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + ) + return hidden_states + + +class CLIPTextTransformer(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP text transformer. + + Args: + config: CLIP configuration for embeddings, encoder, and device/dtype. + """ + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.embeddings = CLIPTextEmbeddings(config) + self.encoder = CLIPEncoder(config) + self.final_layer_norm = LayerNorm( + self.embed_dim, + eps=config.layer_norm_eps, + devices=[config.device], + dtype=config.dtype, + keep_dtype=True, + ) + self.eos_token_id = config.eos_token_id + + def _create_causal_mask(self, input_shape: tuple[int, int]) -> TensorValue: + """Create causal mask for the transformer. + + Args: + input_shape: Shape of the input tensor. + + Returns: + Causal mask tensor. + """ + _, seq_length = input_shape + + rows = ops.range( + 0, seq_length, step=1, dtype=DType.int32, device=self.config.device + ) + rows = ops.unsqueeze(rows, 1) + cols = ops.range( + 0, seq_length, step=1, dtype=DType.int32, device=self.config.device + ) + cols = ops.unsqueeze(cols, 0) + mask = ops.greater(cols, rows) + mask_float = mask.cast(self.config.dtype) + + min_val = DType.finfo(self.config.dtype).min + + causal_mask = mask_float * min_val + causal_mask = ops.unsqueeze(causal_mask, 0) + causal_mask = ops.unsqueeze(causal_mask, 1) + return causal_mask + + def __call__( + self, + input_ids: TensorValue | None = None, + attention_mask: TensorValue | None = None, + position_ids: TensorValue | None = None, + ) -> TensorValue: + """Apply text transformer. + + Args: + input_ids: Input token IDs. + attention_mask: Attention mask. + position_ids: Position IDs. + + Returns: + Tuple of (last_hidden_state, pooled_output). + """ + if input_ids is None: + raise ValueError("You have to specify input_ids") + + hidden_states = self.embeddings( + input_ids=input_ids, position_ids=position_ids + ) + + input_shape = input_ids.shape + causal_attention_mask = self._create_causal_mask(input_shape) + + if attention_mask is not None: + inverted_mask = ( + 1.0 - attention_mask.cast(hidden_states.dtype) + ) * DType.finfo(hidden_states.dtype).min + attention_mask = ops.unsqueeze(inverted_mask, 1) + attention_mask = ops.unsqueeze(attention_mask, 1) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + ) + + last_hidden_state = self.final_layer_norm(encoder_outputs) + + if self.eos_token_id == 2: + eos_token_indices = ops.argmax(input_ids, axis=-1).cast(DType.int32) + else: + eos_token_indices = ops.argmax( + ops.equal(input_ids, self.eos_token_id).cast(DType.int32), + axis=-1, + ).cast(DType.int32) + + pooled_output = ops.gather_nd( + last_hidden_state, eos_token_indices, batch_dims=1 + ) + + return last_hidden_state, pooled_output + + +class CLIPTextModel(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP text model with MAX. + + Args: + config: CLIP configuration for vocabulary size, dimensions, and + device/dtype settings. + """ + super().__init__() + self.text_model = CLIPTextTransformer(config) + self.device = config.device + + def input_types(self) -> tuple[TensorType, ...]: + """Define input tensor types for the model. + + Returns: + Tuple of TensorType specifications for model inputs. + """ + return ( + TensorType( + DType.int64, + shape=["batch_size", "sequence_length"], + device=self.device, + ), + ) + + def __call__( + self, + input_ids: TensorValue | None = None, + attention_mask: TensorValue | None = None, + position_ids: TensorValue | None = None, + ) -> tuple[TensorValue, TensorValue]: + """Apply CLIP text model forward pass. + + Args: + input_ids: Input token IDs. + attention_mask: Attention mask. + position_ids: Position IDs. + + Returns: + Tuple of (last_hidden_state, pooled_output). + """ + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) diff --git a/max/python/max/pipelines/architectures/clip/model.py b/max/python/max/pipelines/architectures/clip/model.py new file mode 100644 index 00000000000..df9f467591b --- /dev/null +++ b/max/python/max/pipelines/architectures/clip/model.py @@ -0,0 +1,70 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.driver import CPU, Accelerator, Device +from max.engine import InferenceSession, Model +from max.graph import Graph +from max.graph.weights import Weights +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.interfaces.max_model import MaxModel + +from .clip import CLIPTextModel +from .model_config import ClipConfig + + +class ClipModel(MaxModel): + config_name = ClipConfig.config_name + + def __init__( + self, + config: dict, + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + ) -> None: + super().__init__( + config, + encoding, + devices, + weights, + ) + self.config = ClipConfig.generate( + config, + encoding, + devices, + ) + self.load_model() + + def load_model(self) -> Model: + clip = CLIPTextModel(self.config) + + if self.config.device.is_cpu(): + session = InferenceSession([CPU()]) + else: + session = InferenceSession([Accelerator()]) + state_dict = {key: value.data() for key, value in self.weights.items()} + clip.load_state_dict(state_dict) + with Graph("clip_text_model", input_types=clip.input_types()) as graph: + outputs = clip( + *graph.inputs, + attention_mask=None, + position_ids=None, + ) + graph.output(*outputs) + compiled_graph = graph + self.session = session.load( + compiled_graph, weights_registry=clip.state_dict() + ) + + def __call__(self, *args, **kwargs): + return self.session.execute(*args, **kwargs) diff --git a/max/python/max/pipelines/architectures/clip/model_config.py b/max/python/max/pipelines/architectures/clip/model_config.py new file mode 100644 index 00000000000..ccf9ba35083 --- /dev/null +++ b/max/python/max/pipelines/architectures/clip/model_config.py @@ -0,0 +1,63 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from typing import ClassVar + +from max.driver import Device +from max.dtype import DType +from max.graph import DeviceRef +from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding +from pydantic import Field + + +class ClipConfigBase(MAXModelConfigBase): + vocab_size: int = 49408 + hidden_size: int = 512 + intermediate_size: int = 2048 + projection_dim: int = 512 + num_hidden_layers: int = 12 + num_attention_heads: int = 8 + max_position_embeddings: int = 77 + hidden_act: str = "quick_gelu" + layer_norm_eps: float = 1e-5 + attention_dropout: float = 0.0 + initializer_range: float = 0.02 + initializer_factor: float = 1.0 + pad_token_id: int = 1 + bos_token_id: int = 49406 + eos_token_id: int = 49407 + dtype: DType = DType.bfloat16 + device: DeviceRef = Field(default_factory=DeviceRef.GPU) + + +class ClipConfig(ClipConfigBase): + config_name: ClassVar[str] = "config.json" + + @staticmethod + def generate( + config_dict: dict, + encoding: SupportedEncoding, + devices: list[Device], + ) -> ClipConfigBase: + init_dict = { + key: value + for key, value in config_dict.items() + if key in ClipConfigBase.__annotations__ + } + init_dict.update( + { + "dtype": encoding.dtype, + "device": DeviceRef.from_device(devices[0]), + } + ) + return ClipConfigBase(**init_dict) diff --git a/max/python/max/pipelines/architectures/flux1/__init__.py b/max/python/max/pipelines/architectures/flux1/__init__.py new file mode 100644 index 00000000000..2325700031e --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/__init__.py @@ -0,0 +1,14 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .arch import flux1_arch diff --git a/max/python/max/pipelines/architectures/flux1/arch.py b/max/python/max/pipelines/architectures/flux1/arch.py new file mode 100644 index 00000000000..aea17022398 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/arch.py @@ -0,0 +1,38 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.graph.weights import WeightsFormat +from max.interfaces import BaseContext, PipelineTask +from max.pipelines.lib import ( + SupportedArchitecture, + SupportedEncoding, + TextTokenizer, +) + +from .pipeline_flux import FluxPipeline + +# TODO(minkyu): revisit default_encoding, supported_encodings, tokenizer. +flux1_arch = SupportedArchitecture( + name="FluxPipeline", + task=PipelineTask.IMAGE_GENERATION, + default_encoding=SupportedEncoding.bfloat16, + supported_encodings={SupportedEncoding.bfloat16: []}, + example_repo_ids=[ + "black-forest-labs/FLUX.1-dev", + "black-forest-labs/FLUX.1-schnell", + ], + pipeline_model=FluxPipeline, + tokenizer=TextTokenizer, + context_type=BaseContext, + default_weights_format=WeightsFormat.safetensors, +) diff --git a/max/python/max/pipelines/architectures/flux1/flux1.py b/max/python/max/pipelines/architectures/flux1/flux1.py new file mode 100644 index 00000000000..ed553ef159a --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/flux1.py @@ -0,0 +1,544 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import logging +import os +from collections.abc import Generator +from contextlib import contextmanager +from os import PathLike +from typing import Any + +import max.nn as nn +from max.driver import DLPackArray +from max.dtype import DType +from max.graph import DeviceRef, TensorType, TensorValue, ops +from max.graph.weights import SafetensorWeights +from max.nn import LayerNorm, Module + +from .layers.embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, +) +from .layers.flux_attention import FeedForward, FluxAttention, FluxPosEmbed +from .layers.normalizations import ( + AdaLayerNormContinuous, + AdaLayerNormZero, + AdaLayerNormZeroSingle, +) +from .model_config import FluxConfig + +logger = logging.getLogger(__name__) + + +def get_weight_registry_from_diffusers( + safe_tensor_folder: PathLike, +) -> dict[str, DLPackArray]: + weight_files = [ + os.path.join(safe_tensor_folder, f) + for f in os.listdir(safe_tensor_folder) + if f.endswith(".safetensors") + ] + weights = SafetensorWeights(weight_files) + return {name: weight.data().data for name, weight in weights.items()} + + +class FluxSingleTransformerBlock(Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize Flux single transformer block. + + Args: + dim: Dimension of the input/output. + num_attention_heads: Number of attention heads. + attention_head_dim: Dimension of each attention head. + mlp_ratio: Ratio for MLP hidden dimension. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim, device=device, dtype=dtype) + self.proj_mlp = nn.Linear( + dim, self.mlp_hidden_dim, has_bias=True, device=device, dtype=dtype + ) + self.act_mlp = ops.gelu + self.proj_out = nn.Linear( + dim + self.mlp_hidden_dim, + dim, + has_bias=True, + device=device, + dtype=dtype, + ) + self.attn = FluxAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + eps=1e-6, + pre_only=True, + device=device, + dtype=dtype, + ) + + def __call__( + self, + hidden_states: TensorValue, + encoder_hidden_states: TensorValue, + temb: TensorValue, + image_rotary_emb: tuple[TensorValue, TensorValue] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[TensorValue, TensorValue]: + """Apply single transformer block with attention and MLP. + + Args: + hidden_states: Input hidden states. + encoder_hidden_states: Encoder hidden states for cross-attention. + temb: Time embedding. + image_rotary_emb: Optional rotary position embeddings. + joint_attention_kwargs: Optional attention kwargs. + + Returns: + Tuple of (encoder_hidden_states, hidden_states). + """ + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = ops.concat( + [encoder_hidden_states, hidden_states], axis=1 + ) + + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp( + self.proj_mlp(norm_hidden_states), approximate="tanh" + ) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = ops.concat([attn_output, mlp_hidden_states], axis=2) + gate = ops.unsqueeze(gate, 1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == DType.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + encoder_hidden_states, hidden_states = ( + hidden_states[:, :text_seq_len], + hidden_states[:, text_seq_len:], + ) + return encoder_hidden_states, hidden_states + + +class FluxTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize Flux transformer block. + + Args: + dim: Dimension of the input/output. + num_attention_heads: Number of attention heads. + attention_head_dim: Dimension of each attention head. + qk_norm: Type of normalization for query and key ("rms_norm"). + eps: Epsilon for normalization layers. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + + self.norm1 = AdaLayerNormZero(dim, device=device, dtype=dtype) + self.norm1_context = AdaLayerNormZero(dim, device=device, dtype=dtype) + + self.attn = FluxAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + eps=eps, + device=device, + dtype=dtype, + ) + + self.norm2 = LayerNorm( + dim, + eps=1e-6, + devices=[device], + dtype=dtype, + keep_dtype=True, + elementwise_affine=False, + ) + self.ff = FeedForward( + dim=dim, + dim_out=dim, + activation_fn="gelu-approximate", + device=device, + dtype=dtype, + ) + + self.norm2_context = LayerNorm( + dim, + eps=1e-6, + devices=[device], + dtype=dtype, + keep_dtype=True, + elementwise_affine=False, + ) + self.ff_context = FeedForward( + dim=dim, + dim_out=dim, + activation_fn="gelu-approximate", + device=device, + dtype=dtype, + ) + + def __call__( + self, + hidden_states: TensorValue, + encoder_hidden_states: TensorValue, + temb: TensorValue, + image_rotary_emb: tuple[TensorValue, TensorValue] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[TensorValue, TensorValue]: + """Apply transformer block with dual-stream attention and feedforward. + + Args: + hidden_states: Input hidden states. + encoder_hidden_states: Encoder hidden states for cross-attention. + temb: Time embedding. + image_rotary_emb: Optional rotary position embeddings. + joint_attention_kwargs: Optional attention kwargs. + + Returns: + Tuple of (encoder_hidden_states, hidden_states). + """ + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.norm1(hidden_states, emb=temb) + ) + + ( + norm_encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = self.norm1_context(encoder_hidden_states, emb=temb) + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + attn_output, context_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = ops.unsqueeze(gate_msa, 1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + + ff_output = self.ff(norm_hidden_states) + ff_output = ops.unsqueeze(gate_mlp, 1) * ff_output + + hidden_states = hidden_states + ff_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = ops.unsqueeze(c_gate_msa, 1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = ( + norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + + c_shift_mlp[:, None] + ) + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = ( + encoder_hidden_states + + ops.unsqueeze(c_gate_mlp, 1) * context_ff_output + ) + if encoder_hidden_states.dtype == DType.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class FluxTransformer2DModel(nn.Module): + def __init__( + self, + config: FluxConfig, + ): + """Initialize Flux Transformer 2D model. + + Args: + config: Flux configuration containing model dimensions, attention + settings, and device/dtype information. + """ + super().__init__() + patch_size = config.patch_size + in_channels = config.in_channels + out_channels = config.out_channels + num_layers = config.num_layers + num_single_layers = config.num_single_layers + attention_head_dim = config.attention_head_dim + num_attention_heads = config.num_attention_heads + joint_attention_dim = config.joint_attention_dim + pooled_projection_dim = config.pooled_projection_dim + guidance_embeds = config.guidance_embeds + axes_dims_rope = config.axes_dims_rope + device = config.device + dtype = config.dtype + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + self.guidance_embeds = guidance_embeds + + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings + if guidance_embeds + else CombinedTimestepTextProjEmbeddings + ) + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, + pooled_projection_dim=pooled_projection_dim, + device=device, + dtype=dtype, + ) + self.context_embedder = nn.Linear( + joint_attention_dim, + self.inner_dim, + has_bias=True, + device=device, + dtype=dtype, + ) + self.x_embedder = nn.Linear( + in_channels, + self.inner_dim, + has_bias=True, + device=device, + dtype=dtype, + ) + + self.transformer_blocks = nn.Sequential( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + device=device, + dtype=dtype, + ) + for _ in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.Sequential( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + device=device, + dtype=dtype, + ) + for _ in range(num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, eps=1e-6, device=device, dtype=dtype + ) + self.proj_out = nn.Linear( + self.inner_dim, + patch_size * patch_size * self.out_channels, + has_bias=True, + device=device, + dtype=dtype, + ) + + self.gradient_checkpointing = False + + self.max_device = device + self.max_dtype = dtype + self.in_channels = in_channels + self.joint_attention_dim = joint_attention_dim + self.pooled_projection_dim = pooled_projection_dim + + self._cache_context_warning_shown = False + + def input_types(self) -> tuple[TensorType, ...]: + """Define input tensor types for the model. + + Returns: + Tuple of TensorType specifications for all model inputs. + """ + hidden_states_type = TensorType( + self.max_dtype, + shape=["batch_size", "image_seq_len", self.in_channels], + device=self.max_device, + ) + encoder_hidden_states_type = TensorType( + self.max_dtype, + shape=["batch_size", "text_seq_len", self.joint_attention_dim], + device=self.max_device, + ) + pooled_projections_type = TensorType( + self.max_dtype, + shape=["batch_size", self.pooled_projection_dim], + device=self.max_device, + ) + timestep_type = TensorType( + DType.float32, shape=["batch_size"], device=self.max_device + ) + img_ids_type = TensorType( + self.max_dtype, shape=["image_seq_len", 3], device=self.max_device + ) + txt_ids_type = TensorType( + self.max_dtype, shape=["text_seq_len", 3], device=self.max_device + ) + guidance_type = TensorType( + self.max_dtype, shape=["batch_size"], device=self.max_device + ) + + return ( + hidden_states_type, + encoder_hidden_states_type, + pooled_projections_type, + timestep_type, + img_ids_type, + txt_ids_type, + guidance_type, + ) + + @contextmanager + def cache_context(self, name: str) -> Generator[None, None, None]: + """Context manager for cache control (not implemented in MAX). + + Args: + name: Name of the cache context. + + Yields: + None. + """ + if not self._cache_context_warning_shown: + logger.warning( + "cache_context is not implemented in MAX FluxTransformer2DModel. " + "Caching optimizations are disabled." + ) + self._cache_context_warning_shown = True + yield + + def __call__( + self, + hidden_states: TensorValue, + encoder_hidden_states: TensorValue = None, + pooled_projections: TensorValue = None, + timestep: TensorValue = None, + img_ids: TensorValue = None, + txt_ids: TensorValue = None, + guidance: TensorValue = None, + joint_attention_kwargs: dict[str, Any] | None = None, + controlnet_block_samples: Any | None = None, + controlnet_single_block_samples: Any | None = None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> tuple[TensorValue]: + """Apply Flux Transformer 2D model forward pass. + + Args: + hidden_states: Input latent hidden states. + encoder_hidden_states: Text encoder hidden states. + pooled_projections: Pooled text embeddings. + timestep: Diffusion timestep. + img_ids: Image position IDs. + txt_ids: Text position IDs. + guidance: Guidance scale values. + joint_attention_kwargs: Additional attention arguments. + controlnet_block_samples: Optional controlnet block samples. + controlnet_single_block_samples: Optional controlnet single block samples. + return_dict: Whether to return as dictionary. + controlnet_blocks_repeat: Whether to repeat controlnet blocks. + + Returns: + Tuple containing output tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + + hidden_states = self.x_embedder(hidden_states) + + timestep = ops.cast(timestep, hidden_states.dtype) + timestep = timestep * 1000.0 + if guidance is not None: + guidance = guidance.cast(hidden_states.dtype) * 1000.0 + + temb = ( + self.time_text_embed(timestep, pooled_projections) + if not self.guidance_embeds + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + ids = ops.concat((txt_ids, img_ids), axis=0) + image_rotary_emb = self.pos_embed(ids) + + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + for block in self.single_transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + return (output,) diff --git a/max/python/max/pipelines/architectures/flux1/layers/__init__.py b/max/python/max/pipelines/architectures/flux1/layers/__init__.py new file mode 100644 index 00000000000..75c4f824f20 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/layers/__init__.py @@ -0,0 +1,12 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # diff --git a/max/python/max/pipelines/architectures/flux1/layers/activations.py b/max/python/max/pipelines/architectures/flux1/layers/activations.py new file mode 100644 index 00000000000..c5fc5a80aca --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/layers/activations.py @@ -0,0 +1,56 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import max.nn as nn +from max.dtype import DType +from max.graph import DeviceRef, TensorValue, ops + + +class GELU(nn.Module): + def __init__( + self, + dim_in: int, + dim_out: int, + approximate: str = "none", + bias: bool = True, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize GELU activation layer with linear projection. + + Args: + dim_in: Input dimension. + dim_out: Output dimension. + approximate: Approximation type for GELU ("none" or "tanh"). + bias: Whether to include bias in the linear projection. + device: Device to place the layer on. + dtype: Data type for the layer. + """ + super().__init__() + self.proj = nn.Linear( + dim_in, dim_out, has_bias=bias, dtype=dtype, device=device + ) + self.approximate = approximate + + def __call__(self, hidden_states: TensorValue) -> TensorValue: + """Apply GELU activation to the input. + + Args: + hidden_states: Input tensor. + + Returns: + Output tensor after linear projection and GELU activation. + """ + hidden_states = self.proj(hidden_states) + hidden_states = ops.gelu(hidden_states, approximate=self.approximate) + return hidden_states diff --git a/max/python/max/pipelines/architectures/flux1/layers/embeddings.py b/max/python/max/pipelines/architectures/flux1/layers/embeddings.py new file mode 100644 index 00000000000..cb50aa3f83b --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/layers/embeddings.py @@ -0,0 +1,471 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import math + +from max import nn +from max.dtype import DType +from max.graph import DeviceRef, TensorValue, ops + + +def apply_rotary_emb( + x: TensorValue, + freqs_cis: tuple[TensorValue, TensorValue], + sequence_dim: int = 2, +) -> TensorValue: + """Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query or key 'x' tensors using the provided frequency + tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped + for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors. + + Args: + x: Query or key tensor to apply rotary embeddings. Shape depends on + caller; the last dimension is split into complex pairs. + freqs_cis: Precomputed cosine/sine frequency tensors for complex + exponentials. Shape ([S, D], [S, D]). + sequence_dim: Dimension representing the sequence (1 or 2). + + Returns: + Tensor: Tensor with rotary embeddings applied. + """ + cos, sin = freqs_cis # [S, D] + if sequence_dim == 2: + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + elif sequence_dim == 1: + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + else: + raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") + + cos, sin = cos.to(x.device), sin.to(x.device) + + # Used for flux, cogvideox, hunyuan-dit + half_last_dim = x.shape[-1] // 2 + chunks = ops.chunk( + x.reshape(list(x.shape[:-1]) + [half_last_dim, 2]), chunks=2, axis=-1 + ) + x_real = ops.squeeze(chunks[0], axis=-1) + x_imag = ops.squeeze(chunks[1], axis=-1) + # Stack and flatten: [B, S, H, D//2] -> [B, S, H, D//2, 2] -> [B, S, H, D] + x_rotated_stacked = ops.stack([-x_imag, x_real], axis=-1) + batch_sz = x_rotated_stacked.shape[0] + seq_len = x_rotated_stacked.shape[1] + heads = x_rotated_stacked.shape[2] + flattened_last_dim = x_rotated_stacked.shape[3] * x_rotated_stacked.shape[4] + x_rotated = ops.reshape( + x_rotated_stacked, (batch_sz, seq_len, heads, flattened_last_dim) + ) + + out = ( + x.cast(DType.float32) * cos + x_rotated.cast(DType.float32) * sin + ).cast(x.dtype) + + return out + + +def get_timestep_embedding( + timesteps: TensorValue, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> TensorValue: + """Create sinusoidal timestep embeddings. + + Matches the implementation in Diffusers/DDPM. + """ + half_dim = embedding_dim // 2 + + # Create exponent: -math.log(max_period) * arange(0, half_dim) + # ops.range creates a sequence tensor + exponent = ops.range( + 0, half_dim, step=1, dtype=DType.float32, device=timesteps.device + ) + exponent = exponent * (-math.log(max_period)) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = ops.exp(exponent) + + # emb = timesteps[:, None].float() * emb[None, :] + timesteps_f32 = timesteps.cast(DType.float32) + timesteps_dim = timesteps_f32.shape[0] + emb_dim = emb.shape[0] + emb = timesteps_f32.reshape((timesteps_dim, 1)) * emb.reshape((1, emb_dim)) + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = ops.concat([ops.sin(emb), ops.cos(emb)], axis=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = ops.concat([emb[:, half_dim:], emb[:, :half_dim]], axis=-1) + + # zero pad if embedding_dim is odd (rare case) + if embedding_dim % 2 == 1: + # Pad with one zero column at the end + zeros = ops.zeros((emb.shape[0], 1), dtype=emb.dtype, device=emb.device) + emb = ops.concat([emb, zeros], axis=-1) + + return emb + + +class Timesteps(nn.Module): + def __init__( + self, + num_channels: int, + flip_sin_to_cos: bool, + downscale_freq_shift: float, + scale: int = 1, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.float32, + ): + """Initialize Timesteps embedding module. + + Args: + num_channels: Number of channels in the embedding. + flip_sin_to_cos: Whether to flip sine and cosine embeddings. + downscale_freq_shift: Frequency downscaling shift parameter. + scale: Scaling factor for embeddings. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def __call__(self, timesteps: TensorValue) -> TensorValue: + """Generate timestep embeddings. + + Args: + timesteps: Input timestep values. + + Returns: + Timestep embeddings. + """ + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int | None = None, + post_act_fn: str | None = None, + cond_proj_dim: int | None = None, + sample_proj_bias: bool = True, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize TimestepEmbedding module. + + Args: + in_channels: Number of input channels. + time_embed_dim: Dimension of the time embedding. + act_fn: Activation function to use ("silu", "swish", or "gelu"). + out_dim: Optional output dimension. Defaults to time_embed_dim if None. + post_act_fn: Optional post-activation function. + cond_proj_dim: Optional conditional projection dimension. + sample_proj_bias: Whether to use bias in projection layers. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + + self.linear_1 = nn.Linear( + in_channels, + time_embed_dim, + has_bias=sample_proj_bias, + device=device, + dtype=dtype, + ) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear( + cond_proj_dim, + in_channels, + has_bias=False, + device=device, + dtype=dtype, + ) + else: + self.cond_proj = None + if act_fn == "silu" or act_fn == "swish": + self.act_fn = ops.silu + elif act_fn == "gelu": + self.act_fn = ops.gelu + else: + raise ValueError(f"Invalid activation function: {act_fn}") + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + + self.linear_2 = nn.Linear( + time_embed_dim, + time_embed_dim_out, + has_bias=sample_proj_bias, + device=device, + dtype=dtype, + ) + + if post_act_fn is None: + self.post_act_fn = None + elif post_act_fn == "silu" or post_act_fn == "swish": + self.post_act_fn = ops.silu + elif post_act_fn == "gelu": + self.post_act_fn = ops.gelu + else: + raise ValueError(f"Invalid post activation function: {post_act_fn}") + + def __call__( + self, sample: TensorValue, condition: TensorValue | None = None + ) -> TensorValue: + """Generate timestep embeddings with optional conditioning. + + Args: + sample: Input sample tensor. + condition: Optional conditioning tensor. + + Returns: + Timestep embeddings. + """ + if condition is not None and self.cond_proj is not None: + sample = sample + self.cond_proj(condition) + + sample = self.linear_1(sample) + + sample = self.act_fn(sample) + + sample = self.linear_2(sample) + + if self.post_act_fn is not None: + sample = self.post_act_fn(sample) + + return sample + + +class PixArtAlphaTextProjection(nn.Module): + """Projects caption embeddings. Also handles dropout for classifier-free guidance.""" + + def __init__( + self, + in_features: int, + hidden_size: int, + out_features: int | None = None, + act_fn: str = "gelu_tanh", + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize PixArtAlpha text projection module. + + Args: + in_features: Number of input features. + hidden_size: Size of the hidden layer. + out_features: Number of output features. Defaults to hidden_size if None. + act_fn: Activation function to use ("gelu_tanh" or "silu"). + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = nn.Linear( + in_features, hidden_size, has_bias=True, device=device, dtype=dtype + ) + self.linear_2 = nn.Linear( + hidden_size, out_features, has_bias=True, device=device, dtype=dtype + ) + if act_fn == "gelu_tanh": + self.act_fn = ops.gelu(approximate="tanh") + elif act_fn == "silu": + self.act_fn = ops.silu + else: + raise ValueError(f"Invalid activation function: {act_fn}") + + def __call__(self, caption: TensorValue) -> TensorValue: + """Project caption embeddings. + + Args: + caption: Input caption embeddings. + + Returns: + Projected caption embeddings. + """ + hidden_states = self.linear_1(caption) + + hidden_states = self.act_fn(hidden_states) + + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class CombinedTimestepTextProjEmbeddings(nn.Module): + def __init__( + self, + embedding_dim: int, + pooled_projection_dim: int, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize combined timestep and text projection embeddings module. + + Args: + embedding_dim: Dimension of the embedding. + pooled_projection_dim: Dimension of the pooled projection. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + + self.time_proj = Timesteps( + num_channels=256, + flip_sin_to_cos=True, + downscale_freq_shift=0, + device=device, + dtype=dtype, + ) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=embedding_dim, + device=device, + dtype=dtype, + ) + self.text_embedder = PixArtAlphaTextProjection( + pooled_projection_dim, + embedding_dim, + act_fn="silu", + device=device, + dtype=dtype, + ) + + def __call__( + self, timestep: TensorValue, pooled_projection: TensorValue + ) -> TensorValue: + """Combine timestep and text embeddings. + + Args: + timestep: Input timestep values. + pooled_projection: Pooled text projection. + + Returns: + Combined conditioning embeddings. + """ + # Timestep projection and embedding + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.cast(pooled_projection.dtype) + ) + + # Text projection + pooled_projections = self.text_embedder(pooled_projection) + + # Combine + conditioning = timesteps_emb + pooled_projections + + return conditioning + + +class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): + def __init__( + self, + embedding_dim: int, + pooled_projection_dim: int, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize combined timestep, guidance, and text projection embeddings module. + + Args: + embedding_dim: Dimension of the embedding. + pooled_projection_dim: Dimension of the pooled projection. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + + self.time_proj = Timesteps( + num_channels=256, + flip_sin_to_cos=True, + downscale_freq_shift=0, + device=device, + dtype=dtype, + ) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=embedding_dim, + device=device, + dtype=dtype, + ) + self.guidance_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=embedding_dim, + device=device, + dtype=dtype, + ) + self.text_embedder = PixArtAlphaTextProjection( + pooled_projection_dim, + embedding_dim, + act_fn="silu", + device=device, + dtype=dtype, + ) + + def __call__( + self, + timestep: TensorValue, + guidance: TensorValue, + pooled_projection: TensorValue, + ) -> TensorValue: + """Combine timestep, guidance, and text embeddings. + + Args: + timestep: Input timestep values. + guidance: Guidance values. + pooled_projection: Pooled text projection. + + Returns: + Combined conditioning embeddings. + """ + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.cast(pooled_projection.dtype) + ) + + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder( + guidance_proj.cast(pooled_projection.dtype) + ) + + time_guidance_emb = timesteps_emb + guidance_emb + + pooled_projections = self.text_embedder(pooled_projection) + conditioning = time_guidance_emb + pooled_projections + + return conditioning diff --git a/max/python/max/pipelines/architectures/flux1/layers/flux_attention.py b/max/python/max/pipelines/architectures/flux1/layers/flux_attention.py new file mode 100644 index 00000000000..188050390e4 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/layers/flux_attention.py @@ -0,0 +1,474 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import math + +import max.nn as nn +from max.dtype import DType +from max.graph import DeviceRef, TensorValue, ops +from max.nn import ( + Linear, + Module, + RMSNorm, +) +from max.nn.attention.mask_config import MHAMaskVariant +from max.nn.kernels import flash_attention_gpu + +from .activations import GELU +from .embeddings import apply_rotary_emb + + +class FluxAttention(Module): + """Flux attention mechanism with QK normalization and optional dual stream.""" + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: int | None = None, + added_proj_bias: bool | None = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int | None = None, + context_pre_only: bool | None = None, + pre_only: bool = False, + elementwise_affine: bool = True, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize Flux attention module. + + Args: + query_dim: Dimension of query vectors. + heads: Number of attention heads. + dim_head: Dimension of each attention head. + dropout: Dropout probability. + bias: Whether to use bias in projections. + added_kv_proj_dim: Optional dimension for additional key/value projections. + added_proj_bias: Whether to use bias in additional projections. + out_bias: Whether to use bias in output projection. + eps: Epsilon for normalization layers. + out_dim: Optional output dimension. + context_pre_only: Whether to use context pre-processing only. + pre_only: Whether to use pre-processing only. + elementwise_affine: Whether to use elementwise affine in normalization. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + self.dtype = dtype + self.device = device + + self.norm_q = RMSNorm( + dim_head, + dtype=self.dtype, + eps=eps, + multiply_before_cast=elementwise_affine, + ) + self.norm_k = RMSNorm( + dim_head, + dtype=self.dtype, + eps=eps, + multiply_before_cast=elementwise_affine, + ) + self.to_q = Linear( + query_dim, + self.inner_dim, + has_bias=bias, + dtype=self.dtype, + device=self.device, + ) + self.to_k = Linear( + query_dim, + self.inner_dim, + has_bias=bias, + dtype=self.dtype, + device=self.device, + ) + self.to_v = Linear( + query_dim, + self.inner_dim, + has_bias=bias, + dtype=self.dtype, + device=self.device, + ) + + if not self.pre_only: + layers = [] + layers.append( + Linear( + self.inner_dim, + self.out_dim, + has_bias=out_bias, + dtype=self.dtype, + device=self.device, + ) + ) + # layers.append(Dropout(dropout)) # There is no Dropout in MAX + self.to_out = nn.Sequential(layers) + + if added_kv_proj_dim is not None: + self.norm_added_q = RMSNorm(dim_head, dtype=self.dtype, eps=eps) + self.norm_added_k = RMSNorm(dim_head, dtype=self.dtype, eps=eps) + self.add_q_proj = Linear( + added_kv_proj_dim, + self.inner_dim, + has_bias=added_proj_bias, + dtype=self.dtype, + device=self.device, + ) + self.add_k_proj = Linear( + added_kv_proj_dim, + self.inner_dim, + has_bias=added_proj_bias, + dtype=self.dtype, + device=self.device, + ) + self.add_v_proj = Linear( + added_kv_proj_dim, + self.inner_dim, + has_bias=added_proj_bias, + dtype=self.dtype, + device=self.device, + ) + self.to_add_out = Linear( + self.inner_dim, + query_dim, + has_bias=out_bias, + dtype=self.dtype, + device=self.device, + ) + + def __call__( + self, + hidden_states: TensorValue, + encoder_hidden_states: TensorValue = None, + attention_mask: TensorValue | None = None, + image_rotary_emb: tuple[TensorValue, TensorValue] | None = None, + ) -> TensorValue: + """Apply Flux attention to hidden states. + + Args: + hidden_states: Input hidden states. + encoder_hidden_states: Optional encoder hidden states for cross-attention. + attention_mask: Optional attention mask. + image_rotary_emb: Optional rotary embeddings for position encoding. + + Returns: + Output hidden states after attention, or tuple of (hidden_states, encoder_hidden_states) if encoder states provided. + """ + batch_size = hidden_states.shape[0] + + # get qkv projections + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + seq_len = query.shape[1] + query = ops.reshape( + query, (batch_size, seq_len, self.heads, self.head_dim) + ) + key = ops.reshape(key, (batch_size, seq_len, self.heads, self.head_dim)) + value = ops.reshape( + value, (batch_size, seq_len, self.heads, self.head_dim) + ) + + query = self.norm_q(query) + key = self.norm_k(key) + + encoder_query = encoder_key = encoder_value = None + if ( + encoder_hidden_states is not None + and self.added_kv_proj_dim is not None + ): + encoder_query = self.add_q_proj(encoder_hidden_states) + encoder_key = self.add_k_proj(encoder_hidden_states) + encoder_value = self.add_v_proj(encoder_hidden_states) + + query = self.norm_q(query) + key = self.norm_k(key) + + if ( + encoder_hidden_states is not None + and self.added_kv_proj_dim is not None + ): + encoder_seq_len = encoder_query.shape[1] + encoder_query = ops.reshape( + encoder_query, + (batch_size, encoder_seq_len, self.heads, self.head_dim), + ) + encoder_key = ops.reshape( + encoder_key, + (batch_size, encoder_seq_len, self.heads, self.head_dim), + ) + encoder_value = ops.reshape( + encoder_value, + (batch_size, encoder_seq_len, self.heads, self.head_dim), + ) + + encoder_query = self.norm_added_q(encoder_query) + encoder_key = self.norm_added_k(encoder_key) + + query = ops.concat([encoder_query, query], axis=1) + key = ops.concat([encoder_key, key], axis=1) + value = ops.concat([encoder_value, value], axis=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = flash_attention_gpu( + query, + key, + value, + mask_variant=MHAMaskVariant.NULL_MASK, + scale=math.sqrt(1.0 / self.head_dim), + ) + + total_seq_len = hidden_states.shape[1] + hidden_states = ops.reshape( + hidden_states, + (batch_size, total_seq_len, self.heads * self.head_dim), + ) + + if encoder_hidden_states is not None: + encoder_seq_len = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :encoder_seq_len, :] + hidden_states = hidden_states[:, encoder_seq_len:, :] + + hidden_states = self.to_out(hidden_states) + encoder_hidden_states = self.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + return hidden_states + + +class FeedForward(Module): + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim: int | None = None, + bias: bool = True, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize FeedForward module. + + Args: + dim: Input dimension. + dim_out: Optional output dimension. Defaults to dim if None. + mult: Multiplier for hidden dimension. + dropout: Dropout probability. + activation_fn: Activation function to use ("gelu" or "gelu-approximate"). + final_dropout: Whether to apply dropout at the end. + inner_dim: Optional inner dimension. Computed as dim * mult if None. + bias: Whether to use bias in linear layers. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias, device=device, dtype=dtype) + if activation_fn == "gelu-approximate": + act_fn = GELU( + dim, + inner_dim, + approximate="tanh", + bias=bias, + device=device, + dtype=dtype, + ) + else: + raise NotImplementedError( + f"Activation function {activation_fn} is not implemented" + ) + + self.net = nn.Sequential( + [ + act_fn, + Linear( + inner_dim, + dim_out, + has_bias=bias, + dtype=dtype, + device=device, + ), + ] + ) + + def __call__( + self, hidden_states: TensorValue, *args, **kwargs + ) -> TensorValue: + """Apply feedforward network to hidden states. + + Args: + hidden_states: Input hidden states. + *args: Additional positional arguments (unused). + **kwargs: Additional keyword arguments (unused). + + Returns: + Output hidden states after feedforward network. + """ + return self.net(hidden_states) + + +class FluxPosEmbed(nn.Module): + """Flux Position Embedding module for 3D rotary position embeddings. + + This module computes separate rotary embeddings for each spatial dimension + (typically time, height, width) and concatenates them. + + Args: + theta: Base value for frequency computation (typically 10000) + axes_dim: List of dimensions for each axis (e.g., [16, 56, 56] for time, height, width) + """ + + def __init__( + self, theta: int = 10000, axes_dim: tuple[int, int, int] = (16, 56, 56) + ): + """Initialize Flux position embedding module. + + Args: + theta: Base value for frequency computation (typically 10000). + axes_dim: Dimensions for each axis (e.g., [16, 56, 56] for time, height, width). + """ + super().__init__() + self.theta = float(theta) + self.axes_dim = list(axes_dim) + + def _get_1d_rotary_pos_embed( + self, dim: int, pos: TensorValue, device: DeviceRef + ) -> tuple[TensorValue, TensorValue]: + """Compute 1D rotary position embeddings for a single axis. + + Args: + dim: Dimension of the embedding (should be even) + pos: Position indices, shape [batch_size] + device: Device to compute on + + Returns: + Tuple of (freqs_cos, freqs_sin), each with shape [batch_size, dim] + """ + # Ensure dim is even + assert dim % 2 == 0, f"dim must be even, got {dim}" + + # Cast position to float32 for computation + pos = ops.cast(pos, DType.float32) + + # Compute frequencies: 1.0 / (theta ** (arange(0, dim, 2) / dim)) + # Shape: [dim/2] + arange_vals = ops.range( + 0, dim, step=2, dtype=DType.float32, device=device + ) + exponents = arange_vals / float(dim) + + # theta ** exponents + theta_tensor = ops.constant(self.theta, DType.float32, device=device) + theta_powered = ops.pow(theta_tensor, exponents) + + # 1.0 / theta_powered + freqs = 1.0 / theta_powered # Shape: [dim/2] + + # Outer product: pos [batch_size] x freqs [dim/2] = [batch_size, dim/2] + freqs_outer = ops.outer(pos, freqs) + + # Compute cos and sin + freqs_cos_half = ops.cos(freqs_outer) # [batch_size, dim/2] + freqs_sin_half = ops.sin(freqs_outer) # [batch_size, dim/2] + + # Repeat interleave to get full dimension + # repeat_interleave(2, dim=1): [a, b, c] -> [a, a, b, b, c, c] + # Since repeat_interleave is not supported on GPU, we use reshape + tile + + # 1. Unsqueeze: [batch_size, dim/2] -> [batch_size, dim/2, 1] + freqs_cos_expanded = ops.unsqueeze(freqs_cos_half, axis=2) + freqs_sin_expanded = ops.unsqueeze(freqs_sin_half, axis=2) + + # 2. Concat to duplicate: [batch_size, dim/2, 1] -> [batch_size, dim/2, 2] + freqs_cos_tiled = ops.concat( + [freqs_cos_expanded, freqs_cos_expanded], axis=2 + ) + freqs_sin_tiled = ops.concat( + [freqs_sin_expanded, freqs_sin_expanded], axis=2 + ) + + # 3. Reshape to flatten: [batch_size, dim/2, 2] -> [batch_size, dim] + flattened_dim = freqs_cos_tiled.shape[1] * freqs_cos_tiled.shape[2] + freqs_cos = ops.reshape( + freqs_cos_tiled, (freqs_cos_tiled.shape[0], flattened_dim) + ) + freqs_sin = ops.reshape( + freqs_sin_tiled, (freqs_sin_tiled.shape[0], flattened_dim) + ) + + return freqs_cos, freqs_sin + + def __call__(self, ids: TensorValue) -> tuple[TensorValue, TensorValue]: + """Forward pass to compute rotary position embeddings. + + Args: + ids: Position indices tensor with shape [batch_size, n_axes] + where n_axes is the number of spatial dimensions (e.g., 3 for time/height/width) + + Returns: + Tuple of (freqs_cos, freqs_sin) with concatenated embeddings from all axes + """ + # Get number of axes from the last dimension + n_axes = ids.shape[-1] + device = ids.device + + cos_out = [] + sin_out = [] + + # Compute embeddings for each axis + for i in range(int(n_axes)): + # Extract position for this axis: ids[:, i] + pos = ids[:, i] + + # Compute 1D rotary embeddings for this axis + cos_embed, sin_embed = self._get_1d_rotary_pos_embed( + dim=self.axes_dim[i], pos=pos, device=device + ) + + cos_out.append(cos_embed) + sin_out.append(sin_embed) + + # Concatenate embeddings from all axes along the last dimension + freqs_cos = ops.concat(cos_out, axis=-1) # [batch_size, sum(axes_dim)] + freqs_sin = ops.concat(sin_out, axis=-1) # [batch_size, sum(axes_dim)] + + return freqs_cos, freqs_sin diff --git a/max/python/max/pipelines/architectures/flux1/layers/normalizations.py b/max/python/max/pipelines/architectures/flux1/layers/normalizations.py new file mode 100644 index 00000000000..6755d56125d --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/layers/normalizations.py @@ -0,0 +1,254 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + + +import max.nn as nn +from max.dtype import DType +from max.graph import DeviceRef, TensorValue, ops +from max.nn import LayerNorm, RMSNorm + + +class AdaLayerNormZeroSingle(nn.Module): + def __init__( + self, + embedding_dim: int, + norm_type: str = "layer_norm", + bias: bool = True, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize adaptive layer normalization zero single module. + + Args: + embedding_dim: Size of each embedding vector. + norm_type: Type of normalization to use ("layer_norm"). + bias: Whether to use bias in linear projection. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + self.linear = nn.Linear( + embedding_dim, + 3 * embedding_dim, + has_bias=bias, + device=device, + dtype=dtype, + ) + if norm_type == "layer_norm": + self.norm = LayerNorm( + embedding_dim, + use_bias=False, + eps=1e-6, + devices=[device], + dtype=dtype, + keep_dtype=True, + elementwise_affine=False, + ) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def __call__( + self, x: TensorValue, emb: TensorValue | None = None + ) -> TensorValue: + """Apply adaptive layer normalization. + + Args: + x: Input tensor. + emb: Optional embedding tensor for conditioning. + + Returns: + Tuple of normalized tensor and gate values. + """ + emb = self.linear(ops.silu(emb)) + shift_msa, scale_msa, gate_msa = ops.chunk(emb, 3, axis=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa + + +class AdaLayerNormZero(nn.Module): + r"""Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__( + self, + embedding_dim: int, + num_embeddings: int | None = None, + norm_type: str = "layer_norm", + bias: bool = True, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize adaptive layer normalization zero module. + + Args: + embedding_dim: Size of each embedding vector. + num_embeddings: Optional size of the embeddings dictionary. + norm_type: Type of normalization to use ("layer_norm" or "fp32_layer_norm"). + bias: Whether to use bias in linear projection. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + if num_embeddings is not None: + # self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) + raise NotImplementedError( + "CombinedTimestepLabelEmbeddings is not implemented" + ) + else: + self.emb = None + + self.linear = nn.Linear( + embedding_dim, + 6 * embedding_dim, + has_bias=bias, + dtype=dtype, + device=device, + ) + if norm_type == "layer_norm": + self.norm = LayerNorm( + embedding_dim, + use_bias=False, + eps=1e-6, + devices=[device], + dtype=dtype, + keep_dtype=True, + elementwise_affine=False, + ) + elif norm_type == "fp32_layer_norm": + # self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False) + raise NotImplementedError("FP32LayerNorm is not implemented") + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def __call__( + self, + x: TensorValue, + timestep: TensorValue | None = None, + class_labels: TensorValue | None = None, + hidden_dtype: DType | None = None, + emb: TensorValue | None = None, + ) -> tuple[TensorValue, TensorValue, TensorValue, TensorValue, TensorValue]: + """Apply adaptive layer normalization with gate values for attention and MLP. + + Args: + x: Input tensor. + timestep: Optional timestep tensor. + class_labels: Optional class label tensor. + hidden_dtype: Optional hidden data type. + emb: Optional embedding tensor for conditioning. + + Returns: + Tuple of (normalized tensor, gate_msa, shift_mlp, scale_mlp, gate_mlp). + """ + if self.emb is not None: + emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) + emb = self.linear(ops.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + ops.chunk(emb, 6, axis=1) + ) + x = self.norm(x) + x = x * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AdaLayerNormContinuous(nn.Module): + r"""Adaptive normalization layer with a norm layer (layer_norm or rms_norm). + + Args: + embedding_dim (`int`): Embedding dimension to use during projection. + conditioning_embedding_dim (`int`): Dimension of the input condition. + elementwise_affine (`bool`, defaults to `True`): + Boolean flag to denote if affine transformation should be applied. + eps (`float`, defaults to 1e-5): Epsilon factor. + bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use. + norm_type (`str`, defaults to `"layer_norm"`): + Normalization layer to use. Values supported: "layer_norm", "rms_norm". + """ + + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + # elementwise_affine=True, + eps: float = 1e-5, + bias: bool = True, + norm_type: str = "layer_norm", + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize adaptive layer normalization continuous module. + + Args: + embedding_dim: Embedding dimension to use during projection. + conditioning_embedding_dim: Dimension of the input condition. + eps: Epsilon factor for normalization. + bias: Whether to use bias in linear projection. + norm_type: Type of normalization to use ("layer_norm" or "rms_norm"). + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + self.silu = ops.silu + self.linear = nn.Linear( + conditioning_embedding_dim, + embedding_dim * 2, + has_bias=bias, + device=device, + dtype=dtype, + ) + if norm_type == "layer_norm": + self.norm = LayerNorm( + embedding_dim, + eps=eps, + devices=[device], + dtype=dtype, + keep_dtype=True, + elementwise_affine=False, + ) + elif norm_type == "rms_norm": + self.norm = RMSNorm( + embedding_dim, eps=eps, device=device, dtype=dtype + ) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def __call__( + self, x: TensorValue, conditioning_embedding: TensorValue + ) -> TensorValue: + """Apply adaptive layer normalization with conditioning. + + Args: + x: Input tensor. + conditioning_embedding: Conditioning embedding tensor. + + Returns: + Normalized and conditioned tensor. + """ + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear(self.silu(conditioning_embedding).cast(x.dtype)) + scale, shift = ops.chunk(emb, 2, axis=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x diff --git a/max/python/max/pipelines/architectures/flux1/model.py b/max/python/max/pipelines/architectures/flux1/model.py new file mode 100644 index 00000000000..8b476f1ab72 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/model.py @@ -0,0 +1,77 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.driver import CPU, Accelerator, Device +from max.engine import InferenceSession, Model +from max.graph import Graph +from max.graph.weights import Weights +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.interfaces.max_model import MaxModel + +from .flux1 import FluxTransformer2DModel +from .model_config import FluxConfig +from .weight_adapters import convert_safetensor_state_dict + + +class Flux1Model(MaxModel): + config_name = FluxConfig.config_name + + def __init__( + self, + config: dict, + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + ) -> None: + super().__init__( + config, + encoding, + devices, + weights, + ) + self.config = FluxConfig.generate( + config, + encoding, + devices, + ) + self.load_model() + + def load_model(self) -> Model: + flux = FluxTransformer2DModel(self.config) + + if self.config.device.is_cpu(): + session = InferenceSession([CPU()]) + else: + session = InferenceSession([Accelerator()]) + state_dict = {key: value.data() for key, value in self.weights.items()} + state_dict = convert_safetensor_state_dict(state_dict) + flux.load_state_dict(state_dict) + with Graph( + "flux_transformer_2d_model", input_types=flux.input_types() + ) as graph: + outputs = flux( + *graph.inputs, + joint_attention_kwargs={}, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict=False, + controlnet_blocks_repeat=False, + ) + graph.output(*outputs) + compiled_graph = graph + self.session = session.load( + compiled_graph, weights_registry=flux.state_dict() + ) + + def __call__(self, *args, **kwargs): + return self.session.execute(*args, **kwargs) diff --git a/max/python/max/pipelines/architectures/flux1/model_config.py b/max/python/max/pipelines/architectures/flux1/model_config.py new file mode 100644 index 00000000000..c9292030000 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/model_config.py @@ -0,0 +1,59 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from typing import ClassVar + +from max.driver import Device +from max.dtype import DType +from max.graph import DeviceRef +from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding +from pydantic import Field + + +class FluxConfigBase(MAXModelConfigBase): + patch_size: int = 1 + in_channels: int = 64 + out_channels: int | None = None + num_layers: int = 19 + num_single_layers: int = 38 + attention_head_dim: int = 128 + num_attention_heads: int = 24 + joint_attention_dim: int = 4096 + pooled_projection_dim: int = 768 + guidance_embeds: bool = False + axes_dims_rope: tuple[int, int, int] = (16, 56, 56) + dtype: DType = DType.bfloat16 + device: DeviceRef = Field(default_factory=DeviceRef.GPU) + + +class FluxConfig(FluxConfigBase): + config_name: ClassVar[str] = "config.json" + + @staticmethod + def generate( + config_dict: dict, + encoding: SupportedEncoding, + devices: list[Device], + ) -> FluxConfigBase: + init_dict = { + key: value + for key, value in config_dict.items() + if key in FluxConfigBase.__annotations__ + } + init_dict.update( + { + "dtype": encoding.dtype, + "device": DeviceRef.from_device(devices[0]), + } + ) + return FluxConfigBase(**init_dict) diff --git a/max/python/max/pipelines/architectures/flux1/pipeline_flux.py b/max/python/max/pipelines/architectures/flux1/pipeline_flux.py new file mode 100644 index 00000000000..85298caf20e --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/pipeline_flux.py @@ -0,0 +1,774 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import inspect +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import numpy as np +import PIL.Image +from max.driver import Tensor +from max.dtype import DType +from max.experimental import Tensor as Tensor_v3 +from max.experimental import functional as F +from max.experimental import random +from max.graph import DeviceRef +from max.pipelines.lib.diffusion_schedulers import ( + FlowMatchEulerDiscreteScheduler, +) +from max.pipelines.lib.image_processor import ( + PipelineImageInput, + VaeImageProcessor, +) +from max.pipelines.lib.interfaces.diffusion_pipeline import ( + DiffusionPipeline, +) +from tqdm import tqdm +from transformers import ( + CLIPTokenizer, + T5TokenizerFast, +) + +from ..autoencoder_kl import AutoencoderKLModel +from ..clip import ClipModel +from ..t5 import T5Model +from .model import Flux1Model + + +def retrieve_timesteps( + scheduler: Any, + num_inference_steps: int | None = None, + device: str | DeviceRef | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs: Any, +) -> tuple[np.ndarray, int]: + r"""Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. + + Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `DeviceRef`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + **kwargs (`Any`, *optional*): + Additional arguments to pass to the scheduler's `set_timesteps` method. + + Returns: + `tuple[Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = int(timesteps.shape[0]) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = int(timesteps.shape[0]) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def calculate_shift( + image_seq_len: int, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +) -> float: + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +@dataclass +class FluxPipelineOutput: + """Output class for Flux image generation pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray` or `Tensor`) + List of denoised PIL images of length `batch_size` or numpy array or Max tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Max tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. + """ + + images: list[PIL.Image.Image] | np.ndarray | Tensor + + +class FluxPipeline(DiffusionPipeline): + config_name = "model_index.json" + + components = { + "scheduler": FlowMatchEulerDiscreteScheduler, + "vae": AutoencoderKLModel, + "text_encoder": ClipModel, + "tokenizer": CLIPTokenizer, + "text_encoder_2": T5Model, + "tokenizer_2": T5TokenizerFast, + "transformer": Flux1Model, + } + + def init_remaining_components(self) -> None: + image_processor_class = self.components.get( + "image_processor", VaeImageProcessor + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) + if getattr(self, "vae", None) + else 8 + ) + image_processor = image_processor_class( + vae_scale_factor=self.vae_scale_factor * 2 + ) + self.image_processor = image_processor + + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str], + device: DeviceRef | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: Tensor | None = None, + pooled_prompt_embeds: Tensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ) -> tuple[Tensor, Tensor, Tensor]: + r"""Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`DeviceRef`): + Max device + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + prompt_embeds (`Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + max_sequence_length (`int`, defaults to 512): Maximum sequence length to use with the `prompt`. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + if lora_scale is not None and isinstance(self, FluxPipeline): + self._lora_scale = lora_scale + + if self.text_encoder is not None and hasattr( + self.text_encoder, "set_lora_scale" + ): + self.text_encoder.set_lora_scale(lora_scale) + if self.text_encoder_2 is not None and hasattr( + self.text_encoder_2, "set_lora_scale" + ): + self.text_encoder_2.set_lora_scale(lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=min( + max_sequence_length, self.tokenizer.model_max_length + ), + truncation=True, + return_length=False, + return_overflowing_tokens=False, + ) + text_input_ids = Tensor_v3.constant( + text_inputs.input_ids, device=device, dtype=DType.int64 + ) + + text_encoder_outputs = self.text_encoder(text_input_ids) + prompt_embeds = text_encoder_outputs[0] + pooled_prompt_embeds = text_encoder_outputs[1] + + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + if self.text_encoder_2 is not None: + text_inputs_2 = self.tokenizer_2( + prompt_2, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + ) + text_input_ids_2 = Tensor_v3.constant( + text_inputs_2.input_ids, device=device, dtype=DType.int64 + ) + + prompt_embeds_2 = self.text_encoder_2(text_input_ids_2)[0] + else: + prompt_embeds_2 = None + + if prompt_embeds_2 is not None: + prompt_embeds = prompt_embeds_2 + + text_ids = Tensor_v3.zeros( + (prompt_embeds.shape[1], 3), + device=device, + dtype=prompt_embeds.dtype, + ) + + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = Tensor_v3.from_dlpack( + prompt_embeds + ) # V2 Tensor to V3 Tensor + pooled_prompt_embeds = Tensor_v3.from_dlpack( + pooled_prompt_embeds + ) # V2 Tensor to V3 Tensor + + prompt_embeds = F.tile(prompt_embeds, (1, num_images_per_prompt, 1)) + prompt_embeds = prompt_embeds.reshape( + (bs_embed * num_images_per_prompt, seq_len, -1) + ) + + pooled_prompt_embeds = F.tile( + pooled_prompt_embeds, (1, num_images_per_prompt) + ) + pooled_prompt_embeds = pooled_prompt_embeds.reshape( + (bs_embed * num_images_per_prompt, -1) + ) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + @staticmethod + def _prepare_latent_image_ids( + batch_size: int, + height: int, + width: int, + device: DeviceRef, + dtype: DType, + ) -> Tensor_v3: + latent_image_ids = np.stack( + [ + np.zeros((height, width)), + np.broadcast_to(np.arange(height)[:, None], (height, width)), + np.broadcast_to(np.arange(width)[None, :], (height, width)), + ], + axis=-1, + ) + + ( + latent_image_id_height, + latent_image_id_width, + latent_image_id_channels, + ) = latent_image_ids.shape + + latent_image_ids = np.reshape( + latent_image_ids, + ( + latent_image_id_height * latent_image_id_width, + latent_image_id_channels, + ), + ) + latent_image_ids = ( + Tensor_v3.from_dlpack(latent_image_ids).to(device).cast(dtype) + ) + + return latent_image_ids + + @staticmethod + def _pack_latents( + latents: Tensor_v3, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + ) -> Tensor_v3: + latents = F.reshape( + latents, + (batch_size, num_channels_latents, height // 2, 2, width // 2, 2), + ) + latents = F.permute(latents, (0, 2, 4, 1, 3, 5)) + latents = F.reshape( + latents, + ( + batch_size, + (height // 2) * (width // 2), + num_channels_latents * 4, + ), + ) + + return latents + + @staticmethod + def _unpack_latents( + latents: Tensor_v3, + height: int, + width: int, + vae_scale_factor: int, + ) -> Tensor_v3: + # TODO: should compile this function for speed up. + batch_size, _, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (height // (vae_scale_factor * 2)) + width = 2 * (width // (vae_scale_factor * 2)) + + latents = F.reshape( + latents, + (batch_size.dim, height // 2, width // 2, channels.dim // 4, 2, 2), + ) + latents = F.permute(latents, (0, 3, 1, 4, 2, 5)) + + latents = F.reshape( + latents, (batch_size.dim, channels.dim // (2 * 2), height, width) + ) + + return latents + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: DType, + device: DeviceRef, + latents: Tensor_v3 | None = None, + ) -> tuple[Tensor_v3, Tensor_v3]: + """Prepare latents for the Flux pipeline. + + Args: + batch_size: The number of images to generate. + num_channels_latents: The number of latent channels. + height: The height of the generated image. + width: The width of the generated image. + dtype: The data type for the latents. + device: The device to run on. + latents: Pre-generated latents. + + Returns: + Tuple of latents and latent image ids. + """ + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids( + batch_size, height // 2, width // 2, device, dtype + ) + return latents.to(device).cast(dtype), latent_image_ids + + latents = random.normal(shape, device=device, dtype=dtype) + latents = self._pack_latents( + latents, batch_size, num_channels_latents, height, width + ) + + latent_image_ids = self._prepare_latent_image_ids( + batch_size, height // 2, width // 2, device, dtype + ) + + return latents, latent_image_ids + + def __call__( + self, + prompt: str | list[str] | None = None, + prompt_2: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + true_cfg_scale: float = 1.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 3.5, + num_images_per_prompt: int | None = 1, + latents: Tensor | None = None, + prompt_embeds: Tensor | None = None, + pooled_prompt_embeds: Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[Tensor] | None = None, + negative_ip_adapter_image: PipelineImageInput | None = None, + negative_ip_adapter_image_embeds: list[Tensor] | None = None, + negative_prompt_embeds: Tensor | None = None, + negative_pooled_prompt_embeds: Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] | None = None, + max_sequence_length: int = 512, + ): + r"""Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + `negative_prompt` is provided. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + latents (`Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`List[Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device() + + lora_scale = ( + self._joint_attention_kwargs.get("scale", None) + if self._joint_attention_kwargs is not None + else None + ) + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None + and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + latents, + ) + + # 5. Prepare timesteps + sigmas = ( + np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + if sigmas is None + else sigmas + ) + if ( + hasattr(self.scheduler, "use_flow_sigmas") + and self.scheduler.use_flow_sigmas + ): + sigmas = None + image_seq_len = latents.shape[1].dim + mu = calculate_shift( + image_seq_len, + self.scheduler.base_image_seq_len, + self.scheduler.max_image_seq_len, + self.scheduler.base_shift, + self.scheduler.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + + self._num_timesteps = timesteps.shape[0] + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = Tensor_v3.full( + [latents.shape[0].dim], + guidance_scale, + device=device, + dtype=prompt_embeds.dtype, + ) + else: + guidance = Tensor_v3.zeros( + [latents.shape[0].dim], + device=device, + dtype=prompt_embeds.dtype, + ) + + if ( + ip_adapter_image is not None + or ip_adapter_image_embeds is not None + or negative_ip_adapter_image is not None + or negative_ip_adapter_image_embeds is not None + ): + raise NotImplementedError( + "IP adapter is not supported for Max yet." + ) + + if self._joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + + # 6. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + batch_size = latents.shape[0].dim + for i in tqdm(range(self._num_timesteps), desc="Denoising"): + if self._interrupt: + continue + + t = timesteps[i] + self._current_timestep = t + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = ( + image_embeds + ) + + # NOTE: Convert timesteps to a Max Tensor before denoising loop, + # as in the original implementation, results in a significant slow down. + # As a workaround, we keep timesteps as a numpy array and convert it + # to a Max Tensor here. This might require a more efficient way to handle this. + # Converting to a Max module V3 Tensor also results in a significant slow down. + timestep = np.full((batch_size,), t) / 1000.0 + timestep = Tensor.from_dlpack(timestep).to(prompt_embeds.device) + + noise_pred = self.transformer( + latents, + prompt_embeds, + pooled_prompt_embeds, + timestep, + latent_image_ids, + text_ids, + guidance, + )[0] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = ( + negative_image_embeds + ) + + neg_noise_pred = self.transformer( + latents, + negative_prompt_embeds, + negative_pooled_prompt_embeds, + timestep, + latent_image_ids, + negative_text_ids, + guidance, + )[0] + # TODO: negative prompt path is very slow, need to optimize. + noise_pred = Tensor_v3.from_dlpack(noise_pred) + neg_noise_pred = Tensor_v3.from_dlpack(neg_noise_pred) + noise_pred = neg_noise_pred + true_cfg_scale * ( + noise_pred - neg_noise_pred + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False + )[0] + + if latents.dtype != latents_dtype: + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs + ) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds + ) + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = Tensor_v3.from_dlpack(latents) # V2 Tensor to V3 Tensor + latents = self._unpack_latents( + latents, height, width, self.vae_scale_factor + ) + latents = ( + latents / self.vae.config.scaling_factor + ) + self.vae.config.shift_factor + image = self.vae.decode(latents)[0] + + image = Tensor_v3.from_dlpack(image) # V2 Tensor to V3 Tensor + image = self.image_processor.postprocess( + image, output_type=output_type + ) + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/max/python/max/pipelines/architectures/flux1/weight_adapters.py b/max/python/max/pipelines/architectures/flux1/weight_adapters.py new file mode 100644 index 00000000000..6ef149c0976 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/weight_adapters.py @@ -0,0 +1,30 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import re + +from max.graph.weights import WeightData + + +def convert_safetensor_state_dict( + state_dict: dict[str, WeightData], +) -> dict[str, WeightData]: + keys = list(state_dict.keys()) + for key in keys: + # Remap net.2 to net.1: Diffusers uses [GELU, Dropout, Linear], while MAX uses [GELU, Linear]. + if re.match( + r"transformer_blocks\.\d+\.(ff|ff_context)\.net\.2\.(weight|bias)", + key, + ): + state_dict[key.replace("net.2.", "net.1.")] = state_dict.pop(key) + return state_dict diff --git a/max/python/max/pipelines/architectures/t5/__init__.py b/max/python/max/pipelines/architectures/t5/__init__.py new file mode 100644 index 00000000000..ad108912d49 --- /dev/null +++ b/max/python/max/pipelines/architectures/t5/__init__.py @@ -0,0 +1,14 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .model import T5Model diff --git a/max/python/max/pipelines/architectures/t5/model.py b/max/python/max/pipelines/architectures/t5/model.py new file mode 100644 index 00000000000..f0c53c2fcdc --- /dev/null +++ b/max/python/max/pipelines/architectures/t5/model.py @@ -0,0 +1,64 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.driver import CPU, Accelerator, Device +from max.engine import InferenceSession, Model +from max.graph import Graph +from max.graph.weights import Weights +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.interfaces.max_model import MaxModel + +from .model_config import T5Config +from .t5 import T5EncoderModel + + +class T5Model(MaxModel): + config_name = T5Config.config_name + + def __init__( + self, + config: dict, + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + ) -> None: + super().__init__(config, encoding, devices, weights) + self.config = T5Config.generate( + config, + encoding, + devices, + ) + self.load_model() + + def load_model(self) -> Model: + t5 = T5EncoderModel(self.config) + + if self.config.device.is_cpu(): + session = InferenceSession([CPU()]) + else: + session = InferenceSession([Accelerator()]) + state_dict = {key: value.data() for key, value in self.weights.items()} + t5.load_state_dict(state_dict) + with Graph("t5_encoder_model", input_types=t5.input_types()) as graph: + outputs = t5( + input_ids=graph.inputs[0], + attention_mask=None, + ) + graph.output(outputs) + compiled_graph = graph + self.session = session.load( + compiled_graph, weights_registry=t5.state_dict() + ) + + def __call__(self, *args, **kwargs): + return self.session.execute(*args, **kwargs) diff --git a/max/python/max/pipelines/architectures/t5/model_config.py b/max/python/max/pipelines/architectures/t5/model_config.py new file mode 100644 index 00000000000..ab28ce4b2a8 --- /dev/null +++ b/max/python/max/pipelines/architectures/t5/model_config.py @@ -0,0 +1,69 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from typing import ClassVar + +from max.driver import Device +from max.dtype import DType +from max.graph import DeviceRef +from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding +from pydantic import Field + + +class T5ConfigBase(MAXModelConfigBase): + vocab_size: int = 32128 + d_model: int = 512 + d_kv: int = 64 + d_ff: int = 2048 + num_layers: int = 6 + num_decoder_layers: int | None = None + num_heads: int = 8 + relative_attention_num_buckets: int = 32 + relative_attention_max_distance: int = 128 + dropout_rate: float = 0.1 + layer_norm_epsilon: float = 1e-6 + initializer_factor: float = 1.0 + feed_forward_proj: str = "relu" + dense_act_fn: str | None = Field(default=None, exclude=True) + is_gated_act: bool = Field(default=False, exclude=True) + is_decoder: bool = Field(default=False, exclude=True) + is_encoder_decoder: bool = True + use_cache: bool = True + pad_token_id: int = 0 + eos_token_id: int = 1 + classifier_dropout: float = 0.0 + device: DeviceRef = Field(default_factory=DeviceRef.GPU) + dtype: DType = DType.bfloat16 + + +class T5Config(T5ConfigBase): + config_name: ClassVar[str] = "config.json" + + @staticmethod + def generate( + config_dict: dict, + encoding: SupportedEncoding, + devices: list[Device], + ) -> T5ConfigBase: + init_dict = { + key: value + for key, value in config_dict.items() + if key in T5ConfigBase.__annotations__ + } + init_dict.update( + { + "dtype": encoding.dtype, + "device": DeviceRef.from_device(devices[0]), + } + ) + return T5ConfigBase(**init_dict) diff --git a/max/python/max/pipelines/architectures/t5/t5.py b/max/python/max/pipelines/architectures/t5/t5.py new file mode 100644 index 00000000000..1cd8665a3bc --- /dev/null +++ b/max/python/max/pipelines/architectures/t5/t5.py @@ -0,0 +1,823 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import math + +import max.nn as nn +from max.dtype import DType +from max.graph import DeviceRef, TensorType, TensorValue, Weight, ops +from max.nn import Module + +from .model_config import T5Config + + +class T5LayerNorm(Module): + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.float32, + ): + """Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + + Args: + hidden_size: Hidden size. + eps: Epsilon. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + self.weight = Weight("weight", dtype, (hidden_size,), device=device) + self.variance_epsilon = eps + self.dtype = dtype + + def __call__(self, hidden_states: TensorValue) -> TensorValue: + """Process hidden states through the T5 layer norm. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + hidden_states_f32 = ops.cast(hidden_states, DType.float32) + variance = ops.mean(ops.pow(hidden_states_f32, 2), axis=-1) + hidden_states = hidden_states * ops.rsqrt( + variance + self.variance_epsilon + ) + + # convert into half-precision if necessary + if self.dtype in [DType.float16, DType.bfloat16]: + hidden_states = ops.cast(hidden_states, self.dtype) + + return self.weight * hidden_states + + +class T5DenseActDense(Module): + def __init__( + self, + config: T5Config, + ): + """Construct a dense-activation-dense module. + + Args: + config: T5 configuration for feed-forward dimensions and dtype. + """ + super().__init__() + self.wi = nn.Linear( + config.d_model, + config.d_ff, + has_bias=False, + device=config.device, + dtype=config.dtype, + ) + self.wo = nn.Linear( + config.d_ff, + config.d_model, + has_bias=False, + device=config.device, + dtype=config.dtype, + ) + self.act_fn = ( + lambda x: 0.5 + * x + * ( + 1.0 + + ops.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * ops.pow(x, 3.0)) + ) + ) + ) + + def __call__(self, hidden_states: TensorValue) -> TensorValue: + """Process hidden states through the dense-activation-dense block. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + hidden_states = self.wi(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedActDense(Module): + def __init__( + self, + config: T5Config, + ): + """Construct a dense-gated-activation-dense module. + + Args: + config: T5 configuration for feed-forward dimensions and dtype. + """ + super().__init__() + self.wi_0 = nn.Linear( + config.d_model, + config.d_ff, + has_bias=False, + device=config.device, + dtype=config.dtype, + ) + self.wi_1 = nn.Linear( + config.d_model, + config.d_ff, + has_bias=False, + device=config.device, + dtype=config.dtype, + ) + self.wo = nn.Linear( + config.d_ff, + config.d_model, + has_bias=False, + device=config.device, + dtype=config.dtype, + ) + self.act_fn = ( + lambda x: 0.5 + * x + * ( + 1.0 + + ops.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * ops.pow(x, 3.0)) + ) + ) + ) + + def __call__(self, hidden_states: TensorValue) -> TensorValue: + """Process hidden states through the dense-gated-activation-dense block. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + hidden_gelu = self.act_fn(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(Module): + def __init__( + self, + config: T5Config, + ): + """Construct a feed-forward layer. + + Args: + config: T5 configuration for gating, dimensions, and dtype. + """ + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense(config) + else: + self.DenseReluDense = T5DenseActDense(config) + + self.layer_norm = T5LayerNorm( + config.d_model, + eps=config.layer_norm_epsilon, + device=config.device, + dtype=config.dtype, + ) + + def __call__(self, hidden_states: TensorValue) -> TensorValue: + """Process hidden states through the feed-forward layer. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + forwarded_states + return hidden_states + + +class T5Attention(Module): + def __init__( + self, + config: T5Config, + has_relative_attention_bias: bool = False, + layer_idx: int | None = None, + ): + """Construct an attention layer. + + Args: + config: T5 configuration. + has_relative_attention_bias: Whether to use relative attention bias. + layer_idx: Index of the layer. + """ + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = ( + config.relative_attention_num_buckets + ) + self.relative_attention_max_distance = ( + config.relative_attention_max_distance + ) + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.device = config.device + self.dtype = config.dtype + + self.q = nn.Linear( + self.d_model, + self.inner_dim, + has_bias=False, + device=config.device, + dtype=config.dtype, + ) + self.k = nn.Linear( + self.d_model, + self.inner_dim, + has_bias=False, + device=config.device, + dtype=config.dtype, + ) + self.v = nn.Linear( + self.d_model, + self.inner_dim, + has_bias=False, + device=config.device, + dtype=config.dtype, + ) + self.o = nn.Linear( + self.inner_dim, + self.d_model, + has_bias=False, + device=config.device, + dtype=config.dtype, + ) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding( + self.relative_attention_num_buckets, + self.n_heads, + device=config.device, + dtype=config.dtype, + ) + + def _relative_position_bucket( + self, + relative_position: TensorValue, + bidirectional: bool = True, + num_buckets: int = 32, + max_distance: int = 128, + ) -> TensorValue: + """Compute relative position bucket. + + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Args: + relative_position: Tensor with relative positions. + bidirectional: Whether the attention is bidirectional. + num_buckets: Number of buckets. + max_distance: Maximum distance for relative positions. + + Returns: + TensorValue: Relative position buckets. + """ + relative_buckets = ops.constant(0, DType.int32, self.device) + + if bidirectional: + num_buckets = num_buckets // 2 + is_positive = ops.greater(relative_position, 0) + relative_buckets = relative_buckets + ( + is_positive.cast(DType.int32) * num_buckets + ) + relative_position = ops.abs(relative_position) + else: + relative_position = -ops.min(relative_position, 0) + + max_exact = num_buckets // 2 + is_small = ops.greater(max_exact, relative_position) + + scale = (num_buckets - max_exact) / math.log(max_distance / max_exact) + rel_pos_float = relative_position.cast(DType.float32) + val_log = ops.log(rel_pos_float / float(max_exact)) + relative_position_if_large = max_exact + (val_log * scale).cast( + DType.int32 + ) + relative_position_if_large = ops.min( + relative_position_if_large, num_buckets - 1 + ) + return relative_buckets + ops.where( + is_small, relative_position, relative_position_if_large + ) + + def compute_bias(self, query_length: int, key_length: int) -> TensorValue: + """Compute relative attention bias. + + Args: + query_length: Length of the query sequence. + key_length: Length of the key sequence. + + Returns: + TensorValue: Relative attention bias tensor. + """ + context_position = ops.range( + 0, query_length, step=1, dtype=DType.int32, device=self.device + ) + context_position = ops.unsqueeze(context_position, 1) + + memory_position = ops.range( + 0, key_length, step=1, dtype=DType.int32, device=self.device + ) + memory_position = ops.unsqueeze(memory_position, 0) + + relative_position = memory_position - context_position + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) + values = ops.permute(values, (2, 0, 1)) + values = ops.unsqueeze(values, 0) + return values + + def __call__( + self, + hidden_states: TensorValue, + mask: TensorValue | None = None, + key_value_states: TensorValue | None = None, + position_bias: TensorValue | None = None, + past_key_values: TensorValue | None = None, + layer_head_mask: TensorValue | None = None, + query_length: int | None = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: TensorValue | None = None, + ) -> tuple[TensorValue, TensorValue]: + """Process hidden states through the attention layer. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + mask: Attention mask. + key_value_states: Key-value states for cross-attention. + position_bias: Position bias tensor. + past_key_values: Past key values for caching (not implemented). + layer_head_mask: Mask for attention heads. + query_length: Length of the query sequence. + use_cache: Whether to use cache (not implemented). + output_attentions: Whether to return attention weights. + cache_position: Cache position. + + Returns: + Tuple[TensorValue, TensorValue]: Output tensor and position bias. + """ + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + if is_cross_attention: + raise NotImplementedError( + "T5 CrossAttention is not implemented yet." + ) + if past_key_values is not None: + raise NotImplementedError( + "T5 auto regressive model is not implemented yet." + ) + + query = self.q(hidden_states) + key = self.k(hidden_states) + value = self.v(hidden_states) + + # Reshape to (batch, seq, heads, head_dim) + query = ops.reshape( + query, + (batch_size, seq_length, self.n_heads, self.key_value_proj_dim), + ) + key = ops.reshape( + key, (batch_size, seq_length, self.n_heads, self.key_value_proj_dim) + ) + value = ops.reshape( + value, + (batch_size, seq_length, self.n_heads, self.key_value_proj_dim), + ) + + # Transpose to (batch, heads, seq, head_dim) + query = ops.permute(query, (0, 2, 1, 3)) + key = ops.permute(key, (0, 2, 1, 3)) + value = ops.permute(value, (0, 2, 1, 3)) + + scores = ops.matmul(query, ops.permute(key, (0, 1, 3, 2))) + + if position_bias is None and self.has_relative_attention_bias: + position_bias = self.compute_bias(seq_length, seq_length) + + if position_bias is not None: + scores = scores + position_bias + + if mask is not None: + scores = scores + mask + + attn_weights = ops.softmax(ops.cast(scores, DType.float32), axis=-1) + attn_weights = ops.cast(attn_weights, self.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = ops.matmul(attn_weights, value) + attn_output = ops.permute(attn_output, (0, 2, 1, 3)) + attn_output = ops.reshape( + attn_output, (batch_size, seq_length, self.inner_dim) + ) + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class T5LayerSelfAttention(Module): + def __init__( + self, + config: T5Config, + has_relative_attention_bias: bool = False, + layer_idx: int | None = None, + ): + """Construct a self-attention layer. + + Args: + config: T5 configuration. + has_relative_attention_bias: Whether to use relative attention bias. + layer_idx: Index of the layer. + """ + super().__init__() + self.SelfAttention = T5Attention( + config, + has_relative_attention_bias=has_relative_attention_bias, + layer_idx=layer_idx, + ) + self.layer_norm = T5LayerNorm( + config.d_model, + eps=config.layer_norm_epsilon, + device=config.device, + dtype=config.dtype, + ) + + def __call__( + self, + hidden_states: TensorValue, + attention_mask: TensorValue | None = None, + position_bias: TensorValue | None = None, + layer_head_mask: TensorValue | None = None, + past_key_values: TensorValue | None = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: TensorValue | None = None, + ) -> TensorValue: + """Process hidden states through the self-attention layer. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + attention_mask: Attention mask. + position_bias: Position bias tensor. + layer_head_mask: Mask for attention heads. + past_key_values: Past key values for caching (not implemented). + use_cache: Whether to use cache (not implemented). + output_attentions: Whether to return attention weights. + cache_position: Cache position. + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = hidden_states + attention_output[0] + outputs = (hidden_states,) + attention_output[1:] + return outputs + + +class T5Block(Module): + def __init__( + self, + config: T5Config, + has_relative_attention_bias: bool = False, + layer_idx: int | None = None, + ): + """Construct a T5 block. + + Args: + config: T5 configuration. + has_relative_attention_bias: Whether to use relative attention bias. + layer_idx: Index of the layer. + """ + super().__init__() + layers = list() + self.is_decoder = config.is_decoder + if self.is_decoder: + raise NotImplementedError( + "T5 LayerCrossAttention is not implemented yet." + ) + + layers.append( + T5LayerSelfAttention( + config, + has_relative_attention_bias=has_relative_attention_bias, + layer_idx=layer_idx, + ) + ) + layers.append(T5LayerFF(config)) + self.layer = nn.LayerList(layers) + + def __call__( + self, + hidden_states: TensorValue, + attention_mask: TensorValue | None = None, + position_bias: TensorValue | None = None, + encoder_hidden_states: TensorValue | None = None, + encoder_attention_mask: TensorValue | None = None, + encoder_decoder_position_bias: TensorValue | None = None, + cross_attn_layer_head_mask: TensorValue | None = None, + layer_head_mask: TensorValue | None = None, + past_key_values: TensorValue | None = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: TensorValue | None = None, + ) -> tuple[TensorValue, TensorValue]: + """Process hidden states through the T5 block. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + attention_mask: Attention mask. + position_bias: Position bias tensor. + encoder_hidden_states: Encoder hidden states (not implemented). + encoder_attention_mask: Encoder attention mask (not implemented). + encoder_decoder_position_bias: Encoder-decoder position bias (not implemented). + cross_attn_layer_head_mask: Cross attention layer head mask (not implemented). + layer_head_mask: Mask for attention heads. + past_key_values: Past key values for caching (not implemented). + use_cache: Whether to use cache (not implemented). + output_attentions: Whether to return attention weights. + cache_position: Cache position. + + Returns: + Tuple[TensorValue, TensorValue]: Output tensor and position bias. + """ + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] + + if hidden_states.dtype == DType.float16: + clamp_value = DType.finfo(hidden_states.dtype).max - 1000 + hidden_states = nn.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + do_cross_attention = ( + self.is_decoder and encoder_hidden_states is not None + ) + if do_cross_attention: + raise NotImplementedError( + "T5 CrossAttention is not implemented yet." + ) + + hidden_states = self.layer[-1](hidden_states) + if hidden_states.dtype == DType.float16: + clamp_value = DType.finfo(hidden_states.dtype).max - 1000 + hidden_states = nn.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + outputs = (hidden_states,) + return outputs + attention_outputs + + +class T5Stack(Module): + def __init__( + self, + config: T5Config, + embed_tokens: nn.Embedding | None = None, + ): + """Construct a T5 stack. + + Args: + config: T5 configuration. + embed_tokens: Embedding module. + """ + super().__init__() + self.config = config + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.LayerList( + [ + T5Block( + config, + has_relative_attention_bias=bool(i == 0), + layer_idx=i, + ) + for i in range(config.num_layers) + ] + ) + self.final_layer_norm = T5LayerNorm( + config.d_model, + eps=config.layer_norm_epsilon, + device=config.device, + dtype=config.dtype, + ) + self.dropout = config.dropout_rate + self.device = config.device + self.dtype = config.dtype + + def __call__( + self, + input_ids: TensorValue | None = None, + attention_mask: TensorValue | None = None, + inputs_embeds: TensorValue | None = None, + encoder_hidden_states: TensorValue | None = None, + encoder_attention_mask: TensorValue | None = None, + encoder_decoder_position_bias: TensorValue | None = None, + cross_attn_layer_head_mask: TensorValue | None = None, + layer_head_mask: TensorValue | None = None, + past_key_values: TensorValue | None = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: TensorValue | None = None, + ) -> TensorValue: + """Process input through the T5 stack. + + Args: + input_ids: Input IDs tensor of shape (batch_size, seq_length). + attention_mask: Attention mask tensor of shape (batch_size, seq_length). + inputs_embeds: Input embeddings tensor of shape (batch_size, seq_length, hidden_size). + encoder_hidden_states: Encoder hidden states (not implemented). + encoder_attention_mask: Encoder attention mask (not implemented). + encoder_decoder_position_bias: Encoder-decoder position bias (not implemented). + cross_attn_layer_head_mask: Cross attention layer head mask (not implemented). + layer_head_mask: Mask for attention heads. + past_key_values: Past key values for caching (not implemented). + use_cache: Whether to use cache (not implemented). + output_attentions: Whether to return attention weights. + cache_position: Cache position. + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + use_cache = ( + use_cache if use_cache is not None else self.config.use_cache + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + if input_ids is not None: + inputs_embeds = self.embed_tokens(input_ids) + elif inputs_embeds is None: + raise ValueError( + "You have to specify either input_ids or inputs_embeds" + ) + + if self.is_decoder or use_cache: + raise NotImplementedError("T5 decoder is not implemented yet.") + + hidden_states = inputs_embeds + + if attention_mask is not None: + causal_mask = ( + 1.0 - ops.cast(attention_mask, hidden_states.dtype) + ) * DType.finfo(hidden_states.dtype).min + causal_mask = ops.unsqueeze(causal_mask, 1) + causal_mask = ops.unsqueeze(causal_mask, 1) + else: + causal_mask = None + encoder_extended_attention_mask = None + + position_bias = None + for layer_module in self.block: + layer_outputs = layer_module( + hidden_states, + attention_mask=causal_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = layer_outputs[0] + position_bias = layer_outputs[1] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) + + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + +class T5EncoderModel(Module): + def __init__( + self, + config: T5Config, + ): + """Construct a T5 encoder model. + + Args: + config: T5 configuration for vocabulary size, layer counts, and + device/dtype settings. + """ + super().__init__() + act_info = config.feed_forward_proj.split("-") + config.dense_act_fn = act_info[-1] + config.is_gated_act = act_info[0] == "gated" + + self.shared = nn.Embedding( + config.vocab_size, + config.d_model, + device=config.device, + dtype=config.dtype, + ) + + encoder_config = config + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + + self.encoder = T5Stack(encoder_config, self.shared) + self.device = config.device + self.dtype = config.dtype + + def input_types(self) -> tuple[TensorType, ...]: + """Get input types for the model. + + Returns: + tuple[TensorType, ...]: Input types. + """ + return ( + TensorType( + DType.int64, + shape=["batch_size", "sequence_length"], + device=self.device, + ), + ) + + def __call__( + self, + input_ids: TensorValue | None = None, + attention_mask: TensorValue | None = None, + ) -> TensorValue: + """Process input through the T5 encoder model. + + Args: + input_ids: Input IDs tensor of shape (batch_size, seq_length). + attention_mask: Attention mask tensor of shape (batch_size, seq_length). + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + return self.encoder(input_ids=input_ids, attention_mask=attention_mask) diff --git a/max/python/max/pipelines/lib/config.py b/max/python/max/pipelines/lib/config.py index d3150c90728..098fae555f0 100644 --- a/max/python/max/pipelines/lib/config.py +++ b/max/python/max/pipelines/lib/config.py @@ -28,6 +28,7 @@ from max.driver import DeviceSpec, load_devices from max.engine import InferenceSession from max.graph.quantization import QuantizationEncoding +from max.interfaces import PipelineTask from max.serve.queue.zmq_queue import generate_zmq_ipc_path from pydantic import ( Field, @@ -910,6 +911,11 @@ def _validate_and_resolve_remaining_pipeline_config( # memory estimations. arch.pipeline_model.finalize_pipeline_config(self) + if arch.task == PipelineTask.IMAGE_GENERATION: + # diffusion pipeline does not use KV cache, + # so we can skip profile run. + return + MemoryEstimator.estimate_memory_footprint( self, arch.pipeline_model, @@ -1140,12 +1146,14 @@ def log_basic_config(self) -> None: pipeline_class = get_pipeline_for_task(task, self) # Get reserved memory info from KVCache config - kv_config = self.model._kv_cache - if kv_config._available_cache_memory is None: - raise ValueError( - "KVCache config is not available after config resolution." - ) - memory_str = to_human_readable_bytes(kv_config._available_cache_memory) + memory_str = None + if task != PipelineTask.IMAGE_GENERATION: + kv_config = self.model._kv_cache + if kv_config._available_cache_memory is None: + raise ValueError( + "KVCache config is not available after config resolution." + ) + memory_str = to_human_readable_bytes(kv_config._available_cache_memory) devices_str = ", ".join( f"{d.device_type}[{d.id}]" for d in self.model.device_specs @@ -1434,3 +1442,25 @@ def from_flags( prometheus_metrics_mode=prometheus_metrics_mode, **config_flags, ) + + +class ImageGenerationConfig(PipelineConfig): + model_path: str = Field(default="") + """The repo id of the image generation model.""" + + @staticmethod + def help() -> dict[str, str]: + return { + "model": "The path to the image generation model.", + } + + @classmethod + def from_flags( + cls, flags: dict[str, str], **config_flags: Any + ) -> ImageGenerationConfig: + merged_configs = {**flags, **config_flags} + return cls(**merged_configs) + + def log_basic_config(self) -> None: + logger.info(f"model_info: {self.model}") + logger.info("") \ No newline at end of file diff --git a/max/python/max/pipelines/lib/diffusion_schedulers/__init__.py b/max/python/max/pipelines/lib/diffusion_schedulers/__init__.py new file mode 100644 index 00000000000..df278d92447 --- /dev/null +++ b/max/python/max/pipelines/lib/diffusion_schedulers/__init__.py @@ -0,0 +1,16 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) diff --git a/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py b/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py new file mode 100644 index 00000000000..1eb39c90b33 --- /dev/null +++ b/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py @@ -0,0 +1,852 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import logging +import math +from dataclasses import dataclass + +import numpy as np +from max.driver import CPU, Accelerator, Device +from max.dtype import DType +from max.engine import InferenceSession +from max.experimental import Tensor, random +from max.graph import DeviceRef, Graph, TensorType + +try: + import scipy.stats + + is_scipy_available = True +except ImportError: + is_scipy_available = False + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class FlowMatchEulerDiscreteSchedulerOutput: + """Output class for the scheduler's `step` function output. + + Args: + prev_sample (`Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: Tensor + + +class FlowMatchEulerDiscreteScheduler: + """Euler scheduler. + + Native Modular implementation (ported from diffusers). + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + use_dynamic_shifting (`bool`, defaults to False): + Whether to apply timestep shifting on-the-fly based on the image resolution. + base_shift (`float`, defaults to 0.5): + Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent + with desired output. + max_shift (`float`, defaults to 1.15): + Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be + more exaggerated or stylized. + base_image_seq_len (`int`, defaults to 256): + The base image sequence length. + max_image_seq_len (`int`, defaults to 4096): + The maximum image sequence length. + invert_sigmas (`bool`, defaults to False): + Whether to invert the sigmas. + shift_terminal (`float`, defaults to None): + The end value of the shifted timestep schedule. + use_karras_sigmas (`bool`, defaults to False): + Whether to use Karras sigmas for step sizes in the noise schedule during sampling. + use_exponential_sigmas (`bool`, defaults to False): + Whether to use exponential sigmas for step sizes in the noise schedule during sampling. + use_beta_sigmas (`bool`, defaults to False): + Whether to use beta sigmas for step sizes in the noise schedule during sampling. + time_shift_type (`str`, defaults to "exponential"): + The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". + stochastic_sampling (`bool`, defaults to False): + Whether to use stochastic sampling. + """ + + config_name = "scheduler_config.json" + + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting: bool = False, + base_shift: float | None = 0.5, + max_shift: float | None = 1.15, + base_image_seq_len: int | None = 256, + max_image_seq_len: int | None = 4096, + invert_sigmas: bool = False, + shift_terminal: float | None = None, + use_karras_sigmas: bool | None = False, + use_exponential_sigmas: bool | None = False, + use_beta_sigmas: bool | None = False, + time_shift_type: str = "exponential", + stochastic_sampling: bool = False, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.float32, + **kwargs, + ): + """Initialize the scheduler. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + use_dynamic_shifting (`bool`, defaults to False): + Whether to apply timestep shifting on-the-fly based on the image resolution. + base_shift (`float`, defaults to 0.5): + Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent + with desired output. + max_shift (`float`, defaults to 1.15): + Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be + more exaggerated or stylized. + base_image_seq_len (`int`, defaults to 256): + The base image sequence length. + max_image_seq_len (`int`, defaults to 4096): + The maximum image sequence length. + invert_sigmas (`bool`, defaults to False): + Whether to invert the sigmas. + shift_terminal (`float`, defaults to None): + The end value of the shifted timestep schedule. + use_karras_sigmas (`bool`, defaults to False): + Whether to use Karras sigmas for step sizes in the noise schedule during sampling. + use_exponential_sigmas (`bool`, defaults to False): + Whether to use exponential sigmas for step sizes in the noise schedule during sampling. + use_beta_sigmas (`bool`, defaults to False): + Whether to use beta sigmas for step sizes in the noise schedule during sampling. + time_shift_type (`str`, defaults to "exponential"): + The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". + stochastic_sampling (`bool`, defaults to False): + Whether to use stochastic sampling. + device (`DeviceRef`, defaults to `DeviceRef.CPU()`): + The device to use. + dtype (`DType`, defaults to `DType.float32`): + The dtype to use. + """ + self.num_train_timesteps = num_train_timesteps + self._shift = shift + self.use_dynamic_shifting = use_dynamic_shifting + self.base_shift = base_shift + self.max_shift = max_shift + self.base_image_seq_len = base_image_seq_len + self.max_image_seq_len = max_image_seq_len + self.invert_sigmas = invert_sigmas + self.shift_terminal = shift_terminal + self.use_karras_sigmas = use_karras_sigmas + self.use_exponential_sigmas = use_exponential_sigmas + self.use_beta_sigmas = use_beta_sigmas + self.time_shift_type = time_shift_type + self.stochastic_sampling = stochastic_sampling + self.device = device + self.dtype = dtype + + if self.use_beta_sigmas and not is_scipy_available: + raise ImportError( + "Make sure to install scipy if you want to use beta sigmas." + ) + if ( + sum( + [ + self.use_beta_sigmas, + self.use_exponential_sigmas, + self.use_karras_sigmas, + ] + ) + > 1 + ): + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) + if time_shift_type not in {"exponential", "linear"}: + raise ValueError( + "`time_shift_type` must either be 'exponential' or 'linear'." + ) + + timesteps = np.linspace( + 1, num_train_timesteps, num_train_timesteps, dtype=np.float32 + )[::-1].copy() + + sigmas = timesteps / num_train_timesteps + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self._shift = shift + + self.sigmas = sigmas + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + self.load_model() + + @property + def shift(self) -> float: + """The value used for shifting.""" + return self._shift + + @property + def step_index(self) -> int: + """The index counter for current timestep. It will increase 1 after each scheduler step.""" + return self._step_index + + @property + def begin_index(self) -> int: + """The index for the first timestep. It should be set from pipeline with `set_begin_index` method.""" + return self._begin_index + + def set_begin_index(self, begin_index: int = 0) -> None: + """Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`, defaults to `0`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_shift(self, shift: float) -> None: + """Set the shift value.""" + self._shift = shift + + def scale_noise( + self, + sample: Tensor, + timestep: float | Tensor, + noise: Tensor | None = None, + ) -> Tensor: + """Forward process in flow-matching. + + Args: + sample (`Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + noise (`Tensor`, *optional*): + The noise tensor. + + Returns: + `Tensor`: + A scaled input sample. + """ + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device).cast(sample.dtype) + + if sample.device.type == "mps": + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device).cast( + DType.float32 + ) + timestep = timestep.to(sample.device).cast(DType.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) for t in timestep + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) + + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma: float) -> float: + """Converts sigma to timestep.""" + return sigma * self.num_train_timesteps + + def time_shift(self, mu: float, sigma: float, t: Tensor) -> Tensor: + """Apply time shifting to the timesteps. + + Args: + mu (`float`): + The mu parameter for time shifting. + sigma (`float`): + The sigma parameter for time shifting. + t (`Tensor`): + The timesteps to shift. + + Returns: + `Tensor`: + The shifted timesteps. + """ + if self.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) + + def stretch_shift_to_terminal(self, t: Tensor) -> Tensor: + r"""Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal`. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + def set_timesteps( + self, + num_inference_steps: int | None = None, + device: str | Device | None = None, + sigmas: list[float] | None = None, + mu: float | None = None, + timesteps: list[float] | None = None, + ) -> None: + """Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`, *optional*): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `Device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + sigmas (`List[float]`, *optional*): + Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed + automatically. + mu (`float`, *optional*): + Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep + shifting. + timesteps (`List[float]`, *optional*): + Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed + automatically. + """ + if self.use_dynamic_shifting and mu is None: + raise ValueError( + "`mu` must be passed when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is not None and timesteps is not None: + if len(sigmas) != len(timesteps): + raise ValueError( + "`sigmas` and `timesteps` should have the same length" + ) + + if num_inference_steps is not None: + if (sigmas is not None and len(sigmas) != num_inference_steps) or ( + timesteps is not None and len(timesteps) != num_inference_steps + ): + raise ValueError( + "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided" + ) + else: + num_inference_steps = ( + len(sigmas) if sigmas is not None else len(timesteps) + ) + + self.num_inference_steps = num_inference_steps + + # 1. Prepare default sigmas + is_timesteps_provided = timesteps is not None + + if is_timesteps_provided: + timesteps = np.array(timesteps).astype(np.float32) + + if sigmas is None: + if timesteps is None: + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), + self._sigma_to_t(self.sigma_min), + num_inference_steps, + ) + sigmas = timesteps / self.num_train_timesteps + else: + sigmas = np.array(sigmas).astype(np.float32) + num_inference_steps = len(sigmas) + + # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of + # "exponential" or "linear" type is applied + if self.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + + # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value + if self.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + + # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules + if self.use_karras_sigmas: + sigmas = self._convert_to_karras( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + elif self.use_exponential_sigmas: + sigmas = self._convert_to_exponential( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + elif self.use_beta_sigmas: + sigmas = self._convert_to_beta( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + + if not is_timesteps_provided: + timesteps = sigmas * self.num_train_timesteps + + # 5. Append the terminal sigma value. + # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the + # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi + if self.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.num_train_timesteps + sigmas = np.concatenate([sigmas, np.ones((1,), dtype=sigmas.dtype)]) + else: + sigmas = np.concatenate( + [ + sigmas, + np.zeros((1,), dtype=sigmas.dtype), + ] + ) + + # 6. Convert sigmas and timesteps to tensors and move to specified device + sigmas = ( + Tensor.from_dlpack(sigmas).to(device=device).cast(DType.float32) + ) + + self.timesteps = timesteps + self.sigmas = sigmas + self._step_index = None + self._begin_index = None + + def index_for_timestep( + self, timestep: Tensor, schedule_timesteps: Tensor | None = None + ) -> int: + """Returns the index for a given timestep. + + Args: + timestep (`Tensor`): + The timestep to find the index for. + schedule_timesteps (`Tensor`, *optional*): + The schedule timesteps to search in. If `None`, defaults to `self.timesteps`. + + Returns: + `int`: + The index of the timestep. + """ + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep: Tensor) -> None: + """Initialize the step index based on the given timestep. + + Args: + timestep (`Tensor`): + The current timestep. + """ + if self.begin_index is None: + if isinstance(timestep, Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def _step( + self, + model_output: Tensor, + timestep: float | Tensor, + sample: Tensor, + sigmas: Tensor | None = None, + step_index: Tensor | None = None, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + per_token_timesteps: Tensor | None = None, + return_dict: bool = True, + ) -> FlowMatchEulerDiscreteSchedulerOutput | tuple: + """Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`Tensor`): + The direct output from learned diffusion model. + timestep (`float` or `Tensor`): + The current discrete timestep in the diffusion chain. + sample (`Tensor`): + A current instance of a sample created by the diffusion process. + sigmas (`Tensor`, *optional*): + The sigmas tensor. + step_index (`Tensor`, *optional*): + The step index. + s_churn (`float`): + Churn parameter. + s_tmin (`float`): + Min churn timestep. + s_tmax (`float`): + Max churn timestep. + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + per_token_timesteps (`Tensor`, *optional*): + The timesteps for each token in the sample. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + """ + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.cast(DType.float32) + + if per_token_timesteps is not None: + per_token_sigmas = per_token_timesteps / self.num_train_timesteps + + sigmas = sigmas[:, None, None] + lower_mask = sigmas < per_token_sigmas[None] - 1e-6 + lower_sigmas = lower_mask * sigmas + lower_sigmas, _ = lower_sigmas.max(axis=0) + + current_sigma = per_token_sigmas[..., None] + next_sigma = lower_sigmas[..., None] + dt = current_sigma - next_sigma + else: + sigma = sigmas[step_index] + sigma_next = sigmas[step_index + 1] + + current_sigma = sigma + next_sigma = sigma_next + dt = sigma_next - sigma + + if self.stochastic_sampling: + x0 = sample - current_sigma * model_output + noise = random.normal(sample) + prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise + else: + prev_sample = sample + dt * model_output + + # upon completion increase step index by one + self._step_index += 1 + if per_token_timesteps is None: + # Cast sample back to model compatible dtype + prev_sample = prev_sample.cast(model_output.dtype) + + if not return_dict: + return prev_sample + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + def _convert_to_karras( + self, in_sigmas: Tensor, num_inference_steps: int + ) -> Tensor: + """Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364). + + Args: + in_sigmas (`Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `Tensor`: + The converted sigma values following the Karras noise schedule. + """ + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self, "sigma_min"): + sigma_min = self.sigma_min + else: + sigma_min = None + + if hasattr(self, "sigma_max"): + sigma_max = self.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def _convert_to_exponential( + self, in_sigmas: Tensor, num_inference_steps: int + ) -> Tensor: + """Construct an exponential noise schedule. + + Args: + in_sigmas (`Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `Tensor`: + The converted sigma values following an exponential schedule. + """ + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self, "sigma_min"): + sigma_min = self.sigma_min + else: + sigma_min = None + + if hasattr(self, "sigma_max"): + sigma_max = self.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp( + np.linspace( + math.log(sigma_max), math.log(sigma_min), num_inference_steps + ) + ) + return sigmas + + def _convert_to_beta( + self, + in_sigmas: Tensor, + num_inference_steps: int, + alpha: float = 0.6, + beta: float = 0.6, + ) -> Tensor: + """Construct a beta noise schedule as proposed in [Beta Sampling is All You Need](https://huggingface.co/papers/2407.12173). + + Args: + in_sigmas (`Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + alpha (`float`, *optional*, defaults to `0.6`): + The alpha parameter for the beta distribution. + beta (`float`, *optional*, defaults to `0.6`): + The beta parameter for the beta distribution. + + Returns: + `Tensor`: + The converted sigma values following a beta distribution schedule. + """ + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self, "sigma_min"): + sigma_min = self.sigma_min + else: + sigma_min = None + + if hasattr(self, "sigma_max"): + sigma_max = self.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + + def _time_shift_exponential( + self, mu: float, sigma: float, t: Tensor + ) -> Tensor: + """Apply exponential time shifting. + + Args: + mu (`float`): + The mu parameter. + sigma (`float`): + The sigma parameter. + t (`Tensor`): + The timesteps. + + Returns: + `Tensor`: + The shifted timesteps. + """ + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def _time_shift_linear(self, mu: float, sigma: float, t: Tensor) -> Tensor: + """Apply linear time shifting. + + Args: + mu (`float`): + The mu parameter. + sigma (`float`): + The sigma parameter. + t (`Tensor`): + The timesteps. + + Returns: + `Tensor`: + The shifted timesteps. + """ + return mu / (mu + (1 / t - 1) ** sigma) + + def __len__(self) -> int: + """Returns the number of train timesteps.""" + return self.num_train_timesteps + + def step_input_types(self) -> tuple[TensorType, ...]: + """Return the input types for the step function.""" + return ( + TensorType( + self.dtype, + shape=["batch_size", "image_seq_len", "channel"], + device=self.device, + ), + TensorType( + DType.float32, + shape=[], + device=self.device, + ), + TensorType( + self.dtype, + shape=["batch_size", "image_seq_len", "channel"], + device=self.device, + ), + TensorType( + DType.float32, + shape=["num_inference_steps"], + device=self.device, + ), + TensorType( + DType.int64, + shape=[], + device=DeviceRef.CPU(), + ), + ) + + def load_model(self) -> None: + """Load the model.""" + if self.device.is_cpu(): + session = InferenceSession([CPU()]) + else: + session = InferenceSession([Accelerator()]) + + self.set_begin_index(0) + with Graph( + "scheduler_step", input_types=self.step_input_types() + ) as graph: + outputs = self._step( + *graph.inputs, + return_dict=False, + ) + graph.output(outputs) + compiled_graph = graph + self.session = session.load(compiled_graph) + + def step( + self, + model_output: Tensor, + timestep: float | Tensor, + sample: Tensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + per_token_timesteps: Tensor | None = None, + return_dict: bool = True, + ) -> FlowMatchEulerDiscreteSchedulerOutput | tuple: + """Predict the sample from the previous timestep by reversing the SDE. + + Args: + model_output (`Tensor`): + The direct output from learned diffusion model. + timestep (`float` or `Tensor`): + The current discrete timestep in the diffusion chain. + sample (`Tensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + Churn parameter. + s_tmin (`float`): + Min churn timestep. + s_tmax (`float`): + Max churn timestep. + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + per_token_timesteps (`Tensor`, *optional*): + The timesteps for each token in the sample. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + """ + if self.step_index is None: + self._init_step_index(timestep) + schedule_output = self.session.execute( + model_output, + timestep, + sample, + self.sigmas, + self.step_index, + )[0] + self._step_index += 1 + + if not return_dict: + return (schedule_output,) + return FlowMatchEulerDiscreteSchedulerOutput( + prev_sample=schedule_output + ) diff --git a/max/python/max/pipelines/lib/hf_utils.py b/max/python/max/pipelines/lib/hf_utils.py index 8d8c6a7f80f..0a80d2ca461 100644 --- a/max/python/max/pipelines/lib/hf_utils.py +++ b/max/python/max/pipelines/lib/hf_utils.py @@ -304,7 +304,7 @@ def _repo_exists_with_retry(repo_id: str, revision: str) -> bool: ) time.sleep(delay_in_seconds) - assert False, ( # noqa: B011 + raise AssertionError( "This should never be reached due to the raise in the last attempt" ) @@ -372,20 +372,18 @@ def info(self) -> huggingface_hub.ModelInfo: @cached_property def weight_files(self) -> dict[WeightsFormat, list[str]]: - safetensor_search_pattern = "*.safetensors" - gguf_search_pattern = "*.gguf" - pytorch_search_pattern = "*.bin" + safetensor_search_pattern = "**/*.safetensors" + gguf_search_pattern = "**/*.gguf" weight_files = {} if self.repo_type == RepoType.local: safetensor_paths = glob.glob( - os.path.join(self.repo_id, safetensor_search_pattern) + os.path.join(self.repo_id, safetensor_search_pattern), + recursive=True, ) gguf_paths = glob.glob( - os.path.join(self.repo_id, gguf_search_pattern) - ) - pytorch_paths = glob.glob( - os.path.join(self.repo_id, pytorch_search_pattern) + os.path.join(self.repo_id, gguf_search_pattern), + recursive=True, ) elif self.repo_type == RepoType.online: fs = huggingface_hub.HfFileSystem() @@ -396,9 +394,6 @@ def weight_files(self) -> dict[WeightsFormat, list[str]]: gguf_paths = cast( list[str], fs.glob(f"{self.repo_id}/{gguf_search_pattern}") ) - pytorch_paths = cast( - list[str], fs.glob(f"{self.repo_id}/{pytorch_search_pattern}") - ) else: raise ValueError(f"Unsupported repo type: {self.repo_type}") @@ -626,3 +621,33 @@ def generate_local_model_path(repo_id: str, revision: str) -> str: if not path.is_dir(): raise FileNotFoundError(f"Model path does not exist: {path}") return str(path) + + +def get_model_index_path_for_diffusers( + huggingface_repo: HuggingFaceRepo, +) -> str | None: + model_index_path: str | None = None + + if huggingface_repo.repo_type == RepoType.local: + local_index = Path(huggingface_repo.repo_id) / "model_index.json" + if local_index.exists(): + model_index_path = str(local_index) + else: + raise ValueError( + f"Failed to find model_index.json in {huggingface_repo.repo_id}." + ) + else: + try: + if huggingface_hub.file_exists( + huggingface_repo.repo_id, + "model_index.json", + revision=huggingface_repo.revision, + ): + model_index_path = huggingface_hub.hf_hub_download( + huggingface_repo.repo_id, + "model_index.json", + revision=huggingface_repo.revision, + ) + except Exception: + model_index_path = None + return model_index_path diff --git a/max/python/max/pipelines/lib/image_processor.py b/max/python/max/pipelines/lib/image_processor.py new file mode 100644 index 00000000000..c6897c095bd --- /dev/null +++ b/max/python/max/pipelines/lib/image_processor.py @@ -0,0 +1,226 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import logging + +import numpy as np +import PIL.Image +from max.driver import Tensor as DTensor +from max.dtype import DType +from max.experimental import Tensor +from max.experimental import functional as F +from max.experimental.compile_utils import max_compile +from max.graph import DeviceRef, TensorType, TensorValue, ops +from PIL import Image + +logger = logging.getLogger(__name__) + + +PipelineImageInput = ( + PIL.Image.Image + | np.ndarray + | Tensor + | list[PIL.Image.Image] + | list[np.ndarray] + | list[Tensor] +) + + +class VaeImageProcessor: + config_name = "config.json" + + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 8, + vae_latent_channels: int = 4, + resample: str = "lanczos", + reducing_gap: int | None = None, + do_normalize: bool = True, + do_binarize: bool = False, + do_convert_rgb: bool = False, + do_convert_grayscale: bool = False, + device: DeviceRef = DeviceRef.GPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize the VaeImageProcessor. + + Args: + do_resize (bool, optional): Whether to resize images. Defaults to True. + vae_scale_factor (int, optional): The VAE scale factor. Defaults to 8. + vae_latent_channels (int, optional): The number of latent channels for the VAE. Defaults to 4. + resample (str, optional): The resampling mode for resizing. Defaults to "lanczos". + reducing_gap (int, optional): A reduction gap parameter for resampling. Defaults to None. + do_normalize (bool, optional): Whether to normalize images to [-1, 1]. Defaults to True. + do_binarize (bool, optional): Whether to binarize images. Defaults to False. + do_convert_rgb (bool, optional): Whether to convert images to RGB. Defaults to False. + do_convert_grayscale (bool, optional): Whether to convert images to grayscale. Defaults to False. + device (DeviceRef, optional): The device to use for the image processor. Defaults to DeviceRef.GPU(). + dtype (DType, optional): The data type to use for the image processor. Defaults to DType.bfloat16. + + Raises: + ValueError: If both do_convert_rgb and do_convert_grayscale are set to True. + """ + super().__init__() + if do_convert_rgb and do_convert_grayscale: + raise ValueError( + "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`," + " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.", + " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`", + ) + + self.do_normalize = do_normalize + self.device = device + self.dtype = dtype + self._denormalize_conditionally = max_compile( + self._denormalize_conditionally, + input_types=self._denormalize_conditionally_input_types(), + ) + + @staticmethod + def denormalize(images: np.ndarray | Tensor) -> np.ndarray | Tensor: + r"""Denormalize an image array to [0,1]. + + Args: + images (`np.ndarray` or `Tensor`): + The image array to denormalize. + + Returns: + `np.ndarray` or `Tensor`: + The denormalized image array. + """ + if isinstance(images, (Tensor, TensorValue)): + images = images * 0.5 + 0.5 + images = F.min( + images, + Tensor.constant(1.0, dtype=images.dtype, device=images.device), + ) + images = F.max( + images, + Tensor.constant(0.0, dtype=images.dtype, device=images.device), + ) + return images + return np.clip(images * 0.5 + 0.5, 0, 1) + + def _denormalize_conditionally( + self, + images: np.ndarray | Tensor, + ) -> np.ndarray: + r"""Denormalize a batch of images based on a condition list. + + Args: + images (`np.ndarray` or `Tensor`): + The input image tensor. + """ + images = self.denormalize(images) if self.do_normalize else images + images = ops.cast(images, DType.float32) + return images + + @staticmethod + def max_to_numpy(images: Tensor) -> np.ndarray: + r"""Convert a Max tensor to a NumPy image. + + Args: + images (`Tensor`): + The Max tensor to convert to NumPy format. + + Returns: + `np.ndarray`: + A NumPy array representation of the images. + """ + images = DTensor.to_numpy(images) + images = np.transpose(images, (0, 2, 3, 1)) + return images + + @staticmethod + def numpy_to_pil(images: np.ndarray) -> list[PIL.Image.Image]: + r"""Convert a numpy image or a batch of images to a PIL image. + + Args: + images (`np.ndarray`): + The image array to convert to PIL format. + + Returns: + `list[PIL.Image.Image]`: + A list of PIL images. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [ + Image.fromarray(image.squeeze(), mode="L") for image in images + ] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + def _denormalize_conditionally_input_types(self) -> list[TensorType]: + return [ + TensorType( + shape=("batch_size", "num_channels", "height", "width"), + device=self.device, + dtype=self.dtype, + ), + ] + + def postprocess( + self, + image: Tensor, + output_type: str = "pil", + do_denormalize: list[bool] | None = None, + ) -> PIL.Image.Image | np.ndarray | Tensor: + """Postprocess the image output from tensor to `output_type`. + + Args: + image (`Tensor`): + The image input, should be a Max tensor with shape `B x C x H x W`. + output_type (`str`, *optional*, defaults to `pil`): + The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. + do_denormalize (`list[bool]`, *optional*, defaults to `None`): + Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the + `VaeImageProcessor` config. + + Returns: + `PIL.Image.Image`, `np.ndarray` or `Tensor`: + The postprocessed image. + """ + if not isinstance(image, Tensor) and not isinstance(image, TensorValue): + raise ValueError( + f"Input for postprocessing is in incorrect format: {type(image)}. We only support Max tensor" + ) + if output_type not in ["latent", "max", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " + "`pil`, `np`, `max`, `latent`" + ) + logger.warning(deprecation_message) + output_type = "np" + + if output_type == "latent": + return image + + image = self._denormalize_conditionally(image) + + if output_type == "max": + return image[0] + + image = self.max_to_numpy(image[0]) + + if output_type == "np": + return image + + if output_type == "pil": + return self.numpy_to_pil(image) diff --git a/max/python/max/pipelines/lib/interfaces/__init__.py b/max/python/max/pipelines/lib/interfaces/__init__.py index db7ab9885c5..a6592f356d3 100644 --- a/max/python/max/pipelines/lib/interfaces/__init__.py +++ b/max/python/max/pipelines/lib/interfaces/__init__.py @@ -12,6 +12,7 @@ # ===----------------------------------------------------------------------=== # """Interfaces for MAX pipelines.""" +from .diffusion_pipeline import DiffusionPipeline from .generate import GenerateMixin from .kv_cache import KVCacheMixin, get_paged_manager from .pipeline_model import ( @@ -23,6 +24,7 @@ __all__ = [ "AlwaysSignalBuffersMixin", + "DiffusionPipeline", "GenerateMixin", "KVCacheMixin", "ModelInputs", diff --git a/max/python/max/pipelines/lib/interfaces/configuration_utils.py b/max/python/max/pipelines/lib/interfaces/configuration_utils.py new file mode 100644 index 00000000000..5c0c5d0e439 --- /dev/null +++ b/max/python/max/pipelines/lib/interfaces/configuration_utils.py @@ -0,0 +1,244 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import functools +import inspect +import json +import os +from collections import OrderedDict +from collections.abc import Callable +from typing import Any, NoReturn + + +class ConfigDict(OrderedDict): + def __init__(self, *args, **kwargs): + """Initialize ConfigDict.""" + super().__init__(*args, **kwargs) + + for key, value in self.items(): + setattr(self, key, value) + + self.__frozen = True + + def __delitem__(self, *args, **kwargs): + raise Exception( + f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance." + ) + + def setdefault(self, *args, **kwargs) -> NoReturn: + """Set default value.""" + raise Exception( + f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance." + ) + + def pop(self, *args, **kwargs) -> NoReturn: + """Pop item.""" + raise Exception( + f"You cannot use ``pop`` on a {self.__class__.__name__} instance." + ) + + def update(self, *args, **kwargs) -> NoReturn: + """Update dictionary.""" + raise Exception( + f"You cannot use ``update`` on a {self.__class__.__name__} instance." + ) + + def __setattr__(self, name: str, value: Any): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception( + f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance." + ) + super().__setattr__(name, value) + + def __setitem__(self, name: str, value: Any): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception( + f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance." + ) + super().__setitem__(name, value) + + +class ConfigMixin: + config_name = None + + @classmethod + def load_config( + cls, + pretrained_model_name_or_path: str, + **kwargs, + ) -> dict: + """Load configuration from a pretrained model directory. + + Args: + pretrained_model_name_or_path: Path to pretrained model or model identifier. + **kwargs: Additional arguments. + + Returns: + Dictionary containing the configuration. + """ + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + subfolder = kwargs.pop("subfolder", None) + + if os.path.isfile(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if subfolder is not None and os.path.isfile( + os.path.join( + pretrained_model_name_or_path, subfolder, cls.config_name + ) + ): + config_file = os.path.join( + pretrained_model_name_or_path, subfolder, cls.config_name + ) + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, cls.config_name) + ): + # Load from a pretrained checkpoint + config_file = os.path.join( + pretrained_model_name_or_path, cls.config_name + ) + else: + raise OSError( + f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." + ) + else: + raise ValueError( + f"The provided pretrained_model_name_or_path '{pretrained_model_name_or_path}'" + " is neither a valid local path nor downloaded properly from Hugging Face Hub." + ) + + config_dict = cls._dict_from_json_file(config_file) + return config_dict + + @property + def config(self) -> ConfigDict: + """Returns the config of the class as a dictionary. + + Returns: + `Dict[str, Any]`: Config of the class. + """ + return self._internal_dict + + @classmethod + def _dict_from_json_file(cls, json_file: str | os.PathLike) -> dict: + with open(json_file, encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + @staticmethod + def _get_init_keys(input_class: Any) -> set: + if hasattr(input_class, "components"): + return set(input_class.components.keys()) + return set( + dict(inspect.signature(input_class.__init__).parameters).keys() + ) + + @classmethod + def extract_init_dict(cls, config_dict: dict) -> dict: + """Extract init dictionary from config dictionary. + + Args: + config_dict: Configuration dictionary. + + Returns: + Dictionary containing the init parameters. + """ + expected_keys = cls._get_init_keys(cls) + + init_dict = { + k: config_dict[k] for k in config_dict if k in expected_keys + } + return init_dict + + def register_to_config(self, **kwargs) -> None: + """Register arguments to the config. + + Args: + **kwargs: Arguments to register. + """ + if self.config_name is None: + raise NotImplementedError( + f"Make sure that {self.__class__} has defined a class name `config_name`" + ) + # Special case for `kwargs` used in deprecation warning added to schedulers + # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, + # or solve in a more general way. + kwargs.pop("kwargs", None) + + if not hasattr(self, "_internal_dict"): + internal_dict = kwargs + else: + internal_dict = {**self._internal_dict, **kwargs} + + self._internal_dict = ConfigDict(internal_dict) + + +def register_to_config(init: Callable) -> Callable: + """Register arguments to the config. + + Args: + init: Initialization function of a class. + + Returns: + Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are + automatically sent to `self.register_to_config`. To ignore a specific argument accepted by the init but that + shouldn't be registered in the config, use the `ignore_for_config` class variable. + + Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! + """ + + @functools.wraps(init) + def inner_init(self: Any, *args, **kwargs) -> None: + # Ignore private kwargs in the init. + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + config_init_kwargs = { + k: v for k, v in kwargs.items() if k.startswith("_") + } + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_to_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + ignore = getattr(self, "ignore_for_config", []) + # Get positional arguments aligned with kwargs + new_kwargs = {} + signature = inspect.signature(init) + parameters = { + name: p.default + for i, (name, p) in enumerate(signature.parameters.items()) + if i > 0 and name not in ignore + } + for arg, name in zip(args, parameters.keys(), strict=False): + new_kwargs[name] = arg + + # Then add all kwargs + new_kwargs.update( + { + k: init_kwargs.get(k, default) + for k, default in parameters.items() + if k not in ignore and k not in new_kwargs + } + ) + + # Take note of the parameters that were not present in the loaded config + if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: + new_kwargs["_use_default_values"] = list( + set(new_kwargs.keys()) - set(init_kwargs) + ) + + new_kwargs = {**config_init_kwargs, **new_kwargs} + self.register_to_config(**new_kwargs) + init(self, *args, **init_kwargs) + + return inner_init diff --git a/max/python/max/pipelines/lib/interfaces/diffusion_pipeline.py b/max/python/max/pipelines/lib/interfaces/diffusion_pipeline.py new file mode 100644 index 00000000000..64835ee3c71 --- /dev/null +++ b/max/python/max/pipelines/lib/interfaces/diffusion_pipeline.py @@ -0,0 +1,177 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Pipeline utilities for MAX-optimized pipelines.""" + +from __future__ import annotations + +import os +from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from max.config import load_config +from max.driver import load_devices +from max.graph import DeviceRef +from max.graph.weights import load_weights +from max.pipelines.lib.interfaces.max_model import MaxModel +from tqdm import tqdm + +if TYPE_CHECKING: + from ..config import PipelineConfig + from ..diffusion_schedulers import FlowMatchEulerDiscreteScheduler + + +class DiffusionPipeline(ABC): + config_name: str | None = None + """ + The name of the config file of the pipeline. + + It can be found in the downloaded path or HuggingFace hub. + It's usually "model_index.json" or "config.json" for Diffusion models. + """ + + components: ( + dict[str, type[MaxModel] | type[FlowMatchEulerDiscreteScheduler]] | None + ) = None + """The components of the pipeline. + + It can be found in the downloaded path or HuggingFace hub. + It's usually contains text_encoder, tokenizer, transformer, vae, etc. + """ + + def __init__( + self, + pipeline_config: PipelineConfig, + cached_folder: str, + **kwargs: Any, + ) -> DiffusionPipeline: + """Load a pipeline from a pretrained model. + + Args: + pipeline_config: Pipeline configuration for model and runtime setup. + cached_folder: Local path to the downloaded model snapshot. + **kwargs: Additional pipeline-specific arguments. + """ + self.pipeline_config = pipeline_config + self.devices = load_devices(pipeline_config.model.device_specs) + + # Load sub models + loaded_sub_models = self.load_sub_models(cached_folder) + for name, model in loaded_sub_models.items(): + setattr(self, name, model) + + self.init_remaining_components() + + @abstractmethod + def init_remaining_components(self) -> None: + pass + + def load_sub_models( + self, + pretrained_model_name_or_path: str | os.PathLike, + ) -> dict: + """Load sub-models for the pipeline. + + Args: + pretrained_model_name_or_path: Path to pretrained model. + + Returns: + Dictionary containing the loaded sub-models. + """ + loaded_sub_models = {} + if self.components is None: + raise ValueError( + f"`components` for {self.__class__.__name__} pipeline is not set. " + "Please set proper components based on its sub-directories in the downloaded path." + ) + for name, component_class in tqdm( + self.components.items(), desc="Loading sub models" + ): + component_path = os.path.join(pretrained_model_name_or_path, name) + if "tokenizer" in name: + # NOTE: Currently, we are using tokenizers from transformers. + # TODO(minkyu): Check if we can use Tokenizer in Max, + # and remove this conditional path. + loaded_sub_models[name] = component_class.from_pretrained( + component_path + ) + continue + + if ( + not hasattr(component_class, "config_name") + or component_class.config_name is None + ): + raise ValueError( + f"`config_name` for {component_class.__name__} is not set. " + "Please set proper config file name in the downloaded path." + ) + config = load_config( + f"{component_path}/{component_class.config_name}" + ) + if issubclass(component_class, MaxModel): + weight_paths = [ + Path(pretrained_model_name_or_path) / weight_path + for weight_path in self.pipeline_config.model.weight_path + if weight_path.split("/")[0] == name + ] + loaded_sub_models[name] = component_class( + config=config, + encoding=self.pipeline_config.model.quantization_encoding, + devices=self.devices, + weights=load_weights(weight_paths), + ) + else: + loaded_sub_models[name] = component_class( + **config, + device=DeviceRef.from_device(self.devices[0]), + dtype=self.pipeline_config.model.quantization_encoding.dtype, + ) + + return loaded_sub_models + + def finalize_pipeline_config(self) -> None: + return + + def _execution_device(self) -> DeviceRef: + r"""Returns the device on which the pipeline's models will be executed. + + This property checks pipeline components to determine the execution device. + It supports MAX models (with DeviceRef device attribute). + Similar structure to diffusers' _execution_device but returns DeviceRef instead of DeviceRef. + + Returns: + DeviceRef: The execution device (GPU if available, otherwise CPU). + """ + # Check MAX models - prioritize GPU + # Similar to diffusers' _execution_device but for MAX models (not torch.nn.Module) + sub_models = {k: getattr(self, k) for k in self.components} + for name, model in sub_models.items(): + exclude_from_cpu_offload = getattr( + self, "_exclude_from_cpu_offload", set() + ) + if name in exclude_from_cpu_offload: + continue + + if hasattr(model, "device") and isinstance(model.device, DeviceRef): + return model.device + + if hasattr(self, "device"): + try: + device = self.device + if isinstance(device, DeviceRef): + return device + except Exception: + pass + + return DeviceRef.CPU() diff --git a/max/python/max/pipelines/lib/interfaces/max_model.py b/max/python/max/pipelines/lib/interfaces/max_model.py new file mode 100644 index 00000000000..3f323d81ad8 --- /dev/null +++ b/max/python/max/pipelines/lib/interfaces/max_model.py @@ -0,0 +1,45 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from max.driver import Device +from max.engine import Model +from max.graph.weights import Weights + +if TYPE_CHECKING: + from max.pipelines.lib import SupportedEncoding + + +class MaxModel(ABC): + """Base interface for pipeline models with weight-backed execution.""" + + def __init__( + self, + config: dict, + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + ) -> None: + self.config = config + self.encoding = encoding + self.devices = devices + self.weights = weights + + @abstractmethod + def load_model(self) -> Model: + """Load and return a runtime model instance.""" + ... diff --git a/max/python/max/pipelines/lib/model_config.py b/max/python/max/pipelines/lib/model_config.py index 0b37899b012..71e2691f33b 100644 --- a/max/python/max/pipelines/lib/model_config.py +++ b/max/python/max/pipelines/lib/model_config.py @@ -509,7 +509,6 @@ def validate_and_resolve_quantization_encoding_weight_path( are consistent. Args: - weight_path: The path to the weight file. default_encoding: The default encoding to use if no encoding is provided. """ @@ -582,6 +581,7 @@ def _validate_and_resolve_dtype_casting( Note: We currently only support float32 to bfloat16 weight type casting. Args: + from_encoding: The source encoding to cast from. to_encoding: The desired encoding to cast to. Raises: @@ -806,6 +806,9 @@ def _resolve_weight_path( encoding=self._applied_dtype_cast_from ) + if not weight_files: + weight_files = self.huggingface_weight_repo.weight_files + if default_weight_files := weight_files.get( default_weights_format, [] ): @@ -945,6 +948,7 @@ def _local_weight_path(self, relative_path: Path) -> str | None: # NOTE(bduke): do this even for online repositories, because upstream # code originating from `huggingface_hub.hf_hub_download` returns # absolute paths for cached files. + relative_path = Path(relative_path) if relative_path.exists() and relative_path.is_file(): return str(relative_path.resolve()) diff --git a/max/python/max/pipelines/lib/pipeline_variants/__init__.py b/max/python/max/pipelines/lib/pipeline_variants/__init__.py index 991ed73eea4..2c04acdbcdd 100644 --- a/max/python/max/pipelines/lib/pipeline_variants/__init__.py +++ b/max/python/max/pipelines/lib/pipeline_variants/__init__.py @@ -11,4 +11,5 @@ # limitations under the License. # ===----------------------------------------------------------------------=== # +from .image_generation import ImageGenerationPipeline from .text_generation import TextGenerationPipeline diff --git a/max/python/max/pipelines/lib/pipeline_variants/image_generation.py b/max/python/max/pipelines/lib/pipeline_variants/image_generation.py new file mode 100644 index 00000000000..5841263e36f --- /dev/null +++ b/max/python/max/pipelines/lib/pipeline_variants/image_generation.py @@ -0,0 +1,219 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from __future__ import annotations + +import fnmatch +import logging +import os +import re +from pathlib import Path +from typing import TYPE_CHECKING + +import huggingface_hub +import requests +from huggingface_hub.utils import EntryNotFoundError, OfflineModeIsEnabled +from max.config import load_config +from max.interfaces import ( + ImageGenerationInputs, + ImageGenerationOutput, + Pipeline, + RequestID, +) +from requests.exceptions import HTTPError + +from ..config_enums import RepoType +from ..interfaces import DiffusionPipeline + +if TYPE_CHECKING: + from ..config import PipelineConfig + +logger = logging.getLogger(__name__) + + +class ImageGenerationPipeline( + Pipeline[ImageGenerationInputs, ImageGenerationOutput], +): + """Pipeline wrapper for diffusion image generation.""" + + def __init__( + self, + pipeline_config: PipelineConfig, + diffusion_pipeline: type[DiffusionPipeline], + ) -> None: + # Download checkpoints if required + # NOTE: Unlike TextGenerationPipeline where each file, + # such as configs and weights, are downloaded individually, + # DiffusionPipeline downloads the entire snapshot at once, + # since it normally contains multiple components. + pretrained_model_name_or_path = ( + pipeline_config.model.huggingface_model_repo.repo_id + ) + if ( + pipeline_config.model.huggingface_model_repo.repo_type + == RepoType.online + ): + cached_folder = self.download( + pretrained_model_name_or_path, + config_name=diffusion_pipeline.config_name, + force_download=pipeline_config.model.force_download, + revision=pipeline_config.model.huggingface_model_revision, + ) + else: + cached_folder = pretrained_model_name_or_path + + self._diffusion_pipeline = diffusion_pipeline( + pipeline_config, cached_folder + ) + + def download( + self, + pretrained_model_name: str | os.PathLike, + config_name: str | None, + force_download: bool = False, + revision: str | None = None, + ) -> str: + """Download the pipeline components from the Hugging Face Hub. + + Args: + pretrained_model_name: Model identifier. + config_name: Pipeline config filename in the repo. + force_download: Whether to force download. + revision: Model revision. + + Returns: + Path to the downloaded model folder. + """ + try: + info = huggingface_hub.model_info( + pretrained_model_name, revision=revision + ) + except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e: + logger.warning( + f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache." + ) + model_info_call_error = ( + e # save error to reraise it if model is not cached locally + ) + + if config_name is None: + raise ValueError( + f"config_name for {pretrained_model_name} pipeline is not set. " + "Please set proper config file name from huggingface hub." + ) + try: + config_file = huggingface_hub.hf_hub_download( + pretrained_model_name, + config_name, + revision=revision, + force_download=force_download, + ) + except EntryNotFoundError as e: + raise ValueError( + f"config file {config_name} not found for {pretrained_model_name} pipeline. " + "Please check if the config file name is correct." + ) from e + + config_dict = load_config(config_file) + ignore_filenames = config_dict.pop("_ignore_files", []) + + filenames = {sibling.rfilename for sibling in info.siblings} + filenames = set(filenames) - set(ignore_filenames) + + ignore_patterns = [ + "*.bin", + "*.msgpack", + "*.onnx", + "*.pb", + "*.bin.index.*json", + "*.msgpack.index.*json", + "*.onnx.index.*json", + "*.pb.index.*json", + ] + + allow_patterns = ["*/*"] + allow_patterns += [ + "scheduler_config.json", + "config.json", + config_name, + ] + re_ignore_pattern = [ + re.compile(fnmatch.translate(p)) for p in ignore_patterns + ] + re_allow_pattern = [ + re.compile(fnmatch.translate(p)) for p in allow_patterns + ] + + expected_files = [ + f + for f in filenames + if not any(p.match(f) for p in re_ignore_pattern) + ] + expected_files = [ + f + for f in expected_files + if any(p.match(f) for p in re_allow_pattern) + ] + + snapshot_folder = Path(config_file).parent + pipeline_is_cached = all( + (snapshot_folder / f).is_file() for f in expected_files + ) + + if pipeline_is_cached and not force_download: + # if the pipeline is cached, we can directly return it + # else call snapshot_download + return snapshot_folder + + # download all allow_patterns - ignore_patterns + try: + cached_folder = huggingface_hub.snapshot_download( + pretrained_model_name, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + + return cached_folder + + except FileNotFoundError: + # Means we tried to load pipeline with `local_files_only=True` but the files have not been found in local cache. + # This can happen in two cases: + # 1. If the user passed `local_files_only=True` => we raise the error directly + # 2. If we forced `local_files_only=True` when `model_info` failed => we raise the initial error + if model_info_call_error is None: + # 1. user passed `local_files_only=True` + raise + else: + # 2. we forced `local_files_only=True` when `model_info` failed + raise OSError( + f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occurred" + " while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace" + " above." + ) from model_info_call_error + + def execute(self, inputs: ImageGenerationInputs) -> ImageGenerationOutput: + outputs = self._diffusion_pipeline( + prompt=inputs.prompt, + negative_prompt=inputs.negative_prompt, + true_cfg_scale=inputs.true_cfg_scale, + height=inputs.height, + width=inputs.width, + num_inference_steps=inputs.num_inference_steps, + guidance_scale=inputs.guidance_scale, + num_images_per_prompt=inputs.num_images_per_prompt, + ) + return ImageGenerationOutput(images=outputs.images) + + def release(self, request_id: RequestID) -> None: + pass diff --git a/max/python/max/pipelines/lib/registry.py b/max/python/max/pipelines/lib/registry.py index 566bcb6d5a0..3445e372971 100644 --- a/max/python/max/pipelines/lib/registry.py +++ b/max/python/max/pipelines/lib/registry.py @@ -16,6 +16,7 @@ from __future__ import annotations import functools +import json import logging from collections.abc import Callable from dataclasses import dataclass, field @@ -38,6 +39,7 @@ from transformers import ( AutoConfig, AutoTokenizer, + PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, ) @@ -49,8 +51,9 @@ from .audio_generator_pipeline import AudioGeneratorPipeline from .config_enums import RopeType, SupportedEncoding from .embeddings_pipeline import EmbeddingsPipeline -from .hf_utils import HuggingFaceRepo +from .hf_utils import HuggingFaceRepo, get_model_index_path_for_diffusers from .interfaces import PipelineModel +from .pipeline_variants.image_generation import ImageGenerationPipeline from .pipeline_variants.text_generation import TextGenerationPipeline from .speculative_decoding import ( EAGLESpeculativeDecodingPipeline, @@ -74,6 +77,7 @@ def get_pipeline_for_task( | type[StandaloneSpeculativeDecodingPipeline] | type[SpeechTokenGenerationPipeline] | type[EAGLESpeculativeDecodingPipeline] + | type[ImageGenerationPipeline] ): if task == PipelineTask.TEXT_GENERATION: if pipeline_config._speculative is not None: @@ -100,6 +104,8 @@ def get_pipeline_for_task( return AudioGeneratorPipeline elif task == PipelineTask.SPEECH_TOKEN_GENERATION: return SpeechTokenGenerationPipeline + elif task == PipelineTask.IMAGE_GENERATION: + return ImageGenerationPipeline @dataclass(frozen=False) @@ -290,7 +296,7 @@ def retrieve_architecture( def get_active_huggingface_config( self, huggingface_repo: HuggingFaceRepo - ) -> AutoConfig: + ) -> AutoConfig | PretrainedConfig: """Retrieves or creates a cached HuggingFace AutoConfig for the given model configuration. @@ -311,7 +317,22 @@ def get_active_huggingface_config( Returns: AutoConfig: The HuggingFace configuration object for the model. """ - if huggingface_repo not in self._cached_huggingface_configs: + model_index_path = get_model_index_path_for_diffusers(huggingface_repo) + + if model_index_path is not None: + with open(model_index_path, encoding="utf-8") as f: + model_index = json.load(f) + + class_name = model_index.get("_class_name") + if not class_name or not isinstance(class_name, str): + raise ValueError( + f"Diffusers-style repository '{huggingface_repo.repo_id}' is missing a valid '_class_name' in model_index.json" + ) + + self._cached_huggingface_configs[huggingface_repo] = ( + PretrainedConfig(architectures=[class_name]) + ) + else: self._cached_huggingface_configs[huggingface_repo] = ( AutoConfig.from_pretrained( huggingface_repo.repo_id, @@ -444,87 +465,104 @@ def retrieve_factory( assert arch is not None devices = load_devices(pipeline_config.model.device_specs) - max_length = arch.pipeline_model.calculate_max_seq_len( - pipeline_config, huggingface_config=huggingface_config - ) - - # Old Mistral model like Mistral-7B-Instruct-v0.3 uses LlamaTokenizer - # and suffers from the whitespace decoding bug. So, we enable the fix - # for only MistralModel in order to avoid any issues with performance - # for rest of the models. This can be applied more generically once - # we have more time verifying this for all the models. - # More information: - # https://linear.app/modularml/issue/AIPIPE-197/add-support-for-mistral-7b-instruct-v03 - # TODO: remove this pipeline_model.__name__ check - if ( - arch.pipeline_model.__name__ in ("MistralModel", "Phi3Model") - and arch.tokenizer is TextTokenizer - ): - text_tokenizer = cast(type[TextTokenizer], arch.tokenizer) - tokenizer = text_tokenizer( - pipeline_config.model.model_path, - pipeline_config=pipeline_config, - revision=pipeline_config.model.huggingface_model_revision, - max_length=max_length, - trust_remote_code=pipeline_config.model.trust_remote_code, - enable_llama_whitespace_fix=True, - chat_template=pipeline_config.retrieve_chat_template(), - context_validators=arch.context_validators, + if task != PipelineTask.IMAGE_GENERATION: + max_length = arch.pipeline_model.calculate_max_seq_len( + pipeline_config, huggingface_config=huggingface_config ) - else: - tokenizer = arch.tokenizer( - model_path=pipeline_config.model.model_path, - pipeline_config=pipeline_config, - revision=pipeline_config.model.huggingface_model_revision, - max_length=max_length, - trust_remote_code=pipeline_config.model.trust_remote_code, - chat_template=pipeline_config.retrieve_chat_template(), - context_validators=arch.context_validators, + + # Old Mistral model like Mistral-7B-Instruct-v0.3 uses LlamaTokenizer + # and suffers from the whitespace decoding bug. So, we enable the fix + # for only MistralModel in order to avoid any issues with performance + # for rest of the models. This can be applied more generically once + # we have more time verifying this for all the models. + # More information: + # https://linear.app/modularml/issue/AIPIPE-197/add-support-for-mistral-7b-instruct-v03 + # TODO: remove this pipeline_model.__name__ check + if ( + arch.pipeline_model.__name__ in ("MistralModel", "Phi3Model") + and arch.tokenizer is TextTokenizer + ): + text_tokenizer = cast(type[TextTokenizer], arch.tokenizer) + tokenizer = text_tokenizer( + pipeline_config.model.model_path, + pipeline_config=pipeline_config, + revision=pipeline_config.model.huggingface_model_revision, + max_length=max_length, + trust_remote_code=pipeline_config.model.trust_remote_code, + enable_llama_whitespace_fix=True, + chat_template=pipeline_config.retrieve_chat_template(), + context_validators=arch.context_validators, + ) + else: + tokenizer = arch.tokenizer( + model_path=pipeline_config.model.model_path, + pipeline_config=pipeline_config, + revision=pipeline_config.model.huggingface_model_revision, + max_length=max_length, + trust_remote_code=pipeline_config.model.trust_remote_code, + chat_template=pipeline_config.retrieve_chat_template(), + context_validators=arch.context_validators, + ) + # Cast tokenizer to the proper type for text generation pipeline compatibility + typed_tokenizer = cast( + PipelineTokenizer[ + Any, npt.NDArray[np.integer[Any]], TextGenerationRequest + ], + tokenizer, ) - # Cast tokenizer to the proper type for text generation pipeline compatibility - typed_tokenizer = cast( - PipelineTokenizer[ - Any, npt.NDArray[np.integer[Any]], TextGenerationRequest - ], - tokenizer, - ) - # For speculative decoding, retrieve draft model's architecture - factory_kwargs: dict[str, Any] = { - "pipeline_config": pipeline_config, - "pipeline_model": arch.pipeline_model, - "eos_token_id": tokenizer.eos, - "weight_adapters": arch.weight_adapters, - "tokenizer": typed_tokenizer, - } - - # If using speculative decoding, add draft model-specific parameters - if pipeline_config.draft_model_config is not None: - draft_arch = self.retrieve_architecture( - huggingface_repo=pipeline_config.draft_model_config.huggingface_weight_repo, - use_module_v3=pipeline_config.use_module_v3, + # For speculative decoding, retrieve draft model's architecture + factory_kwargs: dict[str, Any] = { + "pipeline_config": pipeline_config, + "pipeline_model": arch.pipeline_model, + "eos_token_id": tokenizer.eos, + "weight_adapters": arch.weight_adapters, + "tokenizer": typed_tokenizer, + } + + # If using speculative decoding, add draft model-specific parameters + if pipeline_config.draft_model_config is not None: + draft_arch = self.retrieve_architecture( + huggingface_repo=pipeline_config.draft_model_config.huggingface_weight_repo, + use_module_v3=pipeline_config.use_module_v3, + ) + if draft_arch is None: + raise ValueError( + f"MAX-Optimized architecture not found for draft model " + f"'{pipeline_config.draft_model_config.model_path}'" + ) + factory_kwargs["draft_pipeline_model"] = ( + draft_arch.pipeline_model + ) + factory_kwargs["draft_weight_adapters"] = ( + draft_arch.weight_adapters + ) + + pipeline_factory = cast( + Callable[[], PipelineTypes], + functools.partial( # type: ignore + pipeline_class, **factory_kwargs + ), ) - if draft_arch is None: + + if tokenizer.eos is None: raise ValueError( - f"MAX-Optimized architecture not found for draft model " - f"'{pipeline_config.draft_model_config.model_path}'" + "tokenizer.eos value is None, tokenizer configuration is incomplete." ) - factory_kwargs["draft_pipeline_model"] = draft_arch.pipeline_model - factory_kwargs["draft_weight_adapters"] = draft_arch.weight_adapters - - pipeline_factory = cast( - Callable[[], PipelineTypes], - functools.partial( # type: ignore - pipeline_class, **factory_kwargs - ), - ) - if tokenizer.eos is None: - raise ValueError( - "tokenizer.eos value is None, tokenizer configuration is incomplete." + return tokenizer, pipeline_factory + else: + factory_kwargs = { + "pipeline_config": pipeline_config, + "diffusion_pipeline": arch.pipeline_model, + } + pipeline_factory = cast( + Callable[[], PipelineTypes], + functools.partial( # type: ignore + pipeline_class, **factory_kwargs + ), ) - - return tokenizer, pipeline_factory + return None, pipeline_factory def retrieve_context_type( self, pipeline_config: PipelineConfig diff --git a/max/python/max/serve/api_server.py b/max/python/max/serve/api_server.py index 340e5961a25..c175ecc5edb 100644 --- a/max/python/max/serve/api_server.py +++ b/max/python/max/serve/api_server.py @@ -29,6 +29,7 @@ from max.interfaces import PipelinesFactory, PipelineTask, PipelineTokenizer from max.pipelines.lib import PIPELINE_REGISTRY, PipelineConfig from max.serve.config import APIType, MetricRecordingMethod, Settings +from max.serve.pipelines.diffusion import ImageGeneratorPipeline from max.serve.pipelines.llm import ( AudioGeneratorPipeline, TokenGeneratorPipeline, @@ -106,35 +107,44 @@ async def lifespan( ) METRICS.configure(client=metric_client) - # start model worker - scheduler_zmq_configs = SchedulerZmqConfigs( - serving_settings.pipeline_task, - context_type=PIPELINE_REGISTRY.retrieve_context_type( - serving_settings.pipeline_config - ), - ) - worker_monitor = await exit_stack.enter_async_context( - start_model_worker( - serving_settings.model_factory, - serving_settings.pipeline_config, - settings, - metric_client, - scheduler_zmq_configs=scheduler_zmq_configs, + # Image generation uses a direct pipeline without model worker + if serving_settings.pipeline_task == PipelineTask.IMAGE_GENERATION: + scheduler_zmq_configs = None + lora_queue = None + else: + # start model worker + scheduler_zmq_configs = SchedulerZmqConfigs( + serving_settings.pipeline_task, + context_type=PIPELINE_REGISTRY.retrieve_context_type( + serving_settings.pipeline_config + ), + ) + worker_monitor = await exit_stack.enter_async_context( + start_model_worker( + serving_settings.model_factory, + serving_settings.pipeline_config, + settings, + metric_client, + scheduler_zmq_configs=scheduler_zmq_configs, + ) ) - ) - lora_queue: LoRAQueue | None = ( - LoRAQueue( - serving_settings.pipeline_config.zmq_endpoint_base, - serving_settings.pipeline_config.lora_config.lora_paths, + lora_queue = ( + LoRAQueue( + serving_settings.pipeline_config.zmq_endpoint_base, + serving_settings.pipeline_config.lora_config.lora_paths, + ) + if serving_settings.pipeline_config.lora_config + else None ) - if serving_settings.pipeline_config.lora_config - else None - ) METRICS.pipeline_load(serving_settings.pipeline_config.model.model_name) - pipeline: TokenGeneratorPipeline | AudioGeneratorPipeline + pipeline: ( + TokenGeneratorPipeline + | AudioGeneratorPipeline + | ImageGeneratorPipeline + ) if serving_settings.pipeline_task in ( PipelineTask.TEXT_GENERATION, PipelineTask.EMBEDDINGS_GENERATION, @@ -152,6 +162,12 @@ async def lifespan( lora_queue=lora_queue, scheduler_zmq_configs=scheduler_zmq_configs, ) + elif serving_settings.pipeline_task == PipelineTask.IMAGE_GENERATION: + # Image generation uses a simpler pipeline without scheduler + pipeline = ImageGeneratorPipeline( + model_name=serving_settings.pipeline_config.model.model_name, + pipeline_config=serving_settings.pipeline_config, + ) else: raise ValueError( f"Unsupported pipeline task: {serving_settings.pipeline_task}" diff --git a/max/python/max/serve/pipelines/diffusion.py b/max/python/max/serve/pipelines/diffusion.py new file mode 100644 index 00000000000..76601d0a226 --- /dev/null +++ b/max/python/max/serve/pipelines/diffusion.py @@ -0,0 +1,71 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Image generation pipeline for serving diffusion models.""" + +from __future__ import annotations + +import logging +from typing import Any + +from typing_extensions import Self + + +class ImageGeneratorPipeline: + """Pipeline wrapper for image generation. + + This is a simplified pipeline for image generation that doesn't use + the same streaming/batching infrastructure as text generation. + """ + + def __init__( + self, + model_name: str, + pipeline_config: Any, + ) -> None: + self.model_name = model_name + self.pipeline_config = pipeline_config + self._generator: Any | None = None + self.logger = logging.getLogger( + self.__class__.__module__ + "." + self.__class__.__qualname__ + ) + + async def __aenter__(self) -> Self: + """Initialize the image generator.""" + from max.entrypoints.diffusion import ImageGenerator + + self.logger.info( + "Loading image generator for model: %s", self.model_name + ) + self._generator = ImageGenerator(self.pipeline_config) + self.logger.info("Image generator loaded successfully") + return self + + async def __aexit__( + self, et: type[BaseException] | None, exc: BaseException | None, tb: Any + ) -> bool | None: + """Clean up resources.""" + if self._generator is not None: + del self._generator + self._generator = None + self.logger.info("Image generator pipeline closed: %s", self.model_name) + return None + + @property + def generator(self) -> Any: + """Get the underlying image generator.""" + if self._generator is None: + raise RuntimeError( + "Image generator not initialized. Use async with." + ) + return self._generator diff --git a/max/python/max/serve/router/openai_routes.py b/max/python/max/serve/router/openai_routes.py index 8eb356e2620..61076cbf306 100644 --- a/max/python/max/serve/router/openai_routes.py +++ b/max/python/max/serve/router/openai_routes.py @@ -43,6 +43,7 @@ from max.interfaces import ( AudioGenerationRequest, GenerationStatus, + ImageGenerationRequest, LoRAOperation, LoRARequest, LoRAStatus, @@ -61,6 +62,7 @@ from max.profiler import traced from max.serve.config import Settings from max.serve.parser import LlamaToolParser, parse_json_from_text +from max.serve.pipelines.diffusion import ImageGeneratorPipeline from max.serve.pipelines.llm import ( AudioGeneratorPipeline, TokenGeneratorOutput, @@ -1343,7 +1345,7 @@ async def openai_get_models(request: Request) -> ListModelsResponse: Model(id=pipeline.model_name, object="model", created=None, owned_by="") ] - if lora_queue := request.app.state.pipeline.lora_queue: + if lora_queue := getattr(request.app.state.pipeline, "lora_queue", None): model_list += [ Model(id=lora, object="model", created=None, owned_by="") for lora in lora_queue.list_loras() @@ -1429,6 +1431,103 @@ async def create_streaming_audio_speech( raise HTTPException(status_code=400, detail="Value error.") from e +@router.post("/images/generations", response_model=None) +async def create_image_generation( + request: Request, +) -> JSONResponse: + """Image generation endpoint following OpenAI /v1/images/generations API. + + This endpoint generates images from text prompts using diffusion models. + """ + request_id = request.state.request_id + record_request_start() + stopwatch = StopWatch() + + try: + # Parse request body + body = await request.body() + import json as json_module + + body_dict = json_module.loads(body) + + # Create ImageGenerationRequest from body + image_request = ImageGenerationRequest( + prompt=body_dict.get("prompt", ""), + model=body_dict.get("model"), + n=body_dict.get("n", 1), + size=body_dict.get("size", "1024x1024"), + quality=body_dict.get("quality", "standard"), + response_format=body_dict.get("response_format", "b64_json"), + style=body_dict.get("style"), + num_inference_steps=body_dict.get("num_inference_steps", 50), + guidance_scale=body_dict.get("guidance_scale", 3.5), + seed=body_dict.get("seed"), + ) + + # Get the pipeline + pipeline = request.app.state.pipeline + if not isinstance(pipeline, ImageGeneratorPipeline): + raise HTTPException( + status_code=400, + detail="This server is not configured for image generation. " + "Please start with --task image_generation.", + ) + + # Generate images using the pipeline's generator + response = pipeline.generator.create(image_request) + + record_request_end( + status_code=200, + request_path="/v1/images/generations", + elapsed_ms=stopwatch.elapsed_ms, + ) + + return JSONResponse(content=response.to_dict()) + + except JSONDecodeError as e: + logger.exception("JSONDecodeError in request %s", request_id) + record_request_end( + status_code=400, + request_path="/v1/images/generations", + elapsed_ms=stopwatch.elapsed_ms, + ) + raise HTTPException(status_code=400, detail="Missing JSON.") from e + except (TypeError, ValidationError) as e: + logger.exception("TypeError in request %s", request_id) + record_request_end( + status_code=400, + request_path="/v1/images/generations", + elapsed_ms=stopwatch.elapsed_ms, + ) + raise HTTPException(status_code=400, detail="Invalid JSON.") from e + except InputError as e: + logger.warning( + "Input validation error in request %s: %s", request_id, str(e) + ) + record_request_end( + status_code=400, + request_path="/v1/images/generations", + elapsed_ms=stopwatch.elapsed_ms, + ) + raise HTTPException(status_code=400, detail=str(e)) from e + except ValueError as e: + logger.exception("ValueError in request %s", request_id) + record_request_end( + status_code=400, + request_path="/v1/images/generations", + elapsed_ms=stopwatch.elapsed_ms, + ) + raise HTTPException(status_code=400, detail="Value error.") from e + except Exception as e: + logger.exception("Error in image generation request %s", request_id) + record_request_end( + status_code=500, + request_path="/v1/images/generations", + elapsed_ms=stopwatch.elapsed_ms, + ) + raise HTTPException(status_code=500, detail=str(e)) from e + + @router.post("/load_lora_adapter", response_model=None) async def load_lora_adapter( request: Request, diff --git a/max/python/max/serve/scheduler/__init__.py b/max/python/max/serve/scheduler/__init__.py index fbbb04cd511..5fc4071a08e 100644 --- a/max/python/max/serve/scheduler/__init__.py +++ b/max/python/max/serve/scheduler/__init__.py @@ -16,6 +16,7 @@ from max.interfaces import ( EmbeddingsContext, + ImageGenerationContext, MAXPullQueue, Pipeline, PipelineInputsType, @@ -45,6 +46,10 @@ from .config import TokenGenerationSchedulerConfig from .decode_scheduler import load_decode_scheduler from .embeddings_scheduler import EmbeddingsScheduler, EmbeddingsSchedulerConfig +from .image_generation_scheduler import ( + ImageGenerationScheduler, + ImageGenerationSchedulerConfig, +) from .prefill_scheduler import load_prefill_scheduler from .text_generation_scheduler import load_text_generation_scheduler @@ -54,6 +59,8 @@ "CancelRequest", "EmbeddingsScheduler", "EmbeddingsSchedulerConfig", + "ImageGenerationScheduler", + "ImageGenerationSchedulerConfig", "PrefillRequest", "PrefillResponse", "TokenGenerationSchedulerConfig", @@ -124,6 +131,28 @@ def load_scheduler( paged_manager=paged_manager, offload_queue_draining=pipeline_config.experimental_background_queue, ) + elif pipeline.__class__.__name__ == "ImageGenerationPipeline": + from max.pipelines.lib.pipeline_variants.image_generation import ( + ImageGenerationPipeline, + ) + + image_scheduler_config = ImageGenerationSchedulerConfig( + max_batch_size=pipeline_config.max_batch_size + if pipeline_config.max_batch_size is not None + else 1 + ) + image_pipeline = cast(ImageGenerationPipeline, pipeline) + return ImageGenerationScheduler( + scheduler_config=image_scheduler_config, + pipeline=image_pipeline, + request_queue=cast( + MAXPullQueue[ImageGenerationContext], + request_queue, + ), + response_queue=response_queue, + cancel_queue=cancel_queue, + offload_queue_draining=pipeline_config.experimental_background_queue, + ) elif pipeline_config.pipeline_role == PipelineRole.PrefillAndDecode: assert isinstance(pipeline, Pipeline) text_pipeline = cast( diff --git a/max/python/max/serve/scheduler/image_generation_scheduler.py b/max/python/max/serve/scheduler/image_generation_scheduler.py new file mode 100644 index 00000000000..004dfc6f47b --- /dev/null +++ b/max/python/max/serve/scheduler/image_generation_scheduler.py @@ -0,0 +1,123 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Scheduler for image generation pipelines.""" + +import logging +import queue +from dataclasses import dataclass + +from max.interfaces import ( + ImageGenerationContext, + ImageGenerationInputs, + ImageGenerationOutput, + MAXPullQueue, + MAXPushQueue, + RequestID, + Scheduler, + SchedulerResult, +) +from max.pipelines.lib.pipeline_variants.image_generation import ( + ImageGenerationPipeline, +) +from max.profiler import traced + +from .base import SchedulerProgress + +logger = logging.getLogger("max.serve") + + +@dataclass +class ImageGenerationSchedulerConfig: + """Image generation scheduler configuration.""" + + # The maximum number of requests that can be in the batch. + # For image generation, typically 1 since it's memory intensive. + max_batch_size: int = 1 + + +class ImageGenerationScheduler(Scheduler): + """Scheduler for image generation requests. + + This scheduler handles image generation requests one at a time, + as diffusion models are typically memory-intensive and don't + benefit from batching in the same way as text generation. + """ + + def __init__( + self, + scheduler_config: ImageGenerationSchedulerConfig, + pipeline: ImageGenerationPipeline, + request_queue: MAXPullQueue[ImageGenerationContext], + response_queue: MAXPushQueue[ + dict[RequestID, SchedulerResult[ImageGenerationOutput]] + ], + cancel_queue: MAXPullQueue[list[RequestID]], + offload_queue_draining: bool = False, + ) -> None: + self.scheduler_config = scheduler_config + self.pipeline = pipeline + self.request_queue = request_queue + self.response_queue = response_queue + self.cancel_queue = cancel_queue + # Note: offload_queue_draining is accepted for API compatibility + # but not used since image generation is sequential. + + @traced + def _get_next_request(self) -> ImageGenerationContext | None: + """Get the next request from the queue.""" + try: + return self.request_queue.get_nowait() + except queue.Empty: + return None + + def run_iteration(self) -> SchedulerProgress: + """Process one image generation request. + + Returns: + SchedulerProgress: Indicates whether work was performed. + """ + request = self._get_next_request() + if request is None: + return SchedulerProgress.NO_PROGRESS + + self._execute_request(request) + return SchedulerProgress.MADE_PROGRESS + + @traced + def _execute_request(self, request: ImageGenerationContext) -> None: + """Execute a single image generation request.""" + try: + # Create inputs from context + inputs = ImageGenerationInputs( + prompt=request.prompt, + height=request.height, + width=request.width, + num_inference_steps=request.num_inference_steps, + guidance_scale=request.guidance_scale, + num_images_per_prompt=request.num_images_per_prompt, + ) + + # Execute the pipeline + output: ImageGenerationOutput = self.pipeline.execute(inputs) + + # Send the response + self.response_queue.put_nowait( + {request.request_id: SchedulerResult.create(output)} + ) + except Exception as e: + logger.error(f"Error executing image generation: {e}") + # Send cancelled response on error + self.response_queue.put_nowait( + {request.request_id: SchedulerResult.cancelled()} + )