Skip to content

support ring fp4 comm and sla sparse#933

Merged
helloyongyang merged 4 commits intomainfrom
yr/ring_fp4_comm_and_sla_sparse
Mar 10, 2026
Merged

support ring fp4 comm and sla sparse#933
helloyongyang merged 4 commits intomainfrom
yr/ring_fp4_comm_and_sla_sparse

Conversation

@STwangyingrui
Copy link
Contributor

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the lightx2v framework by integrating FP4 communication capabilities into the Ring Attention mechanism, offering a more diverse range of quantization options for tensor exchange. It also refines the sparse attention functionality with a new control parameter and improves the flexibility of inference step management within the audio scheduler.

Highlights

  • FP4 Communication Support: Introduced support for FP4 quantization and dequantization in Ring Attention, allowing for more flexible communication alongside existing FP8 capabilities.
  • SLA Sparse Attention Parameter: Added a use_sla_sparse parameter to the sageattn3_sparse_blackwell function call, enabling control over sparse attention behavior.
  • Refactored Scheduler infer_steps: Modified the audio scheduler to make the infer_steps parameter optional in the prepare method, allowing dynamic override or fallback to the configured value.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • lightx2v/common/ops/attn/ring_attn.py
    • Imported FP4 quantization and dequantization functions from sageattn3_sparse.
    • Updated the apply method to accept use_fp4_comm and added an assertion for mutual exclusivity with use_fp8_comm.
    • Modified tensor send/receive and dequantization logic to conditionally handle FP4 or FP8 based on the provided flags.
    • Refactored _quant_and_send and _dequantize_received methods to support both FP4 and FP8 quantization/dequantization with appropriate tensor reshaping.
    • Updated _send_recv_tensor to pass the FP4 communication flag.
  • lightx2v/common/ops/attn/sage_attn.py
    • Added use_sla_sparse=False argument to the sageattn3_sparse_blackwell function call.
  • lightx2v/models/schedulers/wan/audio/scheduler.py
    • Removed infer_steps initialization from the __init__ method.
    • Modified the prepare method to allow infer_steps to be an optional argument, defaulting to the configured value if not explicitly provided.
Activity
  • No specific activity (comments, reviews, or progress updates) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for FP4 communication in Ring Attention and adds a flag for SLA sparse attention. The changes in ring_attn.py correctly add the FP4 communication path alongside the existing FP8 path. The modifications in sage_attn.py and scheduler.py are minor and appropriate. My main feedback is on refactoring a section of duplicated code in ring_attn.py to improve maintainability. Overall, the changes are well-structured to support the new features.

Comment on lines +343 to +368
assert not (use_fp8_comm and use_fp4_comm), "use_fp8_comm and use_fp4_comm can't be enabled at the same time."
B, H, N, D2 = next_tensor_quant.shape
D = D2 * 2
D16 = D // 16
if use_kv_fusion and is_kv_fusion:
# KV 融合模式
return dequant_fp8_vllm(next_tensor_fp8, next_tensor_scale, original_dtype)
if use_fp8_comm:
return dequant_fp8_vllm(next_tensor_quant, next_tensor_scale, original_dtype)
else:
return dequant_fp4_sage3(next_tensor_quant.reshape(1, 1, -1, D2), next_tensor_scale.reshape(1, 1, -1, D16)).reshape(B, H, N, D)
elif not use_kv_fusion:
# 分离模式
k = dequant_fp8_vllm(next_tensor_fp8, next_tensor_scale, original_dtype)
v = dequant_fp8_vllm(v_fp8, v_scale, original_dtype)
return k, v
if use_fp8_comm:
k = dequant_fp8_vllm(next_tensor_quant, next_tensor_scale, original_dtype)
v = dequant_fp8_vllm(v_quant, v_scale, original_dtype)
return k, v
else:
k = dequant_fp4_sage3(next_tensor_quant.reshape(1, 1, -1, D2), next_tensor_scale.reshape(1, 1, -1, D16)).reshape(B, H, N, D)
v = dequant_fp4_sage3(v_quant.reshape(1, 1, -1, D2), v_scale.reshape(1, 1, -1, D16)).reshape(B, H, N, D)
return k, v
else:
# 默认返回单个张量
return dequant_fp8_vllm(next_tensor_fp8, next_tensor_scale, original_dtype)
if use_fp8_comm:
return dequant_fp8_vllm(next_tensor_quant, next_tensor_scale, original_dtype)
else:
return dequant_fp4_sage3(next_tensor_quant.reshape(1, 1, -1, D2), next_tensor_scale.reshape(1, 1, -1, D16)).reshape(B, H, N, D)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function _dequantize_received contains duplicated logic for handling FP8 and FP4 dequantization within its conditional branches (if use_kv_fusion and is_kv_fusion, elif not use_kv_fusion, and else). This makes the code harder to read and maintain.

You can refactor this by extracting the dequantization logic into a nested helper function. This will remove redundancy and make the parent function's logic clearer.

        assert not (use_fp8_comm and use_fp4_comm), "use_fp8_comm and use_fp4_comm can't be enabled at the same time."

        def _dequant(tensor_quant, scale):
            if use_fp8_comm:
                return dequant_fp8_vllm(tensor_quant, scale, original_dtype)
            
            B, H, N, D2 = tensor_quant.shape
            D = D2 * 2
            D16 = D // 16
            return dequant_fp4_sage3(tensor_quant.reshape(1, 1, -1, D2), scale.reshape(1, 1, -1, D16)).reshape(B, H, N, D)

        if use_kv_fusion and is_kv_fusion:
            # KV 融合模式
            return _dequant(next_tensor_quant, next_tensor_scale)
        elif not use_kv_fusion:
            # 分离模式
            k = _dequant(next_tensor_quant, next_tensor_scale)
            v = _dequant(v_quant, v_scale)
            return k, v
        else:
            # 默认返回单个张量
            return _dequant(next_tensor_quant, next_tensor_scale)

@helloyongyang helloyongyang merged commit b845787 into main Mar 10, 2026
2 checks passed
@helloyongyang helloyongyang deleted the yr/ring_fp4_comm_and_sla_sparse branch March 10, 2026 08:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants