Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 105 additions & 14 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
is_torch_xla_version,
is_xformers_available,
is_xformers_version,
is_mindie_sd_available,
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS

Expand All @@ -63,6 +64,7 @@
_CAN_USE_NPU_ATTN = is_torch_npu_available()
_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)
_CAN_USE_MINDIESD_ATTN = is_mindie_sd_available()


if _CAN_USE_FLASH_ATTN:
Expand Down Expand Up @@ -142,6 +144,13 @@
else:
xops = None


if _CAN_USE_MINDIESD_ATTN:
from mindiesd import attention_forward as mindie_sd_attn_forward
else:
mindie_sd_attn_forward = None


# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
if torch.__version__ >= "2.4.0":
_custom_op = torch.library.custom_op
Expand Down Expand Up @@ -215,6 +224,9 @@ class AttentionBackendName(str, Enum):
# `xformers`
XFORMERS = "xformers"

# mindie_sd
_MINDIE_SD_LASER = "_mindie_sd_la"


class _AttentionBackendRegistry:
_backends = {}
Expand Down Expand Up @@ -254,7 +266,7 @@ def list_backends(cls):
def _is_context_parallel_enabled(
cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"]
) -> bool:
supports_context_parallel = backend in cls._supports_context_parallel
supports_context_parallel = backend in cls._supports_context_parallel and cls._supports_context_parallel[backend]
is_degree_greater_than_1 = parallel_config is not None and (
parallel_config.context_parallel_config.ring_degree > 1
or parallel_config.context_parallel_config.ulysses_degree > 1
Expand Down Expand Up @@ -470,6 +482,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`."
)

elif backend == AttentionBackendName._MINDIE_SD_LASER:
if not _CAN_USE_MINDIESD_ATTN:
raise RuntimeError(
f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`."
)


@functools.lru_cache(maxsize=128)
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
Expand Down Expand Up @@ -907,19 +925,14 @@ def _npu_attention_forward_op(
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
# if enable_gqa:
# raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.")
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")

# tensors_to_save = ()

# Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results
# if the input tensors are not contiguous.
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
# tensors_to_save += (query, key, value)

out = npu_fusion_attention(
query,
Expand All @@ -936,14 +949,6 @@ def _npu_attention_forward_op(
inner_precise=0,
)[0]

# tensors_to_save += (out)
# if _save_ctx:
# ctx.save_for_backward(*tensors_to_save)
# ctx.dropout_p = dropout_p
# ctx.is_causal = is_causal
# ctx.scale = scale
# ctx.attn_mask = attn_mask

out = out.transpose(1, 2).contiguous()
return out

Expand All @@ -959,6 +964,45 @@ def _npu_attention_backward_op(
raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.")


def _mindie_sd_laser_attn_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for MindIE SD Laser Attention.")
if return_lse:
raise ValueError("MindIE SD attention backend does not support setting `return_lse=True`.")

out = mindie_sd_attn_forward(
query,
key,
value,
opt_mode="manual",
op_type="ascend_laser_attention",
layout="BNSD"
)

return out

def _mindie_sd_laser_attn_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
raise NotImplementedError("Backward pass is not implemented for MindIE SD Laser Attention.")


# ===== Context parallel =====


Expand Down Expand Up @@ -1126,6 +1170,7 @@ def forward(
backward_op,
_parallel_config: Optional["ParallelConfig"] = None,
):
# print(f"[YYT DEBUG] >>>>> ulysses")
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
world_size = _parallel_config.context_parallel_config.ulysses_degree
group = ulysses_mesh.get_group()
Expand Down Expand Up @@ -1776,6 +1821,7 @@ def _native_math_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_NPU,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
)
def _native_npu_attention(
query: torch.Tensor,
Expand Down Expand Up @@ -2095,3 +2141,48 @@ def _xformers_attention(
out = out.flatten(2, 3)

return out


@_AttentionBackendRegistry.register(
AttentionBackendName._MINDIE_SD_LASER,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
)
def _mindie_sd_laser_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if return_lse:
raise ValueError("MINDIE SD attention backend does not support setting `return_lse=True`.")
if _parallel_config is None:
# query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
out = mindie_sd_attn_forward(
query,
key,
value,
opt_mode="manual",
op_type="ascend_laser_attention",
layout="BNSD"
)
# out = out.transpose(1, 2).contiguous()
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
None,
scale,
None,
return_lse,
forward_op=_mindie_sd_laser_attn_forward_op,
backward_op=_mindie_sd_laser_attn_backward_op,
_parallel_config=_parallel_config,
)
return out
Loading