From 022fa3b35dfe8075d52296dd6d712a5aac5c4481 Mon Sep 17 00:00:00 2001 From: yyt Date: Tue, 4 Nov 2025 07:24:39 +0000 Subject: [PATCH 01/17] vae encode dp for wan --- .../models/autoencoders/autoencoder_kl_wan.py | 158 ++++++++++++++++-- 1 file changed, 144 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 0542dd49ae40..2291499043cd 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -1129,6 +1129,9 @@ def enable_dp( world_size = dist.get_world_size() if world_size <= 1 or world_size > dist.get_world_size(): + logger.warning( + f"Supported world_size for vae dp is between 2 - {dist.get_world_size}, but got {world_size}. " \ + f"Fall back to normal vae") return if hw_splits is None: @@ -1180,6 +1183,9 @@ def _encode(self, x: torch.Tensor): if self.config.patch_size is not None: x = patchify(x, patch_size=self.config.patch_size) + if self.use_dp: + return self.tiled_encode_with_dp(x) + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): return self.tiled_encode(x) @@ -1446,6 +1452,134 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod return (dec,) return DecoderOutput(sample=dec) + def tiled_encode_with_dp(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, height, width = x.shape + device = x.device + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + # Calculate stride based on h_split and w_split + tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split) + tile_latent_stride_width = int((latent_width + self.w_split - 1) / self.w_split) + + # Calculate overlap in latent space + overlap_latent_height = 3 + overlap_latent_width = 3 + if self.overlap_pixels is not None: + overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio + overlap_latent_height = overlap_latent + overlap_latent_width = overlap_latent + elif self.overlap_ratio is not None: + overlap_latent_height = int(self.overlap_ratio * latent_height) + overlap_latent_width = int(self.overlap_ratio * latent_width) + + # Calculate minimum tile size in latent space + tile_latent_min_height = tile_latent_stride_height + overlap_latent_height + tile_latent_min_width = tile_latent_stride_width + overlap_latent_width + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio + tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio + tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio + tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio + + # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] + num_tile_rows = self.h_split + num_tile_cols = self.w_split + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + local_tiles = [] + local_hw_shapes = [] + + for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: + self.clear_cache() + patch_height_start = h_idx * tile_sample_stride_height + patch_height_end = patch_height_start + tile_sample_min_height + patch_width_start = w_idx * tile_sample_stride_width + patch_width_end = patch_width_start + tile_sample_min_width + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, patch_height_start : patch_height_end, patch_width_start : patch_width_end] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + patch_height_start : patch_height_end, + patch_width_start : patch_width_end, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + time = torch.cat(time, dim=2) + local_tiles.append(time.flatten(3, 4)) + local_hw_shapes.append(torch.Tensor([*time.shape[3:5]]).to(device).int()) + self.clear_cache() + + # concat all tiles on local rank + local_tiles = torch.cat(local_tiles, dim=3) + local_hw_shapes = torch.stack(local_hw_shapes) + + # get all hw shapes for each rank (perhaps has different shapes for last tile) + gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) + for num_tiles in self.num_tiles_per_rank] + dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) + + # gather tiles on all ranks + b, c, n = local_tiles.shape[:3] + gathered_tiles = [ + torch.empty( + (b, c, n, tiles_shape.prod(dim=1).sum().item()), + dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list + ] + dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) + + # put tiles in rows based on tile_idxs_per_rank + rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] + for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): + if not tile_idxs: + continue + rank_tile_hw_shapes = gathered_shape_list[rank_idx] + hw_start_idx = 0 + # perhaps has more than one tile in each rank, get each by hw_shapes + for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): + rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] + hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw + rows[h_idx][w_idx] = gathered_tiles[rank_idx][:, :, :, hw_start_idx:hw_end_idx].unflatten( + 3, rank_tile_hw_shape.tolist()) # unflatten hw dim + hw_start_idx = hw_end_idx + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: r""" Decode a batch of images using a tiled decoder. @@ -1490,32 +1624,26 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio - self.tile_sample_min_height = tile_sample_min_height - self.tile_sample_min_width = tile_sample_min_width - self.tile_sample_stride_height = tile_sample_stride_height - self.tile_sample_stride_width = tile_sample_stride_width - if self.config.patch_size is not None: sample_height = sample_height // self.config.patch_size sample_width = sample_width // self.config.patch_size tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size - blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height - blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width + blend_height = tile_sample_min_height // self.config.patch_size - tile_sample_stride_height + blend_width = tile_sample_min_width // self.config.patch_size - tile_sample_stride_width else: - blend_height = self.tile_sample_min_height - tile_sample_stride_height - blend_width = self.tile_sample_min_width - tile_sample_stride_width + blend_height = tile_sample_min_height - tile_sample_stride_height + blend_width = tile_sample_min_width - tile_sample_stride_width - # Split z into overlapping tiles and decode them separately. - # The tiles have an overlap to avoid seams between tiles. # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] num_tile_rows = self.h_split num_tile_cols = self.w_split + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. # Each rank computes only tiles assigned to it based on tile_idxs_per_rank - # local_tiles = [] # List to store tiles computed by this rank - local_tiles = [] - local_hw_shapes = [] + local_tiles = [] # List to store tiles computed by this rank + local_hw_shapes = [] # List to store shapes of tiles by this rank for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: self.clear_cache() @@ -1558,6 +1686,8 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni # put tiles in rows based on tile_idxs_per_rank rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): + if not tile_idxs: + continue rank_tile_hw_shapes = gathered_shape_list[rank_idx] hw_start_idx = 0 # perhaps has more than one tile in each rank, get each by hw_shapes From d5c634f05ebe5bb29d430b192f83220d3942b84e Mon Sep 17 00:00:00 2001 From: yyt Date: Tue, 4 Nov 2025 07:55:58 +0000 Subject: [PATCH 02/17] move redundant enable_dp --- .../models/autoencoders/autoencoder_kl_wan.py | 47 ------------------- 1 file changed, 47 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 8aabf56d7e3e..e2fa742e6066 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -1169,53 +1169,6 @@ def enable_dp( self.num_tiles_per_rank[rank_idx] += 1 rank_idx += 1 - def enable_dp( - self, - world_size: Optional[int] = None, - hw_splits: Optional[Tuple[int, int]] = None, - overlap_ratio: Optional[float] = None, - overlap_pixels: Optional[int] = None - ) -> None: - r""" - """ - if world_size is None: - world_size = dist.get_world_size() - - if world_size <= 1 or world_size > dist.get_world_size(): - return - - if hw_splits is None: - hw_splits = (1, int(world_size)) - - assert len(hw_splits) == 2, f"'hw_splits' should be a tuple of 2 int, but got length {len(hw_splits)}" - - h_split, w_split = map(int, hw_splits) - num_tiles = h_split * w_split - - # assert h_split * w_split == world_size, \ - # (f"world_size must be {w_split} * {h_split} = {w_split * h_split}, but got {world_size}") - - self.use_dp = True - self.h_split, self.w_split = h_split, w_split - self.world_size = world_size - self.overlap_ratio = overlap_ratio - self.overlap_pixels = overlap_pixels - - dp_ranks = list(range(0, world_size)) - self.vae_dp_group = dist.new_group(ranks=dp_ranks) - self.rank = dist.get_rank() - # patch_ranks_flatten = [tile_idx % world_size for tile_idx in range(num_tiles)] - # self.patch_ranks = torch.Tensor(patch_ranks_flatten).reshape(h_split, w_split) - self.tile_idxs_per_rank = [[] for _ in range(self.world_size)] - self.num_tiles_per_rank = [0] * self.world_size - rank_idx = 0 - for h_idx in range(self.h_split): - for w_idx in range(self.w_split): - rank_idx %= self.world_size - self.tile_idxs_per_rank[rank_idx].append((h_idx, w_idx)) - self.num_tiles_per_rank[rank_idx] += 1 - rank_idx += 1 - def clear_cache(self): # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call self._conv_num = self._cached_conv_counts["decoder"] From 7d5790fffba4e7e6b2396f1e536e1405a3b441f5 Mon Sep 17 00:00:00 2001 From: yyt Date: Tue, 4 Nov 2025 09:35:22 +0000 Subject: [PATCH 03/17] implement vae dp for AutoencoderKL and AutoencoderKLWan --- .../models/autoencoders/autoencoder_kl.py | 300 ++++++++++++++++- .../models/autoencoders/autoencoder_kl_wan.py | 316 ++++++++++++++++++ 2 files changed, 615 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 1a72aa3cfeb3..e6655a908860 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -15,11 +15,12 @@ import torch import torch.nn as nn +import torch.distributed as dist from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import deprecate +from ...utils import deprecate, logging from ...utils.accelerate_utils import apply_forward_hook from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -35,6 +36,9 @@ from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. @@ -127,6 +131,7 @@ def __init__( self.use_slicing = False self.use_tiling = False + self.use_dp = False # only relevant if vae tiling is enabled self.tile_sample_min_size = self.config.sample_size @@ -214,9 +219,58 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) + def enable_dp( + self, + world_size: Optional[int] = None, + hw_splits: Optional[Tuple[int, int]] = None, + overlap_ratio: Optional[float] = None, + overlap_pixels: Optional[int] = None + ) -> None: + r""" + """ + if world_size is None: + world_size = dist.get_world_size() + + if world_size <= 1 or world_size > dist.get_world_size(): + logger.warning( + f"Supported world_size for vae dp is between 2 - {dist.get_world_size}, but got {world_size}. " \ + f"Fall back to normal vae") + return + + if hw_splits is None: + hw_splits = (1, int(world_size)) + + assert len(hw_splits) == 2, f"'hw_splits' should be a tuple of 2 int, but got length {len(hw_splits)}" + + h_split, w_split = map(int, hw_splits) + + self.use_dp = True + self.h_split, self.w_split = h_split, w_split + self.world_size = world_size + self.overlap_ratio = overlap_ratio + self.overlap_pixels = overlap_pixels + self.spatial_compression_ratio = 2 ** (len(self.config.block_out_channels) - 1) + + dp_ranks = list(range(0, world_size)) + self.vae_dp_group = dist.new_group(ranks=dp_ranks) + self.rank = dist.get_rank() + # patch_ranks_flatten = [tile_idx % world_size for tile_idx in range(num_tiles)] + # self.patch_ranks = torch.Tensor(patch_ranks_flatten).reshape(h_split, w_split) + self.tile_idxs_per_rank = [[] for _ in range(self.world_size)] + self.num_tiles_per_rank = [0] * self.world_size + rank_idx = 0 + for h_idx in range(self.h_split): + for w_idx in range(self.w_split): + rank_idx %= self.world_size + self.tile_idxs_per_rank[rank_idx].append((h_idx, w_idx)) + self.num_tiles_per_rank[rank_idx] += 1 + rank_idx += 1 + def _encode(self, x: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = x.shape + if self.use_dp: + return self._tiled_encode(x) if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): return self._tiled_encode(x) @@ -256,6 +310,8 @@ def encode( return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + if self.use_dp: + return self.tiled_decode_with_dp(z, return_dict=return_dict) if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): return self.tiled_decode(z, return_dict=return_dict) @@ -497,6 +553,248 @@ def forward( return DecoderOutput(sample=dec) + def _tiled_encode_with_dp(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + + _, _, height, width = x.shape + device = x.device + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + # Calculate stride based on h_split and w_split + tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split) + tile_latent_stride_width = int((latent_width + self.w_split - 1) / self.w_split) + + # Calculate overlap in latent space + overlap_latent_height = 3 + overlap_latent_width = 3 + if self.overlap_pixels is not None: + overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio + overlap_latent_height = overlap_latent + overlap_latent_width = overlap_latent + elif self.overlap_ratio is not None: + overlap_latent_height = int(self.overlap_ratio * latent_height) + overlap_latent_width = int(self.overlap_ratio * latent_width) + + # Calculate minimum tile size in latent space + tile_latent_min_height = tile_latent_stride_height + overlap_latent_height + tile_latent_min_width = tile_latent_stride_width + overlap_latent_width + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio + tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio + tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio + tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio + + # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] + num_tile_rows = self.h_split + num_tile_cols = self.w_split + + local_tiles = [] + local_hw_shapes = [] + + for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: + patch_height_start = h_idx * tile_sample_stride_height + patch_height_end = patch_height_start + tile_sample_min_height + patch_width_start = w_idx * tile_sample_stride_width + patch_width_end = patch_width_start + tile_sample_min_width + + tile = x[:, :, patch_height_start : patch_height_end, patch_width_start : patch_width_end] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + + local_tiles.append(tile.flatten(-2, -1)) + local_hw_shapes.append(torch.Tensor([*tile.shape[-2:]]).to(device).int()) + + # concat all tiles on local rank + local_tiles = torch.cat(local_tiles, dim=-1) + local_hw_shapes = torch.stack(local_hw_shapes) + + # get all hw shapes for each rank (perhaps has different shapes for last tile) + gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) + for num_tiles in self.num_tiles_per_rank] + dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) + + # gather tiles on all ranks + bc_ = local_tiles.shape[:-1] + gathered_tiles = [ + torch.empty( + (*bc_, tiles_shape.prod(dim=1).sum().item()), + dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list + ] + dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) + + # put tiles in rows based on tile_idxs_per_rank + rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] + for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): + if not tile_idxs: + continue + rank_tile_hw_shapes = gathered_shape_list[rank_idx] + hw_start_idx = 0 + # perhaps has more than one tile in each rank, get each by hw_shapes + for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): + rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] + hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw + rows[h_idx][w_idx] = gathered_tiles[rank_idx][..., hw_start_idx:hw_end_idx].unflatten( + -1, rank_tile_hw_shape.tolist()) # unflatten hw dim + hw_start_idx = hw_end_idx + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=3)) + + enc = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width] + return enc + + def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, height, width = z.shape + device = z.device + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + # Calculate stride based on h_split and w_split + tile_latent_stride_height = int((height + self.h_split - 1) / self.h_split) + tile_latent_stride_width = int((width + self.w_split - 1) / self.w_split) + + # Calculate overlap in latent space + overlap_latent_height = 3 + overlap_latent_width = 3 + if self.overlap_pixels is not None: + overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio + overlap_latent_height = overlap_latent + overlap_latent_width = overlap_latent + elif self.overlap_ratio is not None: + overlap_latent_height = int(self.overlap_ratio * height) + overlap_latent_width = int(self.overlap_ratio * width) + + # Calculate minimum tile size in latent space + tile_latent_min_height = tile_latent_stride_height + overlap_latent_height + tile_latent_min_width = tile_latent_stride_width + overlap_latent_width + + # Convert min/stride to sample space + tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio + tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio + tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio + tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio + + blend_height = tile_sample_min_height - tile_sample_stride_height + blend_width = tile_sample_min_width - tile_sample_stride_width + + # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] + num_tile_rows = self.h_split + num_tile_cols = self.w_split + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + # Each rank computes only tiles assigned to it based on tile_idxs_per_rank + local_tiles = [] # List to store tiles computed by this rank + local_hw_shapes = [] # List to store shapes of tiles by this rank + + for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: + patch_height_start = h_idx * tile_latent_stride_height + patch_height_end = patch_height_start + tile_latent_min_height + patch_width_start = w_idx * tile_latent_stride_width + patch_width_end = patch_width_start + tile_latent_min_width + + tile = z[:, :, patch_height_start : patch_height_end, patch_width_start : patch_width_end] + if self.config.use_post_quant_conv: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + + local_tiles.append(decoded.flatten(-2, -1)) # flatten h,w dim for concate all tiles in one rank + local_hw_shapes.append(torch.Tensor([*decoded.shape[-2:]]).to(device).int()) # record hw for futher unflatten + + # concat all tiles on local rank + local_tiles = torch.cat(local_tiles, dim=-1) + local_hw_shapes = torch.stack(local_hw_shapes) + + # get all hw shapes for each rank (perhaps has different shapes for last tile) + gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) + for num_tiles in self.num_tiles_per_rank] + dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) + + # gather tiles on all ranks + bcn_ = local_tiles.shape[:-1] + gathered_tiles = [ + torch.empty( + (*bcn_, tiles_shape.prod(dim=1).sum().item()), + dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list + ] + dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) + + # put tiles in rows based on tile_idxs_per_rank + rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] + for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): + if not tile_idxs: + continue + rank_tile_hw_shapes = gathered_shape_list[rank_idx] + hw_start_idx = 0 + # perhaps has more than one tile in each rank, get each by hw_shapes + for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): + rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] + hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw + rows[h_idx][w_idx] = gathered_tiles[rank_idx][..., hw_start_idx:hw_end_idx].unflatten( + -1, rank_tile_hw_shape.tolist()) # unflatten hw dim + hw_start_idx = hw_end_idx + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :tile_sample_stride_height, :tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2)[:, :, :sample_height, :sample_width] + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections def fuse_qkv_projections(self): """ diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index f8bdfeb75524..3fd96a755a7f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -1113,6 +1113,52 @@ def enable_tiling( self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + def enable_dp( + self, + world_size: Optional[int] = None, + hw_splits: Optional[Tuple[int, int]] = None, + overlap_ratio: Optional[float] = None, + overlap_pixels: Optional[int] = None + ) -> None: + r""" + """ + if world_size is None: + world_size = dist.get_world_size() + + if world_size <= 1 or world_size > dist.get_world_size(): + logger.warning( + f"Supported world_size for vae dp is between 2 - {dist.get_world_size}, but got {world_size}. " \ + f"Fall back to normal vae") + return + + if hw_splits is None: + hw_splits = (1, int(world_size)) + + assert len(hw_splits) == 2, f"'hw_splits' should be a tuple of 2 int, but got length {len(hw_splits)}" + + h_split, w_split = map(int, hw_splits) + + self.use_dp = True + self.h_split, self.w_split = h_split, w_split + self.world_size = world_size + self.overlap_ratio = overlap_ratio + self.overlap_pixels = overlap_pixels + + dp_ranks = list(range(0, world_size)) + self.vae_dp_group = dist.new_group(ranks=dp_ranks) + self.rank = dist.get_rank() + # patch_ranks_flatten = [tile_idx % world_size for tile_idx in range(num_tiles)] + # self.patch_ranks = torch.Tensor(patch_ranks_flatten).reshape(h_split, w_split) + self.tile_idxs_per_rank = [[] for _ in range(self.world_size)] + self.num_tiles_per_rank = [0] * self.world_size + rank_idx = 0 + for h_idx in range(self.h_split): + for w_idx in range(self.w_split): + rank_idx %= self.world_size + self.tile_idxs_per_rank[rank_idx].append((h_idx, w_idx)) + self.num_tiles_per_rank[rank_idx] += 1 + rank_idx += 1 + def clear_cache(self): # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call self._conv_num = self._cached_conv_counts["decoder"] @@ -1393,6 +1439,276 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod return (dec,) return DecoderOutput(sample=dec) + def tiled_encode_with_dp(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, height, width = x.shape + device = x.device + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + # Calculate stride based on h_split and w_split + tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split) + tile_latent_stride_width = int((latent_width + self.w_split - 1) / self.w_split) + + # Calculate overlap in latent space + overlap_latent_height = 3 + overlap_latent_width = 3 + if self.overlap_pixels is not None: + overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio + overlap_latent_height = overlap_latent + overlap_latent_width = overlap_latent + elif self.overlap_ratio is not None: + overlap_latent_height = int(self.overlap_ratio * latent_height) + overlap_latent_width = int(self.overlap_ratio * latent_width) + + # Calculate minimum tile size in latent space + tile_latent_min_height = tile_latent_stride_height + overlap_latent_height + tile_latent_min_width = tile_latent_stride_width + overlap_latent_width + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio + tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio + tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio + tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio + + # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] + num_tile_rows = self.h_split + num_tile_cols = self.w_split + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + local_tiles = [] + local_hw_shapes = [] + + for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: + self.clear_cache() + patch_height_start = h_idx * tile_sample_stride_height + patch_height_end = patch_height_start + tile_sample_min_height + patch_width_start = w_idx * tile_sample_stride_width + patch_width_end = patch_width_start + tile_sample_min_width + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, patch_height_start : patch_height_end, patch_width_start : patch_width_end] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + patch_height_start : patch_height_end, + patch_width_start : patch_width_end, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + time = torch.cat(time, dim=2) + local_tiles.append(time.flatten(-2, -1)) + local_hw_shapes.append(torch.Tensor([*time.shape[-2:]]).to(device).int()) + self.clear_cache() + + # concat all tiles on local rank + local_tiles = torch.cat(local_tiles, dim=-1) + local_hw_shapes = torch.stack(local_hw_shapes) + + # get all hw shapes for each rank (perhaps has different shapes for last tile) + gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) + for num_tiles in self.num_tiles_per_rank] + dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) + + # gather tiles on all ranks + bcn_ = local_tiles.shape[:-1] + gathered_tiles = [ + torch.empty( + (*bcn_, tiles_shape.prod(dim=1).sum().item()), + dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list + ] + dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) + + # put tiles in rows based on tile_idxs_per_rank + rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] + for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): + if not tile_idxs: + continue + rank_tile_hw_shapes = gathered_shape_list[rank_idx] + hw_start_idx = 0 + # perhaps has more than one tile in each rank, get each by hw_shapes + for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): + rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] + hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw + rows[h_idx][w_idx] = gathered_tiles[rank_idx][:, :, :, hw_start_idx:hw_end_idx].unflatten( + -1, rank_tile_hw_shape.tolist()) # unflatten hw dim + hw_start_idx = hw_end_idx + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, height, width = z.shape + device = z.device + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + # Calculate stride based on h_split and w_split + tile_latent_stride_height = int((height + self.h_split - 1) / self.h_split) + tile_latent_stride_width = int((width + self.w_split - 1) / self.w_split) + + # Calculate overlap in latent space + overlap_latent_height = 3 + overlap_latent_width = 3 + if self.overlap_pixels is not None: + overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio + overlap_latent_height = overlap_latent + overlap_latent_width = overlap_latent + elif self.overlap_ratio is not None: + overlap_latent_height = int(self.overlap_ratio * height) + overlap_latent_width = int(self.overlap_ratio * width) + + # Calculate minimum tile size in latent space + tile_latent_min_height = tile_latent_stride_height + overlap_latent_height + tile_latent_min_width = tile_latent_stride_width + overlap_latent_width + + # Convert min/stride to sample space + tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio + tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio + tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio + tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio + + if self.config.patch_size is not None: + sample_height = sample_height // self.config.patch_size + sample_width = sample_width // self.config.patch_size + tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size + tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size + blend_height = tile_sample_min_height // self.config.patch_size - tile_sample_stride_height + blend_width = tile_sample_min_width // self.config.patch_size - tile_sample_stride_width + else: + blend_height = tile_sample_min_height - tile_sample_stride_height + blend_width = tile_sample_min_width - tile_sample_stride_width + + # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] + num_tile_rows = self.h_split + num_tile_cols = self.w_split + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + # Each rank computes only tiles assigned to it based on tile_idxs_per_rank + local_tiles = [] # List to store tiles computed by this rank + local_hw_shapes = [] # List to store shapes of tiles by this rank + + for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: + self.clear_cache() + patch_height_start = h_idx * tile_latent_stride_height + patch_height_end = patch_height_start + tile_latent_min_height + patch_width_start = w_idx * tile_latent_stride_width + patch_width_end = patch_width_start + tile_latent_min_width + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, patch_height_start : patch_height_end, patch_width_start : patch_width_end] + tile = self.post_quant_conv(tile) + decoded = self.decoder( + tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0) + ) + time.append(decoded) + time = torch.cat(time, dim=2) + local_tiles.append(time.flatten(-2, -1)) # flatten h,w dim for concate all tiles in one rank + local_hw_shapes.append(torch.Tensor([*time.shape[-2:]]).to(device).int()) # record hw for futher unflatten + self.clear_cache() + + # concat all tiles on local rank + local_tiles = torch.cat(local_tiles, dim=-1) + local_hw_shapes = torch.stack(local_hw_shapes) + + # get all hw shapes for each rank (perhaps has different shapes for last tile) + gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) + for num_tiles in self.num_tiles_per_rank] + dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) + + # gather tiles on all ranks + bcn_ = local_tiles.shape[:-1] + gathered_tiles = [ + torch.empty( + (*bcn_, tiles_shape.prod(dim=1).sum().item()), + dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list + ] + dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) + + # put tiles in rows based on tile_idxs_per_rank + rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] + for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): + if not tile_idxs: + continue + rank_tile_hw_shapes = gathered_shape_list[rank_idx] + hw_start_idx = 0 + # perhaps has more than one tile in each rank, get each by hw_shapes + for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): + rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] + hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw + rows[h_idx][w_idx] = gathered_tiles[rank_idx][..., hw_start_idx:hw_end_idx].unflatten( + 3, rank_tile_hw_shape.tolist()) # unflatten hw dim + hw_start_idx = hw_end_idx + + # combine all tiles, same as tiled decode + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if self.config.patch_size is not None: + dec = unpatchify(dec, patch_size=self.config.patch_size) + + dec = torch.clamp(dec, min=-1.0, max=1.0) + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + def forward( self, sample: torch.Tensor, From 6c61cd0e355c830f352a04100581e9afcc259856 Mon Sep 17 00:00:00 2001 From: yyt Date: Wed, 5 Nov 2025 03:25:45 +0000 Subject: [PATCH 04/17] extract same code in vae dp func --- .../models/autoencoders/autoencoder_kl.py | 275 ++++++------------ .../models/autoencoders/autoencoder_kl_wan.py | 223 +++++--------- src/diffusers/models/autoencoders/vae.py | 78 ++++- 3 files changed, 238 insertions(+), 338 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index e6655a908860..33841b2dae06 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -523,33 +523,41 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod if not return_dict: return (dec,) - return DecoderOutput(sample=dec) + def calculate_tiled_parallel_size(self, latent_height, latent_width): + # Calculate stride based on h_split and w_split + tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split) + tile_latent_stride_width = int((latent_width + self.w_split - 1) / self.w_split) - def forward( - self, - sample: torch.Tensor, - sample_posterior: bool = False, - return_dict: bool = True, - generator: Optional[torch.Generator] = None, - ) -> Union[DecoderOutput, torch.Tensor]: - r""" - Args: - sample (`torch.Tensor`): Input sample. - sample_posterior (`bool`, *optional*, defaults to `False`): - Whether to sample from the posterior. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`DecoderOutput`] instead of a plain tuple. - """ - x = sample - posterior = self.encode(x).latent_dist - if sample_posterior: - z = posterior.sample(generator=generator) - else: - z = posterior.mode() - dec = self.decode(z).sample + # Calculate overlap in latent space + overlap_latent_height = 3 + overlap_latent_width = 3 + if self.overlap_pixels is not None: + overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio + overlap_latent_height = overlap_latent + overlap_latent_width = overlap_latent + elif self.overlap_ratio is not None: + overlap_latent_height = int(self.overlap_ratio * latent_height) + overlap_latent_width = int(self.overlap_ratio * latent_width) - if not return_dict: - return (dec,) + # Calculate minimum tile size in latent space + tile_latent_min_height = tile_latent_stride_height + overlap_latent_height + tile_latent_min_width = tile_latent_stride_width + overlap_latent_width + + tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio + tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio + tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio + tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio + + blend_latent_height = tile_latent_min_height - tile_latent_stride_height + blend_latent_width = tile_latent_min_width - tile_latent_stride_width + + blend_sample_height = tile_sample_min_height - tile_sample_stride_height + blend_sample_width = tile_sample_min_width - tile_sample_stride_width + + return \ + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width return DecoderOutput(sample=dec) @@ -575,86 +583,24 @@ def _tiled_encode_with_dp(self, x: torch.Tensor) -> torch.Tensor: latent_height = height // self.spatial_compression_ratio latent_width = width // self.spatial_compression_ratio - # Calculate stride based on h_split and w_split - tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split) - tile_latent_stride_width = int((latent_width + self.w_split - 1) / self.w_split) - - # Calculate overlap in latent space - overlap_latent_height = 3 - overlap_latent_width = 3 - if self.overlap_pixels is not None: - overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio - overlap_latent_height = overlap_latent - overlap_latent_width = overlap_latent - elif self.overlap_ratio is not None: - overlap_latent_height = int(self.overlap_ratio * latent_height) - overlap_latent_width = int(self.overlap_ratio * latent_width) - - # Calculate minimum tile size in latent space - tile_latent_min_height = tile_latent_stride_height + overlap_latent_height - tile_latent_min_width = tile_latent_stride_width + overlap_latent_width - - blend_height = tile_latent_min_height - tile_latent_stride_height - blend_width = tile_latent_min_width - tile_latent_stride_width - - tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio - tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio - tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio - tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio - - # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] - num_tile_rows = self.h_split - num_tile_cols = self.w_split - - local_tiles = [] - local_hw_shapes = [] - - for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: - patch_height_start = h_idx * tile_sample_stride_height - patch_height_end = patch_height_start + tile_sample_min_height - patch_width_start = w_idx * tile_sample_stride_width - patch_width_end = patch_width_start + tile_sample_min_width + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width = \ + self.calculate_tiled_parallel_size(latent_height, latent_width) + def vae_encode_op( + x, patch_height_start, patch_height_end, patch_width_start, patch_width_end + ) -> torch.Tensor: tile = x[:, :, patch_height_start : patch_height_end, patch_width_start : patch_width_end] tile = self.encoder(tile) if self.config.use_quant_conv: tile = self.quant_conv(tile) + return tile - local_tiles.append(tile.flatten(-2, -1)) - local_hw_shapes.append(torch.Tensor([*tile.shape[-2:]]).to(device).int()) - - # concat all tiles on local rank - local_tiles = torch.cat(local_tiles, dim=-1) - local_hw_shapes = torch.stack(local_hw_shapes) - - # get all hw shapes for each rank (perhaps has different shapes for last tile) - gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) - for num_tiles in self.num_tiles_per_rank] - dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) - - # gather tiles on all ranks - bc_ = local_tiles.shape[:-1] - gathered_tiles = [ - torch.empty( - (*bc_, tiles_shape.prod(dim=1).sum().item()), - dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list - ] - dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) - - # put tiles in rows based on tile_idxs_per_rank - rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] - for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): - if not tile_idxs: - continue - rank_tile_hw_shapes = gathered_shape_list[rank_idx] - hw_start_idx = 0 - # perhaps has more than one tile in each rank, get each by hw_shapes - for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): - rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] - hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw - rows[h_idx][w_idx] = gathered_tiles[rank_idx][..., hw_start_idx:hw_end_idx].unflatten( - -1, rank_tile_hw_shape.tolist()) # unflatten hw dim - hw_start_idx = hw_end_idx + rows = self.run_vae_tile_parallel( + x, vae_encode_op, + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, device + ) result_rows = [] for i, row in enumerate(rows): @@ -663,9 +609,9 @@ def _tiled_encode_with_dp(self, x: torch.Tensor) -> torch.Tensor: # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_height) + tile = self.blend_v(rows[i - 1][j], tile, blend_latent_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_width) + tile = self.blend_h(row[j - 1], tile, blend_latent_width) result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width]) result_rows.append(torch.cat(result_row, dim=3)) @@ -686,95 +632,30 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ - _, _, height, width = z.shape + _, _, latent_height, latent_width = z.shape device = z.device - sample_height = height * self.spatial_compression_ratio - sample_width = width * self.spatial_compression_ratio - - # Calculate stride based on h_split and w_split - tile_latent_stride_height = int((height + self.h_split - 1) / self.h_split) - tile_latent_stride_width = int((width + self.w_split - 1) / self.w_split) - - # Calculate overlap in latent space - overlap_latent_height = 3 - overlap_latent_width = 3 - if self.overlap_pixels is not None: - overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio - overlap_latent_height = overlap_latent - overlap_latent_width = overlap_latent - elif self.overlap_ratio is not None: - overlap_latent_height = int(self.overlap_ratio * height) - overlap_latent_width = int(self.overlap_ratio * width) - - # Calculate minimum tile size in latent space - tile_latent_min_height = tile_latent_stride_height + overlap_latent_height - tile_latent_min_width = tile_latent_stride_width + overlap_latent_width - - # Convert min/stride to sample space - tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio - tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio - tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio - tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio - - blend_height = tile_sample_min_height - tile_sample_stride_height - blend_width = tile_sample_min_width - tile_sample_stride_width + sample_height = latent_height * self.spatial_compression_ratio + sample_width = latent_width * self.spatial_compression_ratio - # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] - num_tile_rows = self.h_split - num_tile_cols = self.w_split + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width = \ + self.calculate_tiled_parallel_size(latent_height, latent_width) - # Split z into overlapping tiles and decode them separately. - # The tiles have an overlap to avoid seams between tiles. - # Each rank computes only tiles assigned to it based on tile_idxs_per_rank - local_tiles = [] # List to store tiles computed by this rank - local_hw_shapes = [] # List to store shapes of tiles by this rank - - for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: - patch_height_start = h_idx * tile_latent_stride_height - patch_height_end = patch_height_start + tile_latent_min_height - patch_width_start = w_idx * tile_latent_stride_width - patch_width_end = patch_width_start + tile_latent_min_width + def vae_decode_op( + z, patch_height_start, patch_height_end, patch_width_start, patch_width_end + ) -> torch.Tensor: tile = z[:, :, patch_height_start : patch_height_end, patch_width_start : patch_width_end] if self.config.use_post_quant_conv: tile = self.post_quant_conv(tile) decoded = self.decoder(tile) + return decoded - local_tiles.append(decoded.flatten(-2, -1)) # flatten h,w dim for concate all tiles in one rank - local_hw_shapes.append(torch.Tensor([*decoded.shape[-2:]]).to(device).int()) # record hw for futher unflatten - - # concat all tiles on local rank - local_tiles = torch.cat(local_tiles, dim=-1) - local_hw_shapes = torch.stack(local_hw_shapes) - - # get all hw shapes for each rank (perhaps has different shapes for last tile) - gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) - for num_tiles in self.num_tiles_per_rank] - dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) - - # gather tiles on all ranks - bcn_ = local_tiles.shape[:-1] - gathered_tiles = [ - torch.empty( - (*bcn_, tiles_shape.prod(dim=1).sum().item()), - dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list - ] - dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) - - # put tiles in rows based on tile_idxs_per_rank - rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] - for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): - if not tile_idxs: - continue - rank_tile_hw_shapes = gathered_shape_list[rank_idx] - hw_start_idx = 0 - # perhaps has more than one tile in each rank, get each by hw_shapes - for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): - rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] - hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw - rows[h_idx][w_idx] = gathered_tiles[rank_idx][..., hw_start_idx:hw_end_idx].unflatten( - -1, rank_tile_hw_shape.tolist()) # unflatten hw dim - hw_start_idx = hw_end_idx + rows = self.run_vae_tile_parallel( + z, vae_decode_op, + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, device + ) result_rows = [] for i, row in enumerate(rows): @@ -783,9 +664,9 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_height) + tile = self.blend_v(rows[i - 1][j], tile, blend_sample_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_width) + tile = self.blend_h(row[j - 1], tile, blend_sample_width) result_row.append(tile[:, :, :tile_sample_stride_height, :tile_sample_stride_width]) result_rows.append(torch.cat(result_row, dim=3)) @@ -795,6 +676,34 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni return DecoderOutput(sample=dec) + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections def fuse_qkv_projections(self): """ diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 3fd96a755a7f..960c3a9d87f2 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -1439,21 +1439,7 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod return (dec,) return DecoderOutput(sample=dec) - def tiled_encode_with_dp(self, x: torch.Tensor) -> AutoencoderKLOutput: - r"""Encode a batch of images using a tiled encoder. - - Args: - x (`torch.Tensor`): Input batch of videos. - - Returns: - `torch.Tensor`: - The latent representation of the encoded videos. - """ - _, _, num_frames, height, width = x.shape - device = x.device - latent_height = height // self.spatial_compression_ratio - latent_width = width // self.spatial_compression_ratio - + def calculate_tiled_parallel_size(self, latent_height, latent_width): # Calculate stride based on h_split and w_split tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split) tile_latent_stride_width = int((latent_width + self.w_split - 1) / self.w_split) @@ -1473,29 +1459,55 @@ def tiled_encode_with_dp(self, x: torch.Tensor) -> AutoencoderKLOutput: tile_latent_min_height = tile_latent_stride_height + overlap_latent_height tile_latent_min_width = tile_latent_stride_width + overlap_latent_width - blend_height = tile_latent_min_height - tile_latent_stride_height - blend_width = tile_latent_min_width - tile_latent_stride_width - tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio - # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] - num_tile_rows = self.h_split - num_tile_cols = self.w_split + blend_latent_height = tile_latent_min_height - tile_latent_stride_height + blend_latent_width = tile_latent_min_width - tile_latent_stride_width - # Split x into overlapping tiles and encode them separately. - # The tiles have an overlap to avoid seams between tiles. - local_tiles = [] - local_hw_shapes = [] + if self.config.patch_size is not None: + sample_height = sample_height // self.config.patch_size + sample_width = sample_width // self.config.patch_size + tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size + tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size + blend_sample_height = tile_sample_min_height // self.config.patch_size - tile_sample_stride_height + blend_sample_width = tile_sample_min_width // self.config.patch_size - tile_sample_stride_width + else: + blend_sample_height = tile_sample_min_height - tile_sample_stride_height + blend_sample_width = tile_sample_min_width - tile_sample_stride_width + + return \ + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width + + def tiled_encode_with_dp(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, sample_height, sample_width = x.shape + device = x.device + latent_height = sample_height // self.spatial_compression_ratio + latent_width = sample_width // self.spatial_compression_ratio + + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width = \ + self.calculate_tiled_parallel_size(latent_height, latent_width) + + def vae_encode_op( + x, patch_height_start, patch_height_end, patch_width_start, patch_width_end, num_frames + ) -> torch.Tensor: - for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: self.clear_cache() - patch_height_start = h_idx * tile_sample_stride_height - patch_height_end = patch_height_start + tile_sample_min_height - patch_width_start = w_idx * tile_sample_stride_width - patch_width_end = patch_width_start + tile_sample_min_width time = [] frame_range = 1 + (num_frames - 1) // 4 for k in range(frame_range): @@ -1514,42 +1526,14 @@ def tiled_encode_with_dp(self, x: torch.Tensor) -> AutoencoderKLOutput: tile = self.quant_conv(tile) time.append(tile) time = torch.cat(time, dim=2) - local_tiles.append(time.flatten(-2, -1)) - local_hw_shapes.append(torch.Tensor([*time.shape[-2:]]).to(device).int()) self.clear_cache() + return time - # concat all tiles on local rank - local_tiles = torch.cat(local_tiles, dim=-1) - local_hw_shapes = torch.stack(local_hw_shapes) - - # get all hw shapes for each rank (perhaps has different shapes for last tile) - gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) - for num_tiles in self.num_tiles_per_rank] - dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) - - # gather tiles on all ranks - bcn_ = local_tiles.shape[:-1] - gathered_tiles = [ - torch.empty( - (*bcn_, tiles_shape.prod(dim=1).sum().item()), - dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list - ] - dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) - - # put tiles in rows based on tile_idxs_per_rank - rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] - for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): - if not tile_idxs: - continue - rank_tile_hw_shapes = gathered_shape_list[rank_idx] - hw_start_idx = 0 - # perhaps has more than one tile in each rank, get each by hw_shapes - for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): - rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] - hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw - rows[h_idx][w_idx] = gathered_tiles[rank_idx][:, :, :, hw_start_idx:hw_end_idx].unflatten( - -1, rank_tile_hw_shape.tolist()) # unflatten hw dim - hw_start_idx = hw_end_idx + rows = self.run_vae_tile_parallel( + x, vae_encode_op, + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, device, + num_frames=num_frames + ) result_rows = [] for i, row in enumerate(rows): @@ -1558,9 +1542,9 @@ def tiled_encode_with_dp(self, x: torch.Tensor) -> AutoencoderKLOutput: # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_height) + tile = self.blend_v(rows[i - 1][j], tile, blend_latent_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_width) + tile = self.blend_h(row[j - 1], tile, blend_latent_width) result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) result_rows.append(torch.cat(result_row, dim=-1)) @@ -1581,63 +1565,22 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ - _, _, num_frames, height, width = z.shape + _, _, num_frames, latent_height, latent_width = z.shape device = z.device - sample_height = height * self.spatial_compression_ratio - sample_width = width * self.spatial_compression_ratio + sample_height = latent_height * self.spatial_compression_ratio + sample_width = latent_width * self.spatial_compression_ratio - # Calculate stride based on h_split and w_split - tile_latent_stride_height = int((height + self.h_split - 1) / self.h_split) - tile_latent_stride_width = int((width + self.w_split - 1) / self.w_split) - - # Calculate overlap in latent space - overlap_latent_height = 3 - overlap_latent_width = 3 - if self.overlap_pixels is not None: - overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio - overlap_latent_height = overlap_latent - overlap_latent_width = overlap_latent - elif self.overlap_ratio is not None: - overlap_latent_height = int(self.overlap_ratio * height) - overlap_latent_width = int(self.overlap_ratio * width) + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width = \ + self.calculate_tiled_parallel_size(latent_height, latent_width) - # Calculate minimum tile size in latent space - tile_latent_min_height = tile_latent_stride_height + overlap_latent_height - tile_latent_min_width = tile_latent_stride_width + overlap_latent_width + def vae_decode_op( + z, patch_height_start, patch_height_end, patch_width_start, patch_width_end, num_frames + ) -> torch.Tensor: - # Convert min/stride to sample space - tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio - tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio - tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio - tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio - - if self.config.patch_size is not None: - sample_height = sample_height // self.config.patch_size - sample_width = sample_width // self.config.patch_size - tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size - tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size - blend_height = tile_sample_min_height // self.config.patch_size - tile_sample_stride_height - blend_width = tile_sample_min_width // self.config.patch_size - tile_sample_stride_width - else: - blend_height = tile_sample_min_height - tile_sample_stride_height - blend_width = tile_sample_min_width - tile_sample_stride_width - - # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] - num_tile_rows = self.h_split - num_tile_cols = self.w_split - - # Split z into overlapping tiles and decode them separately. - # The tiles have an overlap to avoid seams between tiles. - # Each rank computes only tiles assigned to it based on tile_idxs_per_rank - local_tiles = [] # List to store tiles computed by this rank - local_hw_shapes = [] # List to store shapes of tiles by this rank - - for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: self.clear_cache() - patch_height_start = h_idx * tile_latent_stride_height - patch_height_end = patch_height_start + tile_latent_min_height - patch_width_start = w_idx * tile_latent_stride_width - patch_width_end = patch_width_start + tile_latent_min_width + time = [] for k in range(num_frames): self._conv_idx = [0] @@ -1648,42 +1591,14 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni ) time.append(decoded) time = torch.cat(time, dim=2) - local_tiles.append(time.flatten(-2, -1)) # flatten h,w dim for concate all tiles in one rank - local_hw_shapes.append(torch.Tensor([*time.shape[-2:]]).to(device).int()) # record hw for futher unflatten self.clear_cache() + return time - # concat all tiles on local rank - local_tiles = torch.cat(local_tiles, dim=-1) - local_hw_shapes = torch.stack(local_hw_shapes) - - # get all hw shapes for each rank (perhaps has different shapes for last tile) - gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) - for num_tiles in self.num_tiles_per_rank] - dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) - - # gather tiles on all ranks - bcn_ = local_tiles.shape[:-1] - gathered_tiles = [ - torch.empty( - (*bcn_, tiles_shape.prod(dim=1).sum().item()), - dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list - ] - dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) - - # put tiles in rows based on tile_idxs_per_rank - rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] - for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): - if not tile_idxs: - continue - rank_tile_hw_shapes = gathered_shape_list[rank_idx] - hw_start_idx = 0 - # perhaps has more than one tile in each rank, get each by hw_shapes - for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): - rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] - hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw - rows[h_idx][w_idx] = gathered_tiles[rank_idx][..., hw_start_idx:hw_end_idx].unflatten( - 3, rank_tile_hw_shape.tolist()) # unflatten hw dim - hw_start_idx = hw_end_idx + rows = self.run_vae_tile_parallel( + z, vae_decode_op, + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, device, + num_frames=num_frames + ) # combine all tiles, same as tiled decode result_rows = [] @@ -1693,9 +1608,9 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_height) + tile = self.blend_v(rows[i - 1][j], tile, blend_sample_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_width) + tile = self.blend_h(row[j - 1], tile, blend_sample_width) result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width]) result_rows.append(torch.cat(result_row, dim=-1)) dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py index 9c6031a988f9..d798711ec240 100644 --- a/src/diffusers/models/autoencoders/vae.py +++ b/src/diffusers/models/autoencoders/vae.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, List import numpy as np import torch import torch.nn as nn +import torch.distributed as dist from ...utils import BaseOutput from ...utils.torch_utils import randn_tensor @@ -926,3 +927,78 @@ def disable_slicing(self): decoding in one step. """ self.use_slicing = False + + def enable_dp(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + if not hasattr(self, "use_tiling"): + raise NotImplementedError(f"Tiling Parallel doesn't seem to be implemented for {self.__class__.__name__}.") + self.use_dp = True + + def disable_dp(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_dp = False + + def run_vae_tile_parallel( + self, + input: torch.Tensor, + vae_op, + min_height, + min_width, + stride_height, + stride_width, + device, + **kwargs) -> List[List[torch.Tensor]]: + + local_tiles = [] + local_hw_shapes = [] + + for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: + patch_height_start = h_idx * stride_height + patch_height_end = patch_height_start + min_height + patch_width_start = w_idx * stride_width + patch_width_end = patch_width_start + min_width + tile = vae_op(input, patch_height_start, patch_height_end, patch_width_start, patch_width_end, **kwargs) + local_tiles.append(tile.flatten(-2, -1)) + local_hw_shapes.append(torch.Tensor([*tile.shape[-2:]]).to(device).int()) + + # concat all tiles on local rank + local_tiles = torch.cat(local_tiles, dim=-1) + local_hw_shapes = torch.stack(local_hw_shapes) + + # get all hw shapes for each rank (perhaps has different shapes for last tile) + gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) + for num_tiles in self.num_tiles_per_rank] + dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) + + # gather tiles on all ranks + tile_shape_first = local_tiles.shape[:-1] + gathered_tiles = [ + torch.empty( + (*tile_shape_first, tiles_shape.prod(dim=1).sum().item()), + dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list + ] + dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) + + # put tiles in rows based on tile_idxs_per_rank + rows = [[None] * self.w_split for _ in range(self.h_split)] + for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): + if not tile_idxs: + continue + rank_tile_hw_shapes = gathered_shape_list[rank_idx] + hw_start_idx = 0 + # perhaps has more than one tile in each rank, get each by hw_shapes + for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): + rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] + hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw + rows[h_idx][w_idx] = gathered_tiles[rank_idx][..., hw_start_idx:hw_end_idx].unflatten( + -1, rank_tile_hw_shape.tolist()) # unflatten hw dim + hw_start_idx = hw_end_idx + + return rows \ No newline at end of file From 4aeeeb98698e7f51ddb96718cd5dad5d768a0e6b Mon Sep 17 00:00:00 2001 From: yyt Date: Wed, 5 Nov 2025 09:04:13 +0000 Subject: [PATCH 05/17] optimize blend method in tiled vae --- .../models/autoencoders/autoencoder_kl.py | 26 ++++++++++++++----- .../models/autoencoders/autoencoder_kl_wan.py | 22 +++++++++++++--- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 33841b2dae06..3517dde44f97 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -366,6 +366,20 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch. b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) return b + def blend_v_(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + y = torch.arange(0, blend_extent, device=a.device) + blend_ratio = (y / blend_extent)[None, None, :, None].to(a.dtype) + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - blend_ratio) + b[:, :, y, :] * blend_ratio + return b + + def blend_h_(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + x = torch.arange(0, blend_extent, device=a.device) + blend_ratio = (x / blend_extent)[None, None, None, :].to(a.dtype) + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - blend_ratio) + b[:, :, :, x] * blend_ratio + return b + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: r"""Encode a batch of images using a tiled encoder. @@ -523,6 +537,8 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod if not return_dict: return (dec,) + return DecoderOutput(sample=dec) + def calculate_tiled_parallel_size(self, latent_height, latent_width): # Calculate stride based on h_split and w_split tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split) @@ -559,8 +575,6 @@ def calculate_tiled_parallel_size(self, latent_height, latent_width): tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width - return DecoderOutput(sample=dec) - def _tiled_encode_with_dp(self, x: torch.Tensor) -> torch.Tensor: r"""Encode a batch of images using a tiled encoder. @@ -609,9 +623,9 @@ def vae_encode_op( # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_latent_height) + tile = self.blend_v_(rows[i - 1][j], tile, blend_latent_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_latent_width) + tile = self.blend_h_(row[j - 1], tile, blend_latent_width) result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width]) result_rows.append(torch.cat(result_row, dim=3)) @@ -664,9 +678,9 @@ def vae_decode_op( # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_sample_height) + tile = self.blend_v_(rows[i - 1][j], tile, blend_sample_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_sample_width) + tile = self.blend_h_(row[j - 1], tile, blend_sample_width) result_row.append(tile[:, :, :tile_sample_stride_height, :tile_sample_stride_width]) result_rows.append(torch.cat(result_row, dim=3)) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 960c3a9d87f2..8bd8a12403eb 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -1295,6 +1295,20 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch. ) return b + def blend_v_(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + y = torch.arange(0, blend_extent, device=a.device) + blend_ratio = (y / blend_extent)[None, None, None, :, None].to(a.dtype) + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - blend_ratio) + b[:, :, :, y, :] * blend_ratio + return b + + def blend_h_(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + x = torch.arange(0, blend_extent, device=a.device) + blend_ratio = (x / blend_extent)[None, None, None, None, :].to(a.dtype) + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - blend_ratio) + b[:, :, :, :, x] * blend_ratio + return b + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: r"""Encode a batch of images using a tiled encoder. @@ -1542,9 +1556,9 @@ def vae_encode_op( # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_latent_height) + tile = self.blend_v_(rows[i - 1][j], tile, blend_latent_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_latent_width) + tile = self.blend_h_(row[j - 1], tile, blend_latent_width) result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) result_rows.append(torch.cat(result_row, dim=-1)) @@ -1608,9 +1622,9 @@ def vae_decode_op( # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_sample_height) + tile = self.blend_v_(rows[i - 1][j], tile, blend_sample_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_sample_width) + tile = self.blend_h_(row[j - 1], tile, blend_sample_width) result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width]) result_rows.append(torch.cat(result_row, dim=-1)) dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] From 8cfad756b8f286b1faff50113d7ba4114526e913 Mon Sep 17 00:00:00 2001 From: yyt Date: Wed, 5 Nov 2025 12:49:56 +0000 Subject: [PATCH 06/17] fix world_size 1 bug when init parallel tiling --- src/diffusers/models/autoencoders/autoencoder_kl.py | 2 +- src/diffusers/models/autoencoders/autoencoder_kl_wan.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 3517dde44f97..f11c7db25386 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -229,7 +229,7 @@ def enable_dp( r""" """ if world_size is None: - world_size = dist.get_world_size() + world_size = dist.get_world_size() if dist.is_initialized() else 1 if world_size <= 1 or world_size > dist.get_world_size(): logger.warning( diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 8bd8a12403eb..0252f1e22580 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -17,6 +17,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torch.distributed as dist from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin @@ -1123,7 +1124,7 @@ def enable_dp( r""" """ if world_size is None: - world_size = dist.get_world_size() + world_size = dist.get_world_size() if dist.is_initialized() else 1 if world_size <= 1 or world_size > dist.get_world_size(): logger.warning( From 25115150a16ee85b31df44566ac31fbfab67e576 Mon Sep 17 00:00:00 2001 From: yyt Date: Wed, 5 Nov 2025 12:56:40 +0000 Subject: [PATCH 07/17] fix bug in vae_kl_wan --- src/diffusers/models/autoencoders/autoencoder_kl_wan.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 0252f1e22580..f034cebde12b 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -1074,6 +1074,8 @@ def __init__( self.tile_sample_stride_height = 192 self.tile_sample_stride_width = 192 + self.use_dp = False + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup self._cached_conv_counts = { "decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules()) @@ -1177,6 +1179,9 @@ def _encode(self, x: torch.Tensor): if self.config.patch_size is not None: x = patchify(x, patch_size=self.config.patch_size) + if self.use_dp: + return self.tiled_encode_with_dp(x) + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): return self.tiled_encode(x) @@ -1229,6 +1234,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True): tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + if self.use_dp: + return self.tiled_decode_with_dp(z, return_dict=return_dict) + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): return self.tiled_decode(z, return_dict=return_dict) From b690a6682b980bcd97713dd8606338cefb4c0a42 Mon Sep 17 00:00:00 2001 From: yyt Date: Thu, 6 Nov 2025 06:09:24 +0000 Subject: [PATCH 08/17] rmsnorm+rope compute in other stream --- .../models/transformers/transformer_flux.py | 337 +++++++++++++++++- 1 file changed, 336 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 16662e8d8fe8..bc350158eaf4 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -39,6 +39,9 @@ from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +STREAM_VECTOR = torch.npu.Stream() +STREAM_COMM = torch.npu.Stream() + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -131,7 +134,7 @@ def __call__( if hasattr(self._parallel_config, "context_parallel_config") and \ self._parallel_config.context_parallel_config is not None: - return self._context_parallel_forward( + return self._context_parallel_forward_vqk( attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, pre_query, pre_key, cal_q ) @@ -191,6 +194,338 @@ def __call__( else: return hidden_states + def _context_parallel_forward_cv( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + pre_query: Optional[torch.Tensor] = None, + pre_key: Optional[torch.Tensor] = None, + cal_q=True + ) -> torch.Tensor: + + ulysses_mesh = self._parallel_config.context_parallel_config._ulysses_mesh + world_size = self._parallel_config.context_parallel_config.ulysses_degree + group = ulysses_mesh.get_group() + + ev_q = torch.npu.Event() + ev_k = torch.npu.Event() + ev_v = torch.npu.Event() + + query = attn.to_q(hidden_states) + query = query.unflatten(-1, (attn.heads, -1)) + ev_q.record() + key = attn.to_k(hidden_states) + key = key.unflatten(-1, (attn.heads, -1)) + ev_k.record() + + value = attn.to_v(hidden_states) + value = value.unflatten(-1, (attn.heads, -1)) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_value = attn.add_v_proj(encoder_hidden_states) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + value = torch.cat([encoder_value, value], dim=1) + B, S_KV_LOCAL, H, D = value.shape + H_LOCAL = H // world_size + value = value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous().flatten() + v_size = value.numel() + ev_v.record() + + with torch.npu.stream(STREAM_VECTOR): + ev_q.wait() + query = attn.norm_q(query) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_query = attn.norm_added_q(encoder_query) + query = torch.cat([encoder_query, query], dim=1) + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + + ev_k.wait() + key = attn.norm_k(key) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_key = attn.norm_added_k(encoder_key) + key = torch.cat([encoder_key, key], dim=1) + if image_rotary_emb is not None: + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + B, S_Q_LOCAL, H, D = query.shape + _, S_KV_LOCAL, _, _ = key.shape + H_LOCAL = H // world_size + + query = query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous().flatten() + key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous().flatten() + q_size = query.numel() + k_size = query.numel() + + ev_v.wait() + qkv_all = funcol.all_to_all_single(torch.cat([query, key, value], dim=0), None, None, group) + + qkv_all = _wait_tensor(qkv_all) + query_all = qkv_all[:q_size].reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + key_all = qkv_all[q_size: q_size + k_size].reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + value_all = qkv_all[-v_size:].reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + # query_all = _wait_tensor(query_all) + # query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + # key_all = _wait_tensor(key_all) + # key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + # value_all = _wait_tensor(value_all) + # value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + out = npu_fusion_attention( + query_all, + key_all, + value_all, + H_LOCAL, # num_heads + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(D), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + out = out.transpose(1, 2).contiguous() + out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() + out = _all_to_all_single(out, group) + hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + def _context_parallel_forward_vqk( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + pre_query: Optional[torch.Tensor] = None, + pre_key: Optional[torch.Tensor] = None, + cal_q=True + ) -> torch.Tensor: + + ulysses_mesh = self._parallel_config.context_parallel_config._ulysses_mesh + world_size = self._parallel_config.context_parallel_config.ulysses_degree + group = ulysses_mesh.get_group() + + ev_q = torch.npu.Event() + ev_k = torch.npu.Event() + + value = attn.to_v(hidden_states) + value = value.unflatten(-1, (attn.heads, -1)) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_value = attn.add_v_proj(encoder_hidden_states) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + value = torch.cat([encoder_value, value], dim=1) + + B, S_KV_LOCAL, H, D = value.shape + H_LOCAL = H // world_size + value_all = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + + query = attn.to_q(hidden_states) + query = query.unflatten(-1, (attn.heads, -1)) + ev_q.record() + key = attn.to_k(hidden_states) + key = key.unflatten(-1, (attn.heads, -1)) + ev_k.record() + + with torch.npu.stream(STREAM_VECTOR): + ev_q.wait() + query = attn.norm_q(query) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_query = attn.norm_added_q(encoder_query) + query = torch.cat([encoder_query, query], dim=1) + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + + _, S_Q_LOCAL, _, _ = query.shape + query_all = ulysses_preforward(query, group, world_size, B, S_Q_LOCAL, H, D, H_LOCAL) + + ev_k.wait() + key = attn.norm_k(key) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_key = attn.norm_added_k(encoder_key) + key = torch.cat([encoder_key, key], dim=1) + if image_rotary_emb is not None: + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + + value_all = _wait_tensor(value_all) + value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + query_all = _wait_tensor(query_all) + query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + key_all = _wait_tensor(key_all) + key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + out = npu_fusion_attention( + query_all, + key_all, + value_all, + H_LOCAL, # num_heads + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(D), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + out = out.transpose(1, 2).contiguous() + out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() + out = _all_to_all_single(out, group) + hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + def _context_parallel_forward_qkv( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + pre_query: Optional[torch.Tensor] = None, + pre_key: Optional[torch.Tensor] = None, + cal_q=True + ) -> torch.Tensor: + + ulysses_mesh = self._parallel_config.context_parallel_config._ulysses_mesh + world_size = self._parallel_config.context_parallel_config.ulysses_degree + group = ulysses_mesh.get_group() + + ev_q = torch.npu.Event() + ev_k = torch.npu.Event() + + query = attn.to_q(hidden_states) + query = query.unflatten(-1, (attn.heads, -1)) + ev_q.record() + key = attn.to_k(hidden_states) + key = key.unflatten(-1, (attn.heads, -1)) + ev_k.record() + + value = attn.to_v(hidden_states) + value = value.unflatten(-1, (attn.heads, -1)) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_value = attn.add_v_proj(encoder_hidden_states) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + value = torch.cat([encoder_value, value], dim=1) + + with torch.npu.stream(STREAM_VECTOR): + ev_q.wait() + query = attn.norm_q(query) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_query = attn.norm_added_q(encoder_query) + query = torch.cat([encoder_query, query], dim=1) + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + + B, S_Q_LOCAL, H, D = query.shape + H_LOCAL = H // world_size + query_all = ulysses_preforward(query, group, world_size, B, S_Q_LOCAL, H, D, H_LOCAL) + + ev_k.wait() + key = attn.norm_k(key) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_key = attn.norm_added_k(encoder_key) + key = torch.cat([encoder_key, key], dim=1) + if image_rotary_emb is not None: + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + _, S_KV_LOCAL, _, _ = key.shape + key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + + value_all = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + + query_all = _wait_tensor(query_all) + query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + key_all = _wait_tensor(key_all) + key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + value_all = _wait_tensor(value_all) + value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + out = npu_fusion_attention( + query_all, + key_all, + value_all, + H_LOCAL, # num_heads + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(D), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + out = out.transpose(1, 2).contiguous() + out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() + out = _all_to_all_single(out, group) + hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + def _context_parallel_forward( self, attn: "FluxAttention", From 71e652571b1e85e366701450ebe85e31d1e1b63e Mon Sep 17 00:00:00 2001 From: yyt Date: Thu, 6 Nov 2025 06:10:05 +0000 Subject: [PATCH 09/17] prof / time --- src/diffusers/pipelines/flux/pipeline_flux.py | 37 +++++++++++++++++++ src/diffusers/pipelines/wan/pipeline_wan.py | 21 +++++++++++ .../pipelines/wan/pipeline_wan_i2v.py | 35 ++++++++++++++++++ 3 files changed, 93 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 732e98c048f0..ff27a272fb9f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -26,6 +26,8 @@ T5TokenizerFast, ) +import time + from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, FluxTransformer2DModel @@ -680,6 +682,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + prof=None ): r""" Function invoked when calling the pipeline for generation. @@ -779,6 +782,10 @@ def __call__( is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ + torch.cuda.synchronize() + if prof: + prof.start() + t_start = time.time() height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor @@ -925,6 +932,10 @@ def __call__( batch_size * num_images_per_prompt, ) + torch.cuda.synchronize() + if prof: + prof.step() + t_preprocess = time.time() # 6. Denoising loop # We set the index here to remove DtoH sync, helpful especially during compilation. # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 @@ -996,6 +1007,12 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + torch.cuda.synchronize() + if prof: + prof.step() + + t_dit = time.time() + self._current_timestep = None if output_type == "latent": @@ -1006,10 +1023,30 @@ def __call__( image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) + torch.cuda.synchronize() + if prof: + prof.step() + + t_vae = time.time() + # Offload all models self.maybe_free_model_hooks() self.image_rotary_emb = None + is_print = torch.distributed.get_rank() == 0 if torch.distributed.is_initialized() else True + if is_print: + headers = ["Total", "Prepare", "DIT", "VAE", "DIT_PER_STEP"] + time_list = [t_vae - t_start, t_preprocess - t_start, t_dit - t_preprocess, t_vae - t_dit, (t_dit - t_start) / num_inference_steps] + time_list = [f"{t:.3f}" for t in time_list] + widths = [10, 10, 10, 10, 10] + def _fmt_row(values): + return " | ".join(str(values[i]).ljust(widths[i]) for i in range(len(headers))) + + + sep = "-+-".join("-" * w for w in widths) + print(_fmt_row(headers)) + print(_fmt_row(time_list)) + if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index e7226d336ac8..69c07cc715ef 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -14,6 +14,7 @@ import html from typing import Any, Callable, Dict, List, Optional, Union +import time import regex as re import torch @@ -473,6 +474,8 @@ def __call__( indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ + t_start = time.time() + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -564,6 +567,8 @@ def __call__( else: boundary_timestep = None + t_preprocess = time.time() + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -628,6 +633,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + + t_dit = time.time() self._current_timestep = None @@ -647,9 +654,23 @@ def __call__( else: video = latents + t_vae = time.time() + # Offload all models self.maybe_free_model_hooks() + if torch.distributed.get_rank() == 0: + headers = ["Total", "Prepare", "DIT", "VAE", "DIT_PER_STEP"] + time_list = [t_vae - t_start, t_preprocess - t_start, t_dit - t_preprocess, t_vae - t_dit, (t_dit - t_start) / num_inference_steps] + time_list = [f"{t:.3f}" for t in time_list] + widths = [10, 10, 10, 10, 10] + def _fmt_row(values): + return " | ".join(str(values[i]).ljust(widths[i]) for i in range(len(headers))) + + + sep = "-+-".join("-" * w for w in widths) + print(_fmt_row(headers)) + print(_fmt_row(time_list)) self.transformer.rotary_emb = None if self.transformer_2 is not None: diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index b7fd0b05980f..05230f35aebd 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -14,6 +14,7 @@ import html from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import time import PIL import regex as re @@ -356,6 +357,15 @@ def check_inputs( if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + if getattr(getattr(self.transformer, "_parallel_config", None), "context_parallel_config", None) is not None: + mesh_size = self.transformer._parallel_config.context_parallel_config._flattened_mesh.size() + mod_size = 16 * mesh_size + if height % mod_size != 0 or width % mod_size != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 16 * {mesh_size} " \ + f"when enable context parallel, but are {height} and {width}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -613,6 +623,8 @@ def __call__( indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ + t_start = time.time() + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -724,6 +736,7 @@ def __call__( else: boundary_timestep = None + t_preprocess = time.time() with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -794,6 +807,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + t_dit = time.time() + self._current_timestep = None if self.config.expand_timesteps: @@ -815,9 +830,29 @@ def __call__( else: video = latents + t_vae = time.time() + # Offload all models self.maybe_free_model_hooks() + self.transformer.rotary_emb = None + if self.transformer_2 is not None: + self.transformer_2.rotary_emb = None + + is_print = torch.distributed.get_rank() == 0 if torch.distributed.is_initialized() else True + if is_print: + headers = ["Total", "Prepare", "DIT", "VAE", "DIT_PER_STEP"] + time_list = [t_vae - t_start, t_preprocess - t_start, t_dit - t_preprocess, t_vae - t_dit, (t_dit - t_start) / num_inference_steps] + time_list = [f"{t:.3f}" for t in time_list] + widths = [10, 10, 10, 10, 10] + def _fmt_row(values): + return " | ".join(str(values[i]).ljust(widths[i]) for i in range(len(headers))) + + + sep = "-+-".join("-" * w for w in widths) + print(_fmt_row(headers)) + print(_fmt_row(time_list)) + if not return_dict: return (video,) From 333caa4ef3a09fb0581b9ecaf00ba14e2be2aaa9 Mon Sep 17 00:00:00 2001 From: yyt Date: Thu, 6 Nov 2025 11:11:15 +0000 Subject: [PATCH 10/17] add mindie_sd laser attention backend --- src/diffusers/models/attention_dispatch.py | 121 ++++++++++++++++++--- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 4 + 3 files changed, 113 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 289c3e82955b..ebee0a2ed0bb 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -41,6 +41,7 @@ is_torch_xla_version, is_xformers_available, is_xformers_version, + is_mindie_sd_available, ) from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS @@ -63,6 +64,7 @@ _CAN_USE_NPU_ATTN = is_torch_npu_available() _CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION) _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION) +_CAN_USE_MINDIESD_ATTN = is_mindie_sd_available() if _CAN_USE_FLASH_ATTN: @@ -142,6 +144,13 @@ else: xops = None + +if _CAN_USE_MINDIESD_ATTN: + from mindiesd import attention_forward as mindie_sd_attn_forward +else: + mindie_sd_attn_forward = None + + # Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4 if torch.__version__ >= "2.4.0": _custom_op = torch.library.custom_op @@ -215,6 +224,9 @@ class AttentionBackendName(str, Enum): # `xformers` XFORMERS = "xformers" + # mindie_sd + _MINDIE_SD_LASER = "_mindie_sd_la" + class _AttentionBackendRegistry: _backends = {} @@ -470,6 +482,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`." ) + elif backend == AttentionBackendName._MINDIE_SD_LASER: + if not _CAN_USE_MINDIESD_ATTN: + raise RuntimeError( + f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`." + ) + @functools.lru_cache(maxsize=128) def _prepare_for_flash_attn_or_sage_varlen_without_mask( @@ -907,19 +925,14 @@ def _npu_attention_forward_op( _save_ctx: bool = True, _parallel_config: Optional["ParallelConfig"] = None, ): - # if enable_gqa: - # raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.") if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") - # tensors_to_save = () - # Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results # if the input tensors are not contiguous. query = query.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() - # tensors_to_save += (query, key, value) out = npu_fusion_attention( query, @@ -936,14 +949,6 @@ def _npu_attention_forward_op( inner_precise=0, )[0] - # tensors_to_save += (out) - # if _save_ctx: - # ctx.save_for_backward(*tensors_to_save) - # ctx.dropout_p = dropout_p - # ctx.is_causal = is_causal - # ctx.scale = scale - # ctx.attn_mask = attn_mask - out = out.transpose(1, 2).contiguous() return out @@ -959,6 +964,52 @@ def _npu_attention_backward_op( raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.") +def _mindie_sd_laser_attn_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + if enable_gqa: + raise ValueError("`enable_gqa` is not yet supported for MindIE SD Laser Attention.") + if return_lse: + raise ValueError("MindIE SD attention backend does not support setting `return_lse=True`.") + + # query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) + # print(f"[YYT DEBUG] >>>>>>>> {query.shape=}") + # print(f"[YYT DEBUG] >>>>>>>> {key.shape=}") + # print(f"[YYT DEBUG] >>>>>>>> {value.shape=}") + + out = mindie_sd_attn_forward( + query, + key, + value, + opt_mode="manual", + op_type="ascend_laser_attention", + layout="BNSD" + ) + + # out = out.transpose(1, 2).contiguous() + + return out + +def _mindie_sd_laser_attn_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + raise NotImplementedError("Backward pass is not implemented for MindIE SD Laser Attention.") + + # ===== Context parallel ===== @@ -2095,3 +2146,47 @@ def _xformers_attention( out = out.flatten(2, 3) return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._MINDIE_SD_LASER, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _mindie_sd_laser_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> torch.Tensor: + if return_lse: + raise ValueError("MINDIE SD attention backend does not support setting `return_lse=True`.") + if _parallel_config is None: + # query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) + out = mindie_sd_attn_forward( + query, + key, + value, + opt_mode="manual", + op_type="ascend_laser_attention", + layout="BNSD" + ) + # out = out.transpose(1, 2).contiguous() + else: + out = _templated_context_parallel_attention( + query, + key, + value, + None, + dropout_p, + None, + scale, + None, + return_lse, + forward_op=_mindie_sd_laser_attn_forward_op, + backward_op=_mindie_sd_laser_attn_backward_op, + _parallel_config=_parallel_config, + ) + return out \ No newline at end of file diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index cf77aaee8205..49758d1b2454 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -122,6 +122,7 @@ is_wandb_available, is_xformers_available, is_xformers_version, + is_mindie_sd_available, requires_backends, ) from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index adf8ed8b0694..985bb896fa20 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -229,6 +229,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _aiter_available, _aiter_version = _is_package_available("aiter") _kornia_available, _kornia_version = _is_package_available("kornia") _nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True) +_mindie_sd_available, _mindie_sd_version = _is_package_available("mindiesd") def is_torch_available(): @@ -414,6 +415,9 @@ def is_aiter_available(): def is_kornia_available(): return _kornia_available +def is_mindie_sd_available(): + return _mindie_sd_available + # docstyle-ignore FLAX_IMPORT_ERROR = """ From 1fec0ea43e0bf708d8a0efdb80bf8868196dd17a Mon Sep 17 00:00:00 2001 From: yyt Date: Thu, 6 Nov 2025 11:21:07 +0000 Subject: [PATCH 11/17] add WanAttnProcessorNPU with fixed atten_backend --- .../models/transformers/transformer_wan.py | 95 +++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 2e566f1daaaa..12d47dd2c80d 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -65,6 +65,101 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t return key_img, value_img +class WanAttnProcessorNPU: + _attention_backend_fix = "_native_npu" + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: "WanAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + # 512 is the context length of the text encoder, hardcoded for now + image_context_length = encoder_hidden_states.shape[1] - 512 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] + + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if rotary_emb is not None: + + def apply_rotary_emb( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2] + sin = freqs_sin[..., 1::2] + out = torch.empty_like(hidden_states) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out.type_as(hidden_states) + + query = apply_rotary_emb(query, *rotary_emb) + key = apply_rotary_emb(key, *rotary_emb) + + # I2V task + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img) + key_img = attn.norm_added_k(key_img) + + key_img = key_img.unflatten(2, (attn.heads, -1)) + value_img = value_img.unflatten(2, (attn.heads, -1)) + + hidden_states_img = dispatch_attention_fn( + query, + key_img, + value_img, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend_fix, + parallel_config=self._parallel_config, + ) + hidden_states_img = hidden_states_img.flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend_fix, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + class WanAttnProcessor: _attention_backend = None _parallel_config = None From 1a99c38492a6c457494e62b1279680e6ad6b85e8 Mon Sep 17 00:00:00 2001 From: yyt Date: Thu, 6 Nov 2025 11:30:03 +0000 Subject: [PATCH 12/17] prof --- src/diffusers/pipelines/flux/pipeline_flux.py | 4 +++- src/diffusers/pipelines/wan/pipeline_wan_i2v.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index ff27a272fb9f..375a04ea1bd4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -28,6 +28,8 @@ import time +import time + from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, FluxTransformer2DModel @@ -1007,10 +1009,10 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - torch.cuda.synchronize() if prof: prof.step() + torch.cuda.synchronize() t_dit = time.time() self._current_timestep = None diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 05230f35aebd..626bd4883e83 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -542,6 +542,7 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + prof=None, ): r""" The call function to the pipeline for generation. @@ -623,7 +624,10 @@ def __call__( indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ + torch.cuda.synchronize() t_start = time.time() + if prof: + prof.start() if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -736,7 +740,11 @@ def __call__( else: boundary_timestep = None + torch.cuda.synchronize() t_preprocess = time.time() + if prof: + prof.step() + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -807,6 +815,10 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + if prof: + prof.step() + + torch.cuda.synchronize() t_dit = time.time() self._current_timestep = None @@ -830,7 +842,11 @@ def __call__( else: video = latents + torch.cuda.synchronize() t_vae = time.time() + if prof: + prof.step() + prof.stop() # Offload all models self.maybe_free_model_hooks() From a257bdde4b6f8093a7a9f416db9da4a70d246985 Mon Sep 17 00:00:00 2001 From: yyt Date: Thu, 6 Nov 2025 11:49:33 +0000 Subject: [PATCH 13/17] prof --- src/diffusers/pipelines/wan/pipeline_wan.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 69c07cc715ef..5fdca1305f4f 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -403,6 +403,7 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + prof=None, ): r""" The call function to the pipeline for generation. @@ -474,6 +475,9 @@ def __call__( indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ + torch.cuda.synchronize() + if prof: + prof.start() t_start = time.time() if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): @@ -567,6 +571,9 @@ def __call__( else: boundary_timestep = None + torch.cuda.synchronize() + if prof: + prof.step() t_preprocess = time.time() with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -633,7 +640,11 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + + if prof: + prof.step() + torch.cuda.synchronize() t_dit = time.time() self._current_timestep = None @@ -654,6 +665,10 @@ def __call__( else: video = latents + torch.cuda.synchronize() + if prof: + prof.step() + prof.stop() t_vae = time.time() # Offload all models From 5b635a3617a2fa256e678c7265f16b49cbbd0bd7 Mon Sep 17 00:00:00 2001 From: yyt Date: Thu, 6 Nov 2025 15:58:57 +0000 Subject: [PATCH 14/17] laser attn --- src/diffusers/models/attention_dispatch.py | 7 --- .../models/transformers/transformer_flux.py | 44 ++++++++++++------- .../models/transformers/transformer_wan.py | 2 +- 3 files changed, 29 insertions(+), 24 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index ebee0a2ed0bb..f6b542331a05 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -983,11 +983,6 @@ def _mindie_sd_laser_attn_forward_op( if return_lse: raise ValueError("MindIE SD attention backend does not support setting `return_lse=True`.") - # query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) - # print(f"[YYT DEBUG] >>>>>>>> {query.shape=}") - # print(f"[YYT DEBUG] >>>>>>>> {key.shape=}") - # print(f"[YYT DEBUG] >>>>>>>> {value.shape=}") - out = mindie_sd_attn_forward( query, key, @@ -997,8 +992,6 @@ def _mindie_sd_laser_attn_forward_op( layout="BNSD" ) - # out = out.transpose(1, 2).contiguous() - return out def _mindie_sd_laser_attn_backward_op( diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index bc350158eaf4..b9f1c81ee683 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -39,6 +39,8 @@ from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from mindiesd import attention_forward as mindie_sd_attn_forward + STREAM_VECTOR = torch.npu.Stream() STREAM_COMM = torch.npu.Stream() @@ -378,29 +380,39 @@ def _context_parallel_forward_vqk( key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) value_all = _wait_tensor(value_all) - value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() query_all = _wait_tensor(query_all) - query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() key_all = _wait_tensor(key_all) - key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() - - out = npu_fusion_attention( + key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + # out = npu_fusion_attention( + # query_all, + # key_all, + # value_all, + # H_LOCAL, # num_heads + # input_layout="BNSD", + # pse=None, + # scale=1.0 / math.sqrt(D), + # pre_tockens=65536, + # next_tockens=65536, + # keep_prob=1.0, + # sync=False, + # inner_precise=0, + # )[0] + + out = mindie_sd_attn_forward( query_all, key_all, value_all, - H_LOCAL, # num_heads - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(D), - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0, - sync=False, - inner_precise=0, - )[0] - out = out.transpose(1, 2).contiguous() + opt_mode="manual", + op_type="ascend_laser_attention", + layout="BNSD" + ) + + # out = out.transpose(1, 2).contiguous() out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() out = _all_to_all_single(out, group) hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 12d47dd2c80d..65539fdcb96d 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -538,7 +538,7 @@ def __init__( eps=eps, added_kv_proj_dim=added_kv_proj_dim, cross_attention_dim_head=dim // num_heads, - processor=WanAttnProcessor(), + processor=WanAttnProcessorNPU(), ) self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() From 6dbf2c2c66fc039a5a9dc6292c499e8141709a0d Mon Sep 17 00:00:00 2001 From: yyt Date: Sat, 8 Nov 2025 09:38:09 +0000 Subject: [PATCH 15/17] flux laser attn --- .../models/transformers/transformer_flux.py | 232 +++++++++--------- 1 file changed, 120 insertions(+), 112 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index b9f1c81ee683..6762c665e79e 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -42,7 +42,6 @@ from mindiesd import attention_forward as mindie_sd_attn_forward STREAM_VECTOR = torch.npu.Stream() -STREAM_COMM = torch.npu.Stream() logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -136,7 +135,7 @@ def __call__( if hasattr(self._parallel_config, "context_parallel_config") and \ self._parallel_config.context_parallel_config is not None: - return self._context_parallel_forward_vqk( + return self._context_parallel_forward_qkv( attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, pre_query, pre_key, cal_q ) @@ -196,6 +195,102 @@ def __call__( else: return hidden_states + def _context_parallel_forward( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + pre_query: Optional[torch.Tensor] = None, + pre_key: Optional[torch.Tensor] = None, + cal_q=True + ) -> torch.Tensor: + + ulysses_mesh = self._parallel_config.context_parallel_config._ulysses_mesh + world_size = self._parallel_config.context_parallel_config.ulysses_degree + group = ulysses_mesh.get_group() + + value = attn.to_v(hidden_states) + value = value.unflatten(-1, (attn.heads, -1)) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_value = attn.add_v_proj(encoder_hidden_states) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + value = torch.cat([encoder_value, value], dim=1) + + B, S_KV_LOCAL, H, D = value.shape + H_LOCAL = H // world_size + value_all = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + + query = attn.to_q(hidden_states) + query = query.unflatten(-1, (attn.heads, -1)) + query = attn.norm_q(query) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_query = attn.norm_added_q(encoder_query) + query = torch.cat([encoder_query, query], dim=1) + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + _, S_Q_LOCAL, _, _ = query.shape + query_all = ulysses_preforward(query, group, world_size, B, S_Q_LOCAL, H, D, H_LOCAL) + + key = attn.to_k(hidden_states) + key = key.unflatten(-1, (attn.heads, -1)) + key = attn.norm_k(key) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_key = attn.norm_added_k(encoder_key) + key = torch.cat([encoder_key, key], dim=1) + if image_rotary_emb is not None: + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + + value_all = _wait_tensor(value_all) + value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + query_all = _wait_tensor(query_all) + query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + + key_all = _wait_tensor(key_all) + key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + out = npu_fusion_attention( + query_all, + key_all, + value_all, + H_LOCAL, # num_heads + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(D), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + out = out.transpose(1, 2).contiguous() + out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() + out = _all_to_all_single(out, group) + hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + def _context_parallel_forward_cv( self, attn: "FluxAttention", @@ -496,125 +591,38 @@ def _context_parallel_forward_qkv( value_all = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) query_all = _wait_tensor(query_all) - query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() key_all = _wait_tensor(key_all) - key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() - - value_all = _wait_tensor(value_all) - value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() - - out = npu_fusion_attention( - query_all, - key_all, - value_all, - H_LOCAL, # num_heads - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(D), - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0, - sync=False, - inner_precise=0, - )[0] - out = out.transpose(1, 2).contiguous() - out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() - out = _all_to_all_single(out, group) - hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() - - hidden_states = hidden_states.flatten(2, 3) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( - [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 - ) - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - def _context_parallel_forward( - self, - attn: "FluxAttention", - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor = None, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - pre_query: Optional[torch.Tensor] = None, - pre_key: Optional[torch.Tensor] = None, - cal_q=True - ) -> torch.Tensor: - - ulysses_mesh = self._parallel_config.context_parallel_config._ulysses_mesh - world_size = self._parallel_config.context_parallel_config.ulysses_degree - group = ulysses_mesh.get_group() - - value = attn.to_v(hidden_states) - value = value.unflatten(-1, (attn.heads, -1)) - if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: - encoder_value = attn.add_v_proj(encoder_hidden_states) - encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) - value = torch.cat([encoder_value, value], dim=1) - - B, S_KV_LOCAL, H, D = value.shape - H_LOCAL = H // world_size - value_all = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) - - query = attn.to_q(hidden_states) - query = query.unflatten(-1, (attn.heads, -1)) - query = attn.norm_q(query) - if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: - encoder_query = attn.add_q_proj(encoder_hidden_states) - encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) - encoder_query = attn.norm_added_q(encoder_query) - query = torch.cat([encoder_query, query], dim=1) - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) - _, S_Q_LOCAL, _, _ = query.shape - query_all = ulysses_preforward(query, group, world_size, B, S_Q_LOCAL, H, D, H_LOCAL) + key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() - key = attn.to_k(hidden_states) - key = key.unflatten(-1, (attn.heads, -1)) - key = attn.norm_k(key) - if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: - encoder_key = attn.add_k_proj(encoder_hidden_states) - encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) - encoder_key = attn.norm_added_k(encoder_key) - key = torch.cat([encoder_key, key], dim=1) - if image_rotary_emb is not None: - key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) - key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) - value_all = _wait_tensor(value_all) - value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() - - query_all = _wait_tensor(query_all) - query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() - + value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() - key_all = _wait_tensor(key_all) - key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + # out = npu_fusion_attention( + # query_all, + # key_all, + # value_all, + # H_LOCAL, # num_heads + # input_layout="BNSD", + # pse=None, + # scale=1.0 / math.sqrt(D), + # pre_tockens=65536, + # next_tockens=65536, + # keep_prob=1.0, + # sync=False, + # inner_precise=0, + # )[0] - out = npu_fusion_attention( + out = mindie_sd_attn_forward( query_all, key_all, value_all, - H_LOCAL, # num_heads - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(D), - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0, - sync=False, - inner_precise=0, - )[0] - out = out.transpose(1, 2).contiguous() + opt_mode="manual", + op_type="ascend_laser_attention", + layout="BNSD" + ) + # out = out.transpose(1, 2).contiguous() out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() out = _all_to_all_single(out, group) hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() From 03dacd220bc0b9b217cdd5e6bc36a89055ffde90 Mon Sep 17 00:00:00 2001 From: yyt Date: Sat, 8 Nov 2025 09:39:45 +0000 Subject: [PATCH 16/17] wan - use mindie rope --- .../models/transformers/transformer_wan.py | 23 +++++++++++++------ src/diffusers/pipelines/wan/pipeline_wan.py | 1 - 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 65539fdcb96d..89b678de5095 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -18,6 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from mindiesd import rotary_position_embedding from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin @@ -196,18 +197,26 @@ def __call__( if rotary_emb is not None: + # def apply_rotary_emb( + # hidden_states: torch.Tensor, + # freqs_cos: torch.Tensor, + # freqs_sin: torch.Tensor, + # ): + # x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + # cos = freqs_cos[..., 0::2] + # sin = freqs_sin[..., 1::2] + # out = torch.empty_like(hidden_states) + # out[..., 0::2] = x1 * cos - x2 * sin + # out[..., 1::2] = x1 * sin + x2 * cos + # return out.type_as(hidden_states) + def apply_rotary_emb( hidden_states: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, ): - x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) - cos = freqs_cos[..., 0::2] - sin = freqs_sin[..., 1::2] - out = torch.empty_like(hidden_states) - out[..., 0::2] = x1 * cos - x2 * sin - out[..., 1::2] = x1 * sin + x2 * cos - return out.type_as(hidden_states) + out = rotary_position_embedding(hidden_states, freqs_cos, freqs_sin, rotated_mode="rotated_interleaved", fused=True) + return out query = apply_rotary_emb(query, *rotary_emb) key = apply_rotary_emb(key, *rotary_emb) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 5fdca1305f4f..53c3a12d4660 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -668,7 +668,6 @@ def __call__( torch.cuda.synchronize() if prof: prof.step() - prof.stop() t_vae = time.time() # Offload all models From 1b119b846351dcd323af1241ec243defa857cc30 Mon Sep 17 00:00:00 2001 From: yyt Date: Tue, 25 Nov 2025 06:41:39 +0000 Subject: [PATCH 17/17] wan: compute-communication-parallel and ulysses parallel optimize --- src/diffusers/models/attention_dispatch.py | 5 +- .../models/transformers/transformer_wan.py | 274 +++++++++++++++++- .../pipelines/wan/pipeline_wan_i2v.py | 6 +- 3 files changed, 269 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index f6b542331a05..8a39423ff2ef 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -266,7 +266,7 @@ def list_backends(cls): def _is_context_parallel_enabled( cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"] ) -> bool: - supports_context_parallel = backend in cls._supports_context_parallel + supports_context_parallel = backend in cls._supports_context_parallel and cls._supports_context_parallel[backend] is_degree_greater_than_1 = parallel_config is not None and ( parallel_config.context_parallel_config.ring_degree > 1 or parallel_config.context_parallel_config.ulysses_degree > 1 @@ -1170,6 +1170,7 @@ def forward( backward_op, _parallel_config: Optional["ParallelConfig"] = None, ): + # print(f"[YYT DEBUG] >>>>> ulysses") ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh world_size = _parallel_config.context_parallel_config.ulysses_degree group = ulysses_mesh.get_group() @@ -1820,6 +1821,7 @@ def _native_math_attention( @_AttentionBackendRegistry.register( AttentionBackendName._NATIVE_NPU, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=True, ) def _native_npu_attention( query: torch.Tensor, @@ -2144,6 +2146,7 @@ def _xformers_attention( @_AttentionBackendRegistry.register( AttentionBackendName._MINDIE_SD_LASER, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=True, ) def _mindie_sd_laser_attention( query: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 89b678de5095..4d0e7fb4b219 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -37,6 +37,10 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +if torch.distributed.is_available(): + import torch.distributed._functional_collectives as funcol + + def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): # encoder_hidden_states is only passed for cross-attention if encoder_hidden_states is None: @@ -66,6 +70,234 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t return key_img, value_img +def _wait_tensor(tensor): + if isinstance(tensor, funcol.AsyncCollectiveTensor): + tensor = tensor.wait() + return tensor + + +def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: + shape = x.shape + x = x.flatten() + x = funcol.all_to_all_single(x, None, None, group) + x = x.reshape(shape) + x = _wait_tensor(x) + return x + + +def ulysses_preforward( + x: torch.Tensor, + group, + world_size, + B, + S_LOCAL, + H, + D, + H_LOCAL +): + x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + x = x.flatten() + x = funcol.all_to_all_single(x, None, None, group) + return x + + +class WanAttnProcessorSP: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: "WanAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + # 512 is the context length of the text encoder, hardcoded for now + image_context_length = encoder_hidden_states.shape[1] - 512 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + ulysses_mesh = self._parallel_config.context_parallel_config._ulysses_mesh + world_size = self._parallel_config.context_parallel_config.ulysses_degree + group = ulysses_mesh.get_group() + + # query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + value = value.unflatten(2, (attn.heads, -1)) + B, S_KV_LOCAL, H, D = value.shape + H_LOCAL = H // world_size + value_all = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + + + def apply_rotary_emb( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + out = rotary_position_embedding(hidden_states, freqs_cos, freqs_sin, rotated_mode="rotated_interleaved", fused=True) + return out + + query = attn.to_q(hidden_states) + query = attn.norm_q(query) + query = query.unflatten(2, (attn.heads, -1)) + if rotary_emb is not None: + query = apply_rotary_emb(query, *rotary_emb) + _, S_Q_LOCAL, _, _ = query.shape + query_all = ulysses_preforward(query, group, world_size, B, S_Q_LOCAL, H, D, H_LOCAL) + + key = attn.to_k(encoder_hidden_states) + key = attn.norm_k(key) + key = key.unflatten(2, (attn.heads, -1)) + if rotary_emb is not None: + key = apply_rotary_emb(key, *rotary_emb) + key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + + value_all = _wait_tensor(value_all) + value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + query_all = _wait_tensor(query_all) + query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + key_all = _wait_tensor(key_all) + key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + hidden_states = dispatch_attention_fn( + query_all, + key_all, + value_all, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=None, + ) + hidden_states = hidden_states.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() + hidden_states = _all_to_all_single(hidden_states, group) + hidden_states = hidden_states.flatten(0, 1).permute(1, 2, 0, 3).contiguous() + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class WanAttnProcessorCross: + _attention_backend_fix = "_native_npu" + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: "WanAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + # 512 is the context length of the text encoder, hardcoded for now + image_context_length = encoder_hidden_states.shape[1] - 512 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] + + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if rotary_emb is not None: + + # def apply_rotary_emb( + # hidden_states: torch.Tensor, + # freqs_cos: torch.Tensor, + # freqs_sin: torch.Tensor, + # ): + # x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + # cos = freqs_cos[..., 0::2] + # sin = freqs_sin[..., 1::2] + # out = torch.empty_like(hidden_states) + # out[..., 0::2] = x1 * cos - x2 * sin + # out[..., 1::2] = x1 * sin + x2 * cos + # return out.type_as(hidden_states) + + def apply_rotary_emb( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + out = rotary_position_embedding(hidden_states, freqs_cos, freqs_sin, rotated_mode="rotated_interleaved", fused=True) + return out + + query = apply_rotary_emb(query, *rotary_emb) + key = apply_rotary_emb(key, *rotary_emb) + + # I2V task + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img) + key_img = attn.norm_added_k(key_img) + + key_img = key_img.unflatten(2, (attn.heads, -1)) + value_img = value_img.unflatten(2, (attn.heads, -1)) + + hidden_states_img = dispatch_attention_fn( + query, + key_img, + value_img, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend_fix, + parallel_config=None, + ) + hidden_states_img = hidden_states_img.flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend_fix, + parallel_config=None, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + class WanAttnProcessorNPU: _attention_backend_fix = "_native_npu" _parallel_config = None @@ -102,18 +334,26 @@ def __call__( if rotary_emb is not None: + # def apply_rotary_emb( + # hidden_states: torch.Tensor, + # freqs_cos: torch.Tensor, + # freqs_sin: torch.Tensor, + # ): + # x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + # cos = freqs_cos[..., 0::2] + # sin = freqs_sin[..., 1::2] + # out = torch.empty_like(hidden_states) + # out[..., 0::2] = x1 * cos - x2 * sin + # out[..., 1::2] = x1 * sin + x2 * cos + # return out.type_as(hidden_states) + def apply_rotary_emb( hidden_states: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, ): - x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) - cos = freqs_cos[..., 0::2] - sin = freqs_sin[..., 1::2] - out = torch.empty_like(hidden_states) - out[..., 0::2] = x1 * cos - x2 * sin - out[..., 1::2] = x1 * sin + x2 * cos - return out.type_as(hidden_states) + out = rotary_position_embedding(hidden_states, freqs_cos, freqs_sin, rotated_mode="rotated_interleaved", fused=True) + return out query = apply_rotary_emb(query, *rotary_emb) key = apply_rotary_emb(key, *rotary_emb) @@ -511,6 +751,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + # if (pad_size := 8 - freqs_cos.shape[1] % 8) != 8: + # b, s, n, d = freqs_cos.shape + # freqs_cos = torch.cat([freqs_cos, freqs_cos.new_ones(b, pad_size, n, d)], dim=1) + # freqs_sin = torch.cat([freqs_sin, freqs_sin.new_ones(b, pad_size, n, d)], dim=1) + return freqs_cos, freqs_sin @@ -536,7 +781,7 @@ def __init__( dim_head=dim // num_heads, eps=eps, cross_attention_dim_head=None, - processor=WanAttnProcessor(), + processor=WanAttnProcessorSP(), ) # 2. Cross-attention @@ -547,7 +792,7 @@ def __init__( eps=eps, added_kv_proj_dim=added_kv_proj_dim, cross_attention_dim_head=dim // num_heads, - processor=WanAttnProcessorNPU(), + processor=WanAttnProcessorCross(), ) self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() @@ -655,9 +900,9 @@ class WanTransformer3DModel( "blocks.0": { "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), }, - "blocks.*": { - "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - }, + # "blocks.*": { + # "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + # }, "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), } @@ -755,6 +1000,10 @@ def forward( hidden_states = self.patch_embedding(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) + # if (pad_size := (8 - hidden_states.shape[1] % 8) % 8) != 0: + # b, s, d = hidden_states.shape + # hidden_states = torch.cat([hidden_states, hidden_states.new_zeros(b, pad_size, d)], dim=1) + # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v) if timestep.ndim == 2: ts_seq_len = timestep.shape[1] @@ -804,6 +1053,7 @@ def forward( hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) hidden_states = self.proj_out(hidden_states) + # hidden_states = hidden_states[:, :-pad_size, :] hidden_states = hidden_states.reshape( batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 626bd4883e83..0342836279ed 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -359,10 +359,10 @@ def check_inputs( if getattr(getattr(self.transformer, "_parallel_config", None), "context_parallel_config", None) is not None: mesh_size = self.transformer._parallel_config.context_parallel_config._flattened_mesh.size() - mod_size = 16 * mesh_size - if height % mod_size != 0 or width % mod_size != 0: + mod_size = 16 * 16 * mesh_size + if (height * width ) % mod_size != 0: raise ValueError( - f"`height` and `width` have to be divisible by 16 * {mesh_size} " \ + f"The product of `height` and `width` have to be divisible by 16 * 16 * {mesh_size} " \ f"when enable context parallel, but are {height} and {width}." )