diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index aa9a36409171..037dbc0a4eb1 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -52,20 +52,19 @@ from ..attention_dispatch import npu_fusion_attention def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None, cal_q=True): - if cal_q: - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) encoder_query = encoder_key = encoder_value = None - if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None and cal_q: encoder_query = attn.add_q_proj(encoder_hidden_states) encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states) if cal_q: return query, key, value, encoder_query, encoder_key, encoder_value else: - return value, encoder_query, encoder_key, encoder_value + return query, key, value def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) @@ -117,6 +116,7 @@ class FluxAttnProcessor: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + self.double_stream = bool(int(os.environ.get("DOUBLE_STREAM", 1))) def __call__( self, @@ -261,14 +261,15 @@ def _context_parallel_forward( torch_npu._npu_flash_attention_unpad(query_all, key_all, value_all, seq_len, 1/math.sqrt(D), N, N, out) out = out.view(B, S, N, D).contiguous() + out = out.to(query.dtype) out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() - out = _all_to_all_single(out, group) - hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() - - hidden_states = hidden_states.flatten(2, 3) - hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: + out = _all_to_all_single(out, group) + hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 ) @@ -278,6 +279,9 @@ def _context_parallel_forward( return hidden_states, encoder_hidden_states else: + out = out.flatten() + out = funcol.all_to_all_single(out, None, None, group) + hidden_states = out.reshape(world_size, H_LOCAL, B, S_Q_LOCAL, D).flatten(0, 1).permute(1, 2, 0, 3) return hidden_states @@ -529,13 +533,19 @@ def forward( residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) - mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + # mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) joint_attention_kwargs = joint_attention_kwargs or {} attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs, ) + mlp_hidden_states = self.proj_mlp(norm_hidden_states) + attn_output = _wait_tensor(attn_output) + attn_output = attn_output.contiguous() + if attn_output.ndim == 4: + attn_output = attn_output.flatten(2, 3) + mlp_hidden_states = self.act_mlp(mlp_hidden_states) hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) gate = gate.unsqueeze(1) @@ -576,6 +586,8 @@ def __init__( self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + self.double_stream = bool(int(os.environ.get("DOUBLE_STREAM", 1))) + def forward( self, hidden_states: torch.Tensor, @@ -584,21 +596,44 @@ def forward( image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + joint_attention_kwargs = joint_attention_kwargs or {} norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) - joint_attention_kwargs = joint_attention_kwargs or {} - - # Attention. - attention_outputs = self.attn( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb, - **joint_attention_kwargs, - ) - + + if self.double_stream: + emb = self.norm1.linear(self.norm1.silu(temb)) + current_event.record(current_stream) + + with torch.npu.stream(stream2): + stream2.wait_event(current_event) + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=emb, skip_matmul=True) + event2.record(stream2) + + pre_encoder_query = self.attn.add_q_proj(norm_encoder_hidden_states) + pre_encoder_key = self.attn.add_k_proj(norm_encoder_hidden_states) + pre_encoder_value = self.attn.add_v_proj(norm_encoder_hidden_states) + current_stream.wait_event(event2) + + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + pre_encoder_query=pre_encoder_query, + pre_encoder_key=pre_encoder_key, + pre_encoder_value=pre_encoder_value, + cal_q=False, + **joint_attention_kwargs, + ) + else: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) if len(attention_outputs) == 2: attn_output, context_attn_output = attention_outputs elif len(attention_outputs) == 3: @@ -611,26 +646,56 @@ def forward( norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - ff_output = self.ff(norm_hidden_states) - ff_output = gate_mlp.unsqueeze(1) * ff_output + if self.double_stream: + current_event.record(current_stream) + with torch.npu.stream(stream2): + stream2.wait_event(current_event) + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output - hidden_states = hidden_states + ff_output - if len(attention_outputs) == 3: - hidden_states = hidden_states + ip_attn_output + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] - # Process attention outputs for the `encoder_hidden_states`. - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output - encoder_hidden_states = encoder_hidden_states + context_attn_output + event2.record(stream2) - norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + current_stream.wait_event(event2) - context_ff_output = self.ff_context(norm_encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output - if encoder_hidden_states.dtype == torch.float16: - encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + ff_output = gate_mlp.unsqueeze(1) * ff_output + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output - return encoder_hidden_states, hidden_states + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + else: + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states class FluxPosEmbed(nn.Module):