I'm working on implementing the 4K image generation approach using NATTEN as described in the GNA paper. The paper demonstrates accelerating Flux for high-resolution image generation.
I'm modifying the attention processor of URAE (which enables Flux to generate 4K images) to leverage na1d for efficient attention computation. For image generation tasks, each image token needs to interact with text tokens - a pattern that might be described as "self-cross attention" as mentioned in issue #82.
My implementation leverages the additional_key/value parameters of na1d to enable this cross-domain interaction. I'm basing my approach on the URAE attention processor implementation found here: URAE/attention_processor.py#L82
hidden_states = na1d(query.transpose(1, 2)[:,512:,:,:], key.transpose(1, 2)[:,512:,:,:], value.transpose(1, 2)[:,512:,:,:], \
kernel_size=80, stride=16, \
additional_keys=key.transpose(1, 2)[:,:512,:,:], additional_values=value.transpose(1, 2)[:,:512,:,:], \
backend="cutlass-fna", \
attention_kwargs={"backend": "cutlass-fmha"} )
text_hidden_states = na1d(query.transpose(1, 2)[:,:512,:,:], key.transpose(1, 2)[:,:512,:,:], value.transpose(1, 2)[:,:512,:,:], \
kernel_size=512, \
additional_keys=key.transpose(1, 2)[:,512:,:,:], additional_values=value.transpose(1, 2)[:,512:,:,:], \
backend="cutlass-fna", \
attention_kwargs={"backend": "cutlass-fmha"} )
hidden_states = torch.cat([text_hidden_states, hidden_states], dim=1)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
When running inference with this implementation, I'm getting noisy/corrupted images rather than proper high-resolution outputs. I've experimented with different parameter configurations but haven't been able to achieve the results described in the paper.
I appreciate any guidance or pointers to relevant resources. Thank you for your time!
Hello,
I'm working on implementing the 4K image generation approach using NATTEN as described in the GNA paper. The paper demonstrates accelerating Flux for high-resolution image generation.
Environment
Hardware: A800 GPU
NATTEN version: 0.20.0
CUDA: 12.6
PyTorch: 2.7.0
Implementation Approach
I'm modifying the attention processor of URAE (which enables Flux to generate 4K images) to leverage na1d for efficient attention computation. For image generation tasks, each image token needs to interact with text tokens - a pattern that might be described as "self-cross attention" as mentioned in issue #82.
My implementation leverages the additional_key/value parameters of na1d to enable this cross-domain interaction. I'm basing my approach on the URAE attention processor implementation found here: URAE/attention_processor.py#L82
When running inference with this implementation, I'm getting noisy/corrupted images rather than proper high-resolution outputs. I've experimented with different parameter configurations but haven't been able to achieve the results described in the paper.
Questions
I appreciate any guidance or pointers to relevant resources. Thank you for your time!