diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 289c3e82955b..8a39423ff2ef 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -41,6 +41,7 @@ is_torch_xla_version, is_xformers_available, is_xformers_version, + is_mindie_sd_available, ) from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS @@ -63,6 +64,7 @@ _CAN_USE_NPU_ATTN = is_torch_npu_available() _CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION) _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION) +_CAN_USE_MINDIESD_ATTN = is_mindie_sd_available() if _CAN_USE_FLASH_ATTN: @@ -142,6 +144,13 @@ else: xops = None + +if _CAN_USE_MINDIESD_ATTN: + from mindiesd import attention_forward as mindie_sd_attn_forward +else: + mindie_sd_attn_forward = None + + # Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4 if torch.__version__ >= "2.4.0": _custom_op = torch.library.custom_op @@ -215,6 +224,9 @@ class AttentionBackendName(str, Enum): # `xformers` XFORMERS = "xformers" + # mindie_sd + _MINDIE_SD_LASER = "_mindie_sd_la" + class _AttentionBackendRegistry: _backends = {} @@ -254,7 +266,7 @@ def list_backends(cls): def _is_context_parallel_enabled( cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"] ) -> bool: - supports_context_parallel = backend in cls._supports_context_parallel + supports_context_parallel = backend in cls._supports_context_parallel and cls._supports_context_parallel[backend] is_degree_greater_than_1 = parallel_config is not None and ( parallel_config.context_parallel_config.ring_degree > 1 or parallel_config.context_parallel_config.ulysses_degree > 1 @@ -470,6 +482,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`." ) + elif backend == AttentionBackendName._MINDIE_SD_LASER: + if not _CAN_USE_MINDIESD_ATTN: + raise RuntimeError( + f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`." + ) + @functools.lru_cache(maxsize=128) def _prepare_for_flash_attn_or_sage_varlen_without_mask( @@ -907,19 +925,14 @@ def _npu_attention_forward_op( _save_ctx: bool = True, _parallel_config: Optional["ParallelConfig"] = None, ): - # if enable_gqa: - # raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.") if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") - # tensors_to_save = () - # Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results # if the input tensors are not contiguous. query = query.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() - # tensors_to_save += (query, key, value) out = npu_fusion_attention( query, @@ -936,14 +949,6 @@ def _npu_attention_forward_op( inner_precise=0, )[0] - # tensors_to_save += (out) - # if _save_ctx: - # ctx.save_for_backward(*tensors_to_save) - # ctx.dropout_p = dropout_p - # ctx.is_causal = is_causal - # ctx.scale = scale - # ctx.attn_mask = attn_mask - out = out.transpose(1, 2).contiguous() return out @@ -959,6 +964,45 @@ def _npu_attention_backward_op( raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.") +def _mindie_sd_laser_attn_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + if enable_gqa: + raise ValueError("`enable_gqa` is not yet supported for MindIE SD Laser Attention.") + if return_lse: + raise ValueError("MindIE SD attention backend does not support setting `return_lse=True`.") + + out = mindie_sd_attn_forward( + query, + key, + value, + opt_mode="manual", + op_type="ascend_laser_attention", + layout="BNSD" + ) + + return out + +def _mindie_sd_laser_attn_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + raise NotImplementedError("Backward pass is not implemented for MindIE SD Laser Attention.") + + # ===== Context parallel ===== @@ -1126,6 +1170,7 @@ def forward( backward_op, _parallel_config: Optional["ParallelConfig"] = None, ): + # print(f"[YYT DEBUG] >>>>> ulysses") ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh world_size = _parallel_config.context_parallel_config.ulysses_degree group = ulysses_mesh.get_group() @@ -1776,6 +1821,7 @@ def _native_math_attention( @_AttentionBackendRegistry.register( AttentionBackendName._NATIVE_NPU, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=True, ) def _native_npu_attention( query: torch.Tensor, @@ -2095,3 +2141,48 @@ def _xformers_attention( out = out.flatten(2, 3) return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._MINDIE_SD_LASER, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=True, +) +def _mindie_sd_laser_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> torch.Tensor: + if return_lse: + raise ValueError("MINDIE SD attention backend does not support setting `return_lse=True`.") + if _parallel_config is None: + # query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) + out = mindie_sd_attn_forward( + query, + key, + value, + opt_mode="manual", + op_type="ascend_laser_attention", + layout="BNSD" + ) + # out = out.transpose(1, 2).contiguous() + else: + out = _templated_context_parallel_attention( + query, + key, + value, + None, + dropout_p, + None, + scale, + None, + return_lse, + forward_op=_mindie_sd_laser_attn_forward_op, + backward_op=_mindie_sd_laser_attn_backward_op, + _parallel_config=_parallel_config, + ) + return out \ No newline at end of file diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 1a72aa3cfeb3..f11c7db25386 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -15,11 +15,12 @@ import torch import torch.nn as nn +import torch.distributed as dist from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import deprecate +from ...utils import deprecate, logging from ...utils.accelerate_utils import apply_forward_hook from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -35,6 +36,9 @@ from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. @@ -127,6 +131,7 @@ def __init__( self.use_slicing = False self.use_tiling = False + self.use_dp = False # only relevant if vae tiling is enabled self.tile_sample_min_size = self.config.sample_size @@ -214,9 +219,58 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) + def enable_dp( + self, + world_size: Optional[int] = None, + hw_splits: Optional[Tuple[int, int]] = None, + overlap_ratio: Optional[float] = None, + overlap_pixels: Optional[int] = None + ) -> None: + r""" + """ + if world_size is None: + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + if world_size <= 1 or world_size > dist.get_world_size(): + logger.warning( + f"Supported world_size for vae dp is between 2 - {dist.get_world_size}, but got {world_size}. " \ + f"Fall back to normal vae") + return + + if hw_splits is None: + hw_splits = (1, int(world_size)) + + assert len(hw_splits) == 2, f"'hw_splits' should be a tuple of 2 int, but got length {len(hw_splits)}" + + h_split, w_split = map(int, hw_splits) + + self.use_dp = True + self.h_split, self.w_split = h_split, w_split + self.world_size = world_size + self.overlap_ratio = overlap_ratio + self.overlap_pixels = overlap_pixels + self.spatial_compression_ratio = 2 ** (len(self.config.block_out_channels) - 1) + + dp_ranks = list(range(0, world_size)) + self.vae_dp_group = dist.new_group(ranks=dp_ranks) + self.rank = dist.get_rank() + # patch_ranks_flatten = [tile_idx % world_size for tile_idx in range(num_tiles)] + # self.patch_ranks = torch.Tensor(patch_ranks_flatten).reshape(h_split, w_split) + self.tile_idxs_per_rank = [[] for _ in range(self.world_size)] + self.num_tiles_per_rank = [0] * self.world_size + rank_idx = 0 + for h_idx in range(self.h_split): + for w_idx in range(self.w_split): + rank_idx %= self.world_size + self.tile_idxs_per_rank[rank_idx].append((h_idx, w_idx)) + self.num_tiles_per_rank[rank_idx] += 1 + rank_idx += 1 + def _encode(self, x: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = x.shape + if self.use_dp: + return self._tiled_encode(x) if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): return self._tiled_encode(x) @@ -256,6 +310,8 @@ def encode( return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + if self.use_dp: + return self.tiled_decode_with_dp(z, return_dict=return_dict) if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): return self.tiled_decode(z, return_dict=return_dict) @@ -310,6 +366,20 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch. b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) return b + def blend_v_(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + y = torch.arange(0, blend_extent, device=a.device) + blend_ratio = (y / blend_extent)[None, None, :, None].to(a.dtype) + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - blend_ratio) + b[:, :, y, :] * blend_ratio + return b + + def blend_h_(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + x = torch.arange(0, blend_extent, device=a.device) + blend_ratio = (x / blend_extent)[None, None, None, :].to(a.dtype) + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - blend_ratio) + b[:, :, :, x] * blend_ratio + return b + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: r"""Encode a batch of images using a tiled encoder. @@ -469,6 +539,157 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod return DecoderOutput(sample=dec) + def calculate_tiled_parallel_size(self, latent_height, latent_width): + # Calculate stride based on h_split and w_split + tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split) + tile_latent_stride_width = int((latent_width + self.w_split - 1) / self.w_split) + + # Calculate overlap in latent space + overlap_latent_height = 3 + overlap_latent_width = 3 + if self.overlap_pixels is not None: + overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio + overlap_latent_height = overlap_latent + overlap_latent_width = overlap_latent + elif self.overlap_ratio is not None: + overlap_latent_height = int(self.overlap_ratio * latent_height) + overlap_latent_width = int(self.overlap_ratio * latent_width) + + # Calculate minimum tile size in latent space + tile_latent_min_height = tile_latent_stride_height + overlap_latent_height + tile_latent_min_width = tile_latent_stride_width + overlap_latent_width + + tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio + tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio + tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio + tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio + + blend_latent_height = tile_latent_min_height - tile_latent_stride_height + blend_latent_width = tile_latent_min_width - tile_latent_stride_width + + blend_sample_height = tile_sample_min_height - tile_sample_stride_height + blend_sample_width = tile_sample_min_width - tile_sample_stride_width + + return \ + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width + + def _tiled_encode_with_dp(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + + _, _, height, width = x.shape + device = x.device + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width = \ + self.calculate_tiled_parallel_size(latent_height, latent_width) + + def vae_encode_op( + x, patch_height_start, patch_height_end, patch_width_start, patch_width_end + ) -> torch.Tensor: + tile = x[:, :, patch_height_start : patch_height_end, patch_width_start : patch_width_end] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + return tile + + rows = self.run_vae_tile_parallel( + x, vae_encode_op, + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, device + ) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v_(rows[i - 1][j], tile, blend_latent_height) + if j > 0: + tile = self.blend_h_(row[j - 1], tile, blend_latent_width) + result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=3)) + + enc = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width] + return enc + + def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, latent_height, latent_width = z.shape + device = z.device + sample_height = latent_height * self.spatial_compression_ratio + sample_width = latent_width * self.spatial_compression_ratio + + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width = \ + self.calculate_tiled_parallel_size(latent_height, latent_width) + + def vae_decode_op( + z, patch_height_start, patch_height_end, patch_width_start, patch_width_end + ) -> torch.Tensor: + + tile = z[:, :, patch_height_start : patch_height_end, patch_width_start : patch_width_end] + if self.config.use_post_quant_conv: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + return decoded + + rows = self.run_vae_tile_parallel( + z, vae_decode_op, + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, device + ) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v_(rows[i - 1][j], tile, blend_sample_height) + if j > 0: + tile = self.blend_h_(row[j - 1], tile, blend_sample_width) + result_row.append(tile[:, :, :tile_sample_stride_height, :tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2)[:, :, :sample_height, :sample_width] + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 431a3d5e6f3b..f034cebde12b 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -18,7 +18,6 @@ import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist -import torch.distributed as dist from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin @@ -1077,8 +1076,6 @@ def __init__( self.use_dp = False - self.use_dp = False - # Precompute and cache conv counts for encoder and decoder for clear_cache speedup self._cached_conv_counts = { "decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules()) @@ -1129,9 +1126,12 @@ def enable_dp( r""" """ if world_size is None: - world_size = dist.get_world_size() + world_size = dist.get_world_size() if dist.is_initialized() else 1 if world_size <= 1 or world_size > dist.get_world_size(): + logger.warning( + f"Supported world_size for vae dp is between 2 - {dist.get_world_size}, but got {world_size}. " \ + f"Fall back to normal vae") return if hw_splits is None: @@ -1140,57 +1140,6 @@ def enable_dp( assert len(hw_splits) == 2, f"'hw_splits' should be a tuple of 2 int, but got length {len(hw_splits)}" h_split, w_split = map(int, hw_splits) - num_tiles = h_split * w_split - - # assert h_split * w_split == world_size, \ - # (f"world_size must be {w_split} * {h_split} = {w_split * h_split}, but got {world_size}") - - self.use_dp = True - self.h_split, self.w_split = h_split, w_split - self.world_size = world_size - self.overlap_ratio = overlap_ratio - self.overlap_pixels = overlap_pixels - - dp_ranks = list(range(0, world_size)) - self.vae_dp_group = dist.new_group(ranks=dp_ranks) - self.rank = dist.get_rank() - # patch_ranks_flatten = [tile_idx % world_size for tile_idx in range(num_tiles)] - # self.patch_ranks = torch.Tensor(patch_ranks_flatten).reshape(h_split, w_split) - self.tile_idxs_per_rank = [[] for _ in range(self.world_size)] - self.num_tiles_per_rank = [0] * self.world_size - rank_idx = 0 - for h_idx in range(self.h_split): - for w_idx in range(self.w_split): - rank_idx %= self.world_size - self.tile_idxs_per_rank[rank_idx].append((h_idx, w_idx)) - self.num_tiles_per_rank[rank_idx] += 1 - rank_idx += 1 - - def enable_dp( - self, - world_size: Optional[int] = None, - hw_splits: Optional[Tuple[int, int]] = None, - overlap_ratio: Optional[float] = None, - overlap_pixels: Optional[int] = None - ) -> None: - r""" - """ - if world_size is None: - world_size = dist.get_world_size() - - if world_size <= 1 or world_size > dist.get_world_size(): - return - - if hw_splits is None: - hw_splits = (1, int(world_size)) - - assert len(hw_splits) == 2, f"'hw_splits' should be a tuple of 2 int, but got length {len(hw_splits)}" - - h_split, w_split = map(int, hw_splits) - num_tiles = h_split * w_split - - # assert h_split * w_split == world_size, \ - # (f"world_size must be {w_split} * {h_split} = {w_split * h_split}, but got {world_size}") self.use_dp = True self.h_split, self.w_split = h_split, w_split @@ -1230,6 +1179,9 @@ def _encode(self, x: torch.Tensor): if self.config.patch_size is not None: x = patchify(x, patch_size=self.config.patch_size) + if self.use_dp: + return self.tiled_encode_with_dp(x) + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): return self.tiled_encode(x) @@ -1282,9 +1234,6 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True): tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio - if self.use_dp: - return self.tiled_decode_with_dp(z, return_dict=return_dict) - if self.use_dp: return self.tiled_decode_with_dp(z, return_dict=return_dict) @@ -1355,6 +1304,20 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch. ) return b + def blend_v_(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + y = torch.arange(0, blend_extent, device=a.device) + blend_ratio = (y / blend_extent)[None, None, None, :, None].to(a.dtype) + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - blend_ratio) + b[:, :, :, y, :] * blend_ratio + return b + + def blend_h_(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + x = torch.arange(0, blend_extent, device=a.device) + blend_ratio = (x / blend_extent)[None, None, None, None, :].to(a.dtype) + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - blend_ratio) + b[:, :, :, :, x] * blend_ratio + return b + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: r"""Encode a batch of images using a tiled encoder. @@ -1499,28 +1462,10 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod return (dec,) return DecoderOutput(sample=dec) - def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: - r""" - Decode a batch of images using a tiled decoder. - - Args: - z (`torch.Tensor`): Input batch of latent vectors. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. - - Returns: - [`~models.vae.DecoderOutput`] or `tuple`: - If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is - returned. - """ - _, _, num_frames, height, width = z.shape - device = z.device - sample_height = height * self.spatial_compression_ratio - sample_width = width * self.spatial_compression_ratio - + def calculate_tiled_parallel_size(self, latent_height, latent_width): # Calculate stride based on h_split and w_split - tile_latent_stride_height = int((height + self.h_split - 1) / self.h_split) - tile_latent_stride_width = int((width + self.w_split - 1) / self.w_split) + tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split) + tile_latent_stride_width = int((latent_width + self.w_split - 1) / self.w_split) # Calculate overlap in latent space overlap_latent_height = 3 @@ -1530,52 +1475,135 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni overlap_latent_height = overlap_latent overlap_latent_width = overlap_latent elif self.overlap_ratio is not None: - overlap_latent_height = int(self.overlap_ratio * height) - overlap_latent_width = int(self.overlap_ratio * width) + overlap_latent_height = int(self.overlap_ratio * latent_height) + overlap_latent_width = int(self.overlap_ratio * latent_width) # Calculate minimum tile size in latent space tile_latent_min_height = tile_latent_stride_height + overlap_latent_height tile_latent_min_width = tile_latent_stride_width + overlap_latent_width - # Convert min/stride to sample space tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio - self.tile_sample_min_height = tile_sample_min_height - self.tile_sample_min_width = tile_sample_min_width - self.tile_sample_stride_height = tile_sample_stride_height - self.tile_sample_stride_width = tile_sample_stride_width + blend_latent_height = tile_latent_min_height - tile_latent_stride_height + blend_latent_width = tile_latent_min_width - tile_latent_stride_width if self.config.patch_size is not None: sample_height = sample_height // self.config.patch_size sample_width = sample_width // self.config.patch_size tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size - blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height - blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width + blend_sample_height = tile_sample_min_height // self.config.patch_size - tile_sample_stride_height + blend_sample_width = tile_sample_min_width // self.config.patch_size - tile_sample_stride_width else: - blend_height = self.tile_sample_min_height - tile_sample_stride_height - blend_width = self.tile_sample_min_width - tile_sample_stride_width + blend_sample_height = tile_sample_min_height - tile_sample_stride_height + blend_sample_width = tile_sample_min_width - tile_sample_stride_width + + return \ + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width + + def tiled_encode_with_dp(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, sample_height, sample_width = x.shape + device = x.device + latent_height = sample_height // self.spatial_compression_ratio + latent_width = sample_width // self.spatial_compression_ratio + + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width = \ + self.calculate_tiled_parallel_size(latent_height, latent_width) + + def vae_encode_op( + x, patch_height_start, patch_height_end, patch_width_start, patch_width_end, num_frames + ) -> torch.Tensor: - # Split z into overlapping tiles and decode them separately. - # The tiles have an overlap to avoid seams between tiles. - # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] - num_tile_rows = self.h_split - num_tile_cols = self.w_split - - # Each rank computes only tiles assigned to it based on tile_idxs_per_rank - # local_tiles = [] # List to store tiles computed by this rank - local_tiles = [] - local_hw_shapes = [] - - for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: self.clear_cache() - patch_height_start = h_idx * tile_latent_stride_height - patch_height_end = patch_height_start + tile_latent_min_height - patch_width_start = w_idx * tile_latent_stride_width - patch_width_end = patch_width_start + tile_latent_min_width + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, patch_height_start : patch_height_end, patch_width_start : patch_width_end] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + patch_height_start : patch_height_end, + patch_width_start : patch_width_end, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + time = torch.cat(time, dim=2) + self.clear_cache() + return time + + rows = self.run_vae_tile_parallel( + x, vae_encode_op, + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, device, + num_frames=num_frames + ) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v_(rows[i - 1][j], tile, blend_latent_height) + if j > 0: + tile = self.blend_h_(row[j - 1], tile, blend_latent_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, latent_height, latent_width = z.shape + device = z.device + sample_height = latent_height * self.spatial_compression_ratio + sample_width = latent_width * self.spatial_compression_ratio + + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width = \ + self.calculate_tiled_parallel_size(latent_height, latent_width) + + def vae_decode_op( + z, patch_height_start, patch_height_end, patch_width_start, patch_width_end, num_frames + ) -> torch.Tensor: + + self.clear_cache() + time = [] for k in range(num_frames): self._conv_idx = [0] @@ -1586,40 +1614,14 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni ) time.append(decoded) time = torch.cat(time, dim=2) - local_tiles.append(time.flatten(3, 4)) # flatten h,w dim for concate all tiles in one rank - local_hw_shapes.append(torch.Tensor([*time.shape[3:5]]).to(device).int()) # record hw for futher unflatten self.clear_cache() + return time - # concat all tiles on local rank - local_tiles = torch.cat(local_tiles, dim=3) - local_hw_shapes = torch.stack(local_hw_shapes) - - # get all hw shapes for each rank (perhaps has different shapes for last tile) - gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) - for num_tiles in self.num_tiles_per_rank] - dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) - - # gather tiles on all ranks - b, c, n = local_tiles.shape[:3] - gathered_tiles = [ - torch.empty( - (b, c, n, tiles_shape.prod(dim=1).sum().item()), - dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list - ] - dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) - - # put tiles in rows based on tile_idxs_per_rank - rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] - for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): - rank_tile_hw_shapes = gathered_shape_list[rank_idx] - hw_start_idx = 0 - # perhaps has more than one tile in each rank, get each by hw_shapes - for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): - rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] - hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw - rows[h_idx][w_idx] = gathered_tiles[rank_idx][:, :, :, hw_start_idx:hw_end_idx].unflatten( - 3, rank_tile_hw_shape.tolist()) # unflatten hw dim - hw_start_idx = hw_end_idx + rows = self.run_vae_tile_parallel( + z, vae_decode_op, + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, device, + num_frames=num_frames + ) # combine all tiles, same as tiled decode result_rows = [] @@ -1629,9 +1631,9 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_height) + tile = self.blend_v_(rows[i - 1][j], tile, blend_sample_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_width) + tile = self.blend_h_(row[j - 1], tile, blend_sample_width) result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width]) result_rows.append(torch.cat(result_row, dim=-1)) dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py index 9c6031a988f9..d798711ec240 100644 --- a/src/diffusers/models/autoencoders/vae.py +++ b/src/diffusers/models/autoencoders/vae.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, List import numpy as np import torch import torch.nn as nn +import torch.distributed as dist from ...utils import BaseOutput from ...utils.torch_utils import randn_tensor @@ -926,3 +927,78 @@ def disable_slicing(self): decoding in one step. """ self.use_slicing = False + + def enable_dp(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + if not hasattr(self, "use_tiling"): + raise NotImplementedError(f"Tiling Parallel doesn't seem to be implemented for {self.__class__.__name__}.") + self.use_dp = True + + def disable_dp(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_dp = False + + def run_vae_tile_parallel( + self, + input: torch.Tensor, + vae_op, + min_height, + min_width, + stride_height, + stride_width, + device, + **kwargs) -> List[List[torch.Tensor]]: + + local_tiles = [] + local_hw_shapes = [] + + for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: + patch_height_start = h_idx * stride_height + patch_height_end = patch_height_start + min_height + patch_width_start = w_idx * stride_width + patch_width_end = patch_width_start + min_width + tile = vae_op(input, patch_height_start, patch_height_end, patch_width_start, patch_width_end, **kwargs) + local_tiles.append(tile.flatten(-2, -1)) + local_hw_shapes.append(torch.Tensor([*tile.shape[-2:]]).to(device).int()) + + # concat all tiles on local rank + local_tiles = torch.cat(local_tiles, dim=-1) + local_hw_shapes = torch.stack(local_hw_shapes) + + # get all hw shapes for each rank (perhaps has different shapes for last tile) + gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) + for num_tiles in self.num_tiles_per_rank] + dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) + + # gather tiles on all ranks + tile_shape_first = local_tiles.shape[:-1] + gathered_tiles = [ + torch.empty( + (*tile_shape_first, tiles_shape.prod(dim=1).sum().item()), + dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list + ] + dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) + + # put tiles in rows based on tile_idxs_per_rank + rows = [[None] * self.w_split for _ in range(self.h_split)] + for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): + if not tile_idxs: + continue + rank_tile_hw_shapes = gathered_shape_list[rank_idx] + hw_start_idx = 0 + # perhaps has more than one tile in each rank, get each by hw_shapes + for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): + rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] + hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw + rows[h_idx][w_idx] = gathered_tiles[rank_idx][..., hw_start_idx:hw_end_idx].unflatten( + -1, rank_tile_hw_shape.tolist()) # unflatten hw dim + hw_start_idx = hw_end_idx + + return rows \ No newline at end of file diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 16662e8d8fe8..6762c665e79e 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -39,6 +39,10 @@ from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from mindiesd import attention_forward as mindie_sd_attn_forward + +STREAM_VECTOR = torch.npu.Stream() + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -131,7 +135,7 @@ def __call__( if hasattr(self._parallel_config, "context_parallel_config") and \ self._parallel_config.context_parallel_config is not None: - return self._context_parallel_forward( + return self._context_parallel_forward_qkv( attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, pre_query, pre_key, cal_q ) @@ -287,6 +291,357 @@ def _context_parallel_forward( else: return hidden_states + def _context_parallel_forward_cv( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + pre_query: Optional[torch.Tensor] = None, + pre_key: Optional[torch.Tensor] = None, + cal_q=True + ) -> torch.Tensor: + + ulysses_mesh = self._parallel_config.context_parallel_config._ulysses_mesh + world_size = self._parallel_config.context_parallel_config.ulysses_degree + group = ulysses_mesh.get_group() + + ev_q = torch.npu.Event() + ev_k = torch.npu.Event() + ev_v = torch.npu.Event() + + query = attn.to_q(hidden_states) + query = query.unflatten(-1, (attn.heads, -1)) + ev_q.record() + key = attn.to_k(hidden_states) + key = key.unflatten(-1, (attn.heads, -1)) + ev_k.record() + + value = attn.to_v(hidden_states) + value = value.unflatten(-1, (attn.heads, -1)) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_value = attn.add_v_proj(encoder_hidden_states) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + value = torch.cat([encoder_value, value], dim=1) + B, S_KV_LOCAL, H, D = value.shape + H_LOCAL = H // world_size + value = value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous().flatten() + v_size = value.numel() + ev_v.record() + + with torch.npu.stream(STREAM_VECTOR): + ev_q.wait() + query = attn.norm_q(query) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_query = attn.norm_added_q(encoder_query) + query = torch.cat([encoder_query, query], dim=1) + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + + ev_k.wait() + key = attn.norm_k(key) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_key = attn.norm_added_k(encoder_key) + key = torch.cat([encoder_key, key], dim=1) + if image_rotary_emb is not None: + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + B, S_Q_LOCAL, H, D = query.shape + _, S_KV_LOCAL, _, _ = key.shape + H_LOCAL = H // world_size + + query = query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous().flatten() + key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous().flatten() + q_size = query.numel() + k_size = query.numel() + + ev_v.wait() + qkv_all = funcol.all_to_all_single(torch.cat([query, key, value], dim=0), None, None, group) + + qkv_all = _wait_tensor(qkv_all) + query_all = qkv_all[:q_size].reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + key_all = qkv_all[q_size: q_size + k_size].reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + value_all = qkv_all[-v_size:].reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + # query_all = _wait_tensor(query_all) + # query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + # key_all = _wait_tensor(key_all) + # key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + # value_all = _wait_tensor(value_all) + # value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + + out = npu_fusion_attention( + query_all, + key_all, + value_all, + H_LOCAL, # num_heads + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(D), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + out = out.transpose(1, 2).contiguous() + out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() + out = _all_to_all_single(out, group) + hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + def _context_parallel_forward_vqk( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + pre_query: Optional[torch.Tensor] = None, + pre_key: Optional[torch.Tensor] = None, + cal_q=True + ) -> torch.Tensor: + + ulysses_mesh = self._parallel_config.context_parallel_config._ulysses_mesh + world_size = self._parallel_config.context_parallel_config.ulysses_degree + group = ulysses_mesh.get_group() + + ev_q = torch.npu.Event() + ev_k = torch.npu.Event() + + value = attn.to_v(hidden_states) + value = value.unflatten(-1, (attn.heads, -1)) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_value = attn.add_v_proj(encoder_hidden_states) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + value = torch.cat([encoder_value, value], dim=1) + + B, S_KV_LOCAL, H, D = value.shape + H_LOCAL = H // world_size + value_all = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + + query = attn.to_q(hidden_states) + query = query.unflatten(-1, (attn.heads, -1)) + ev_q.record() + key = attn.to_k(hidden_states) + key = key.unflatten(-1, (attn.heads, -1)) + ev_k.record() + + with torch.npu.stream(STREAM_VECTOR): + ev_q.wait() + query = attn.norm_q(query) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_query = attn.norm_added_q(encoder_query) + query = torch.cat([encoder_query, query], dim=1) + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + + _, S_Q_LOCAL, _, _ = query.shape + query_all = ulysses_preforward(query, group, world_size, B, S_Q_LOCAL, H, D, H_LOCAL) + + ev_k.wait() + key = attn.norm_k(key) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_key = attn.norm_added_k(encoder_key) + key = torch.cat([encoder_key, key], dim=1) + if image_rotary_emb is not None: + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + + value_all = _wait_tensor(value_all) + value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + query_all = _wait_tensor(query_all) + query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + key_all = _wait_tensor(key_all) + key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + # out = npu_fusion_attention( + # query_all, + # key_all, + # value_all, + # H_LOCAL, # num_heads + # input_layout="BNSD", + # pse=None, + # scale=1.0 / math.sqrt(D), + # pre_tockens=65536, + # next_tockens=65536, + # keep_prob=1.0, + # sync=False, + # inner_precise=0, + # )[0] + + out = mindie_sd_attn_forward( + query_all, + key_all, + value_all, + opt_mode="manual", + op_type="ascend_laser_attention", + layout="BNSD" + ) + + # out = out.transpose(1, 2).contiguous() + out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() + out = _all_to_all_single(out, group) + hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + def _context_parallel_forward_qkv( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + pre_query: Optional[torch.Tensor] = None, + pre_key: Optional[torch.Tensor] = None, + cal_q=True + ) -> torch.Tensor: + + ulysses_mesh = self._parallel_config.context_parallel_config._ulysses_mesh + world_size = self._parallel_config.context_parallel_config.ulysses_degree + group = ulysses_mesh.get_group() + + ev_q = torch.npu.Event() + ev_k = torch.npu.Event() + + query = attn.to_q(hidden_states) + query = query.unflatten(-1, (attn.heads, -1)) + ev_q.record() + key = attn.to_k(hidden_states) + key = key.unflatten(-1, (attn.heads, -1)) + ev_k.record() + + value = attn.to_v(hidden_states) + value = value.unflatten(-1, (attn.heads, -1)) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_value = attn.add_v_proj(encoder_hidden_states) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + value = torch.cat([encoder_value, value], dim=1) + + with torch.npu.stream(STREAM_VECTOR): + ev_q.wait() + query = attn.norm_q(query) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_query = attn.norm_added_q(encoder_query) + query = torch.cat([encoder_query, query], dim=1) + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + + B, S_Q_LOCAL, H, D = query.shape + H_LOCAL = H // world_size + query_all = ulysses_preforward(query, group, world_size, B, S_Q_LOCAL, H, D, H_LOCAL) + + ev_k.wait() + key = attn.norm_k(key) + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_key = attn.norm_added_k(encoder_key) + key = torch.cat([encoder_key, key], dim=1) + if image_rotary_emb is not None: + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + _, S_KV_LOCAL, _, _ = key.shape + key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + + value_all = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + + query_all = _wait_tensor(query_all) + query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + key_all = _wait_tensor(key_all) + key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + value_all = _wait_tensor(value_all) + value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + # out = npu_fusion_attention( + # query_all, + # key_all, + # value_all, + # H_LOCAL, # num_heads + # input_layout="BNSD", + # pse=None, + # scale=1.0 / math.sqrt(D), + # pre_tockens=65536, + # next_tockens=65536, + # keep_prob=1.0, + # sync=False, + # inner_precise=0, + # )[0] + + out = mindie_sd_attn_forward( + query_all, + key_all, + value_all, + opt_mode="manual", + op_type="ascend_laser_attention", + layout="BNSD" + ) + # out = out.transpose(1, 2).contiguous() + out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() + out = _all_to_all_single(out, group) + hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + class FluxIPAdapterAttnProcessor(torch.nn.Module): """Flux Attention processor for IP-Adapter.""" diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 2e566f1daaaa..4d0e7fb4b219 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -18,6 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from mindiesd import rotary_position_embedding from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin @@ -36,6 +37,10 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +if torch.distributed.is_available(): + import torch.distributed._functional_collectives as funcol + + def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): # encoder_hidden_states is only passed for cross-attention if encoder_hidden_states is None: @@ -65,6 +70,337 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t return key_img, value_img +def _wait_tensor(tensor): + if isinstance(tensor, funcol.AsyncCollectiveTensor): + tensor = tensor.wait() + return tensor + + +def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: + shape = x.shape + x = x.flatten() + x = funcol.all_to_all_single(x, None, None, group) + x = x.reshape(shape) + x = _wait_tensor(x) + return x + + +def ulysses_preforward( + x: torch.Tensor, + group, + world_size, + B, + S_LOCAL, + H, + D, + H_LOCAL +): + x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + x = x.flatten() + x = funcol.all_to_all_single(x, None, None, group) + return x + + +class WanAttnProcessorSP: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: "WanAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + # 512 is the context length of the text encoder, hardcoded for now + image_context_length = encoder_hidden_states.shape[1] - 512 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + ulysses_mesh = self._parallel_config.context_parallel_config._ulysses_mesh + world_size = self._parallel_config.context_parallel_config.ulysses_degree + group = ulysses_mesh.get_group() + + # query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + value = value.unflatten(2, (attn.heads, -1)) + B, S_KV_LOCAL, H, D = value.shape + H_LOCAL = H // world_size + value_all = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + + + def apply_rotary_emb( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + out = rotary_position_embedding(hidden_states, freqs_cos, freqs_sin, rotated_mode="rotated_interleaved", fused=True) + return out + + query = attn.to_q(hidden_states) + query = attn.norm_q(query) + query = query.unflatten(2, (attn.heads, -1)) + if rotary_emb is not None: + query = apply_rotary_emb(query, *rotary_emb) + _, S_Q_LOCAL, _, _ = query.shape + query_all = ulysses_preforward(query, group, world_size, B, S_Q_LOCAL, H, D, H_LOCAL) + + key = attn.to_k(encoder_hidden_states) + key = attn.norm_k(key) + key = key.unflatten(2, (attn.heads, -1)) + if rotary_emb is not None: + key = apply_rotary_emb(key, *rotary_emb) + key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + + value_all = _wait_tensor(value_all) + value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + query_all = _wait_tensor(query_all) + query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + key_all = _wait_tensor(key_all) + key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + hidden_states = dispatch_attention_fn( + query_all, + key_all, + value_all, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=None, + ) + hidden_states = hidden_states.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() + hidden_states = _all_to_all_single(hidden_states, group) + hidden_states = hidden_states.flatten(0, 1).permute(1, 2, 0, 3).contiguous() + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class WanAttnProcessorCross: + _attention_backend_fix = "_native_npu" + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: "WanAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + # 512 is the context length of the text encoder, hardcoded for now + image_context_length = encoder_hidden_states.shape[1] - 512 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] + + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if rotary_emb is not None: + + # def apply_rotary_emb( + # hidden_states: torch.Tensor, + # freqs_cos: torch.Tensor, + # freqs_sin: torch.Tensor, + # ): + # x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + # cos = freqs_cos[..., 0::2] + # sin = freqs_sin[..., 1::2] + # out = torch.empty_like(hidden_states) + # out[..., 0::2] = x1 * cos - x2 * sin + # out[..., 1::2] = x1 * sin + x2 * cos + # return out.type_as(hidden_states) + + def apply_rotary_emb( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + out = rotary_position_embedding(hidden_states, freqs_cos, freqs_sin, rotated_mode="rotated_interleaved", fused=True) + return out + + query = apply_rotary_emb(query, *rotary_emb) + key = apply_rotary_emb(key, *rotary_emb) + + # I2V task + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img) + key_img = attn.norm_added_k(key_img) + + key_img = key_img.unflatten(2, (attn.heads, -1)) + value_img = value_img.unflatten(2, (attn.heads, -1)) + + hidden_states_img = dispatch_attention_fn( + query, + key_img, + value_img, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend_fix, + parallel_config=None, + ) + hidden_states_img = hidden_states_img.flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend_fix, + parallel_config=None, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class WanAttnProcessorNPU: + _attention_backend_fix = "_native_npu" + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: "WanAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + # 512 is the context length of the text encoder, hardcoded for now + image_context_length = encoder_hidden_states.shape[1] - 512 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] + + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if rotary_emb is not None: + + # def apply_rotary_emb( + # hidden_states: torch.Tensor, + # freqs_cos: torch.Tensor, + # freqs_sin: torch.Tensor, + # ): + # x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + # cos = freqs_cos[..., 0::2] + # sin = freqs_sin[..., 1::2] + # out = torch.empty_like(hidden_states) + # out[..., 0::2] = x1 * cos - x2 * sin + # out[..., 1::2] = x1 * sin + x2 * cos + # return out.type_as(hidden_states) + + def apply_rotary_emb( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + out = rotary_position_embedding(hidden_states, freqs_cos, freqs_sin, rotated_mode="rotated_interleaved", fused=True) + return out + + query = apply_rotary_emb(query, *rotary_emb) + key = apply_rotary_emb(key, *rotary_emb) + + # I2V task + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img) + key_img = attn.norm_added_k(key_img) + + key_img = key_img.unflatten(2, (attn.heads, -1)) + value_img = value_img.unflatten(2, (attn.heads, -1)) + + hidden_states_img = dispatch_attention_fn( + query, + key_img, + value_img, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend_fix, + parallel_config=self._parallel_config, + ) + hidden_states_img = hidden_states_img.flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend_fix, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + class WanAttnProcessor: _attention_backend = None _parallel_config = None @@ -101,18 +437,26 @@ def __call__( if rotary_emb is not None: + # def apply_rotary_emb( + # hidden_states: torch.Tensor, + # freqs_cos: torch.Tensor, + # freqs_sin: torch.Tensor, + # ): + # x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + # cos = freqs_cos[..., 0::2] + # sin = freqs_sin[..., 1::2] + # out = torch.empty_like(hidden_states) + # out[..., 0::2] = x1 * cos - x2 * sin + # out[..., 1::2] = x1 * sin + x2 * cos + # return out.type_as(hidden_states) + def apply_rotary_emb( hidden_states: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, ): - x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) - cos = freqs_cos[..., 0::2] - sin = freqs_sin[..., 1::2] - out = torch.empty_like(hidden_states) - out[..., 0::2] = x1 * cos - x2 * sin - out[..., 1::2] = x1 * sin + x2 * cos - return out.type_as(hidden_states) + out = rotary_position_embedding(hidden_states, freqs_cos, freqs_sin, rotated_mode="rotated_interleaved", fused=True) + return out query = apply_rotary_emb(query, *rotary_emb) key = apply_rotary_emb(key, *rotary_emb) @@ -407,6 +751,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + # if (pad_size := 8 - freqs_cos.shape[1] % 8) != 8: + # b, s, n, d = freqs_cos.shape + # freqs_cos = torch.cat([freqs_cos, freqs_cos.new_ones(b, pad_size, n, d)], dim=1) + # freqs_sin = torch.cat([freqs_sin, freqs_sin.new_ones(b, pad_size, n, d)], dim=1) + return freqs_cos, freqs_sin @@ -432,7 +781,7 @@ def __init__( dim_head=dim // num_heads, eps=eps, cross_attention_dim_head=None, - processor=WanAttnProcessor(), + processor=WanAttnProcessorSP(), ) # 2. Cross-attention @@ -443,7 +792,7 @@ def __init__( eps=eps, added_kv_proj_dim=added_kv_proj_dim, cross_attention_dim_head=dim // num_heads, - processor=WanAttnProcessor(), + processor=WanAttnProcessorCross(), ) self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() @@ -551,9 +900,9 @@ class WanTransformer3DModel( "blocks.0": { "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), }, - "blocks.*": { - "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - }, + # "blocks.*": { + # "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + # }, "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), } @@ -651,6 +1000,10 @@ def forward( hidden_states = self.patch_embedding(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) + # if (pad_size := (8 - hidden_states.shape[1] % 8) % 8) != 0: + # b, s, d = hidden_states.shape + # hidden_states = torch.cat([hidden_states, hidden_states.new_zeros(b, pad_size, d)], dim=1) + # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v) if timestep.ndim == 2: ts_seq_len = timestep.shape[1] @@ -700,6 +1053,7 @@ def forward( hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) hidden_states = self.proj_out(hidden_states) + # hidden_states = hidden_states[:, :-pad_size, :] hidden_states = hidden_states.reshape( batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 732e98c048f0..375a04ea1bd4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -26,6 +26,10 @@ T5TokenizerFast, ) +import time + +import time + from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, FluxTransformer2DModel @@ -680,6 +684,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + prof=None ): r""" Function invoked when calling the pipeline for generation. @@ -779,6 +784,10 @@ def __call__( is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ + torch.cuda.synchronize() + if prof: + prof.start() + t_start = time.time() height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor @@ -925,6 +934,10 @@ def __call__( batch_size * num_images_per_prompt, ) + torch.cuda.synchronize() + if prof: + prof.step() + t_preprocess = time.time() # 6. Denoising loop # We set the index here to remove DtoH sync, helpful especially during compilation. # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 @@ -996,6 +1009,12 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + if prof: + prof.step() + + torch.cuda.synchronize() + t_dit = time.time() + self._current_timestep = None if output_type == "latent": @@ -1006,10 +1025,30 @@ def __call__( image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) + torch.cuda.synchronize() + if prof: + prof.step() + + t_vae = time.time() + # Offload all models self.maybe_free_model_hooks() self.image_rotary_emb = None + is_print = torch.distributed.get_rank() == 0 if torch.distributed.is_initialized() else True + if is_print: + headers = ["Total", "Prepare", "DIT", "VAE", "DIT_PER_STEP"] + time_list = [t_vae - t_start, t_preprocess - t_start, t_dit - t_preprocess, t_vae - t_dit, (t_dit - t_start) / num_inference_steps] + time_list = [f"{t:.3f}" for t in time_list] + widths = [10, 10, 10, 10, 10] + def _fmt_row(values): + return " | ".join(str(values[i]).ljust(widths[i]) for i in range(len(headers))) + + + sep = "-+-".join("-" * w for w in widths) + print(_fmt_row(headers)) + print(_fmt_row(time_list)) + if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index e7226d336ac8..53c3a12d4660 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -14,6 +14,7 @@ import html from typing import Any, Callable, Dict, List, Optional, Union +import time import regex as re import torch @@ -402,6 +403,7 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + prof=None, ): r""" The call function to the pipeline for generation. @@ -473,6 +475,11 @@ def __call__( indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ + torch.cuda.synchronize() + if prof: + prof.start() + t_start = time.time() + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -564,6 +571,11 @@ def __call__( else: boundary_timestep = None + torch.cuda.synchronize() + if prof: + prof.step() + t_preprocess = time.time() + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -629,6 +641,12 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + if prof: + prof.step() + + torch.cuda.synchronize() + t_dit = time.time() + self._current_timestep = None if not output_type == "latent": @@ -647,9 +665,26 @@ def __call__( else: video = latents + torch.cuda.synchronize() + if prof: + prof.step() + t_vae = time.time() + # Offload all models self.maybe_free_model_hooks() + if torch.distributed.get_rank() == 0: + headers = ["Total", "Prepare", "DIT", "VAE", "DIT_PER_STEP"] + time_list = [t_vae - t_start, t_preprocess - t_start, t_dit - t_preprocess, t_vae - t_dit, (t_dit - t_start) / num_inference_steps] + time_list = [f"{t:.3f}" for t in time_list] + widths = [10, 10, 10, 10, 10] + def _fmt_row(values): + return " | ".join(str(values[i]).ljust(widths[i]) for i in range(len(headers))) + + + sep = "-+-".join("-" * w for w in widths) + print(_fmt_row(headers)) + print(_fmt_row(time_list)) self.transformer.rotary_emb = None if self.transformer_2 is not None: diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index b7fd0b05980f..0342836279ed 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -14,6 +14,7 @@ import html from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import time import PIL import regex as re @@ -356,6 +357,15 @@ def check_inputs( if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + if getattr(getattr(self.transformer, "_parallel_config", None), "context_parallel_config", None) is not None: + mesh_size = self.transformer._parallel_config.context_parallel_config._flattened_mesh.size() + mod_size = 16 * 16 * mesh_size + if (height * width ) % mod_size != 0: + raise ValueError( + f"The product of `height` and `width` have to be divisible by 16 * 16 * {mesh_size} " \ + f"when enable context parallel, but are {height} and {width}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -532,6 +542,7 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + prof=None, ): r""" The call function to the pipeline for generation. @@ -613,6 +624,11 @@ def __call__( indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ + torch.cuda.synchronize() + t_start = time.time() + if prof: + prof.start() + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -724,6 +740,11 @@ def __call__( else: boundary_timestep = None + torch.cuda.synchronize() + t_preprocess = time.time() + if prof: + prof.step() + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -794,6 +815,12 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + if prof: + prof.step() + + torch.cuda.synchronize() + t_dit = time.time() + self._current_timestep = None if self.config.expand_timesteps: @@ -815,9 +842,33 @@ def __call__( else: video = latents + torch.cuda.synchronize() + t_vae = time.time() + if prof: + prof.step() + prof.stop() + # Offload all models self.maybe_free_model_hooks() + self.transformer.rotary_emb = None + if self.transformer_2 is not None: + self.transformer_2.rotary_emb = None + + is_print = torch.distributed.get_rank() == 0 if torch.distributed.is_initialized() else True + if is_print: + headers = ["Total", "Prepare", "DIT", "VAE", "DIT_PER_STEP"] + time_list = [t_vae - t_start, t_preprocess - t_start, t_dit - t_preprocess, t_vae - t_dit, (t_dit - t_start) / num_inference_steps] + time_list = [f"{t:.3f}" for t in time_list] + widths = [10, 10, 10, 10, 10] + def _fmt_row(values): + return " | ".join(str(values[i]).ljust(widths[i]) for i in range(len(headers))) + + + sep = "-+-".join("-" * w for w in widths) + print(_fmt_row(headers)) + print(_fmt_row(time_list)) + if not return_dict: return (video,) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index cf77aaee8205..49758d1b2454 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -122,6 +122,7 @@ is_wandb_available, is_xformers_available, is_xformers_version, + is_mindie_sd_available, requires_backends, ) from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index adf8ed8b0694..985bb896fa20 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -229,6 +229,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _aiter_available, _aiter_version = _is_package_available("aiter") _kornia_available, _kornia_version = _is_package_available("kornia") _nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True) +_mindie_sd_available, _mindie_sd_version = _is_package_available("mindiesd") def is_torch_available(): @@ -414,6 +415,9 @@ def is_aiter_available(): def is_kornia_available(): return _kornia_available +def is_mindie_sd_available(): + return _mindie_sd_available + # docstyle-ignore FLAX_IMPORT_ERROR = """