[Bug fixes] Align clamped SwiGLU computation order and data types with Megatron, and reduce redundancy testcase#1078
Conversation
risemeup1111
left a comment
There was a problem hiding this comment.
本轮审查发现需要先修复的问题,主要集中在 fused_swiglu_scale_backward 的 fallback 返回形状与 CUDA/既有行为不一致;另外 MoE clamp 分支的判定条件也需要和新 wrapper 语义统一。具体建议已放在行内评论中。
当前 CI 中常规检查大多已通过,仍有构建任务在运行,且 approval 检查未满足。
| d_scale = paddle.sum( | ||
| swiglu_val.cast(x.dtype) * out_grad.cast(scale_dtype), | ||
| axis=-1, | ||
| keepdim=True, | ||
| ).cast(scale_dtype) |
There was a problem hiding this comment.
这里把 keepdim 固定为 True 后,CPU/XPU fallback 在 scale 是一维时会返回 [rows, 1],而不是原来的 [rows];CUDA 扩展的 FusedGradInferShape 也按 scale_shape 返回 DScale。这样同一个 fused_swiglu_scale_backward(x, scale, ..., clamp_value=...) 会因为是否走 fallback 得到不同的 d_scale 形状,现有一维 scale 调用/测试会被破坏。请按 scale 的 rank 决定是否保留最后一维,并在 non-clamp 分支做同样处理。
| d_scale = paddle.sum( | |
| swiglu_val.cast(x.dtype) * out_grad.cast(scale_dtype), | |
| axis=-1, | |
| keepdim=True, | |
| ).cast(scale_dtype) | |
| d_scale = paddle.sum( | |
| swiglu_val.cast(x.dtype) * out_grad.cast(scale_dtype), | |
| axis=-1, | |
| keepdim=scale.ndim > 1, | |
| ).cast(scale_dtype) |
| d_scale = paddle.sum( | ||
| out_grad.cast(paddle.float32) * swiglu_val.cast(paddle.float32), | ||
| axis=-1, | ||
| keepdim=True, | ||
| ).cast(scale.dtype) |
There was a problem hiding this comment.
non-clamp fallback 也被改成了始终 keepdim=True,这会把一维 scale 的 d_scale 从 [rows] 变成 [rows, 1],与已有 CPU fallback 行为、仓库中一维 scale 测试以及 CUDA DScale infer shape 都不一致。请保持与 scale 的 rank 对齐,避免仅 fallback 路径改变公共返回形状。
| d_scale = paddle.sum( | |
| out_grad.cast(paddle.float32) * swiglu_val.cast(paddle.float32), | |
| axis=-1, | |
| keepdim=True, | |
| ).cast(scale.dtype) | |
| d_scale = paddle.sum( | |
| out_grad.cast(paddle.float32) * swiglu_val.cast(paddle.float32), | |
| axis=-1, | |
| keepdim=scale.ndim > 1, | |
| ).cast(scale.dtype) |
|
|
||
| if self.clamp_value is not None: | ||
| o2 = fused_swiglu_scale_clamp_forward( | ||
| o2 = fused_swiglu_scale_forward( |
There was a problem hiding this comment.
新的 fused_swiglu_scale_forward 只有在 clamp_value > 0 时才启用 clamp,但这里仍然用 is not None 选择 clamp 分支。这样 activation_func_clamp_value=0.0 或负数时,forward 会通过 wrapper 实际走 non-clamp,而同类的 backward/FP8 分支仍按 is not None 调用 clamp kernel,容易造成前后向语义不一致。请把 MoE 相关分支统一成和 wrapper 一样的正数判定。
| o2 = fused_swiglu_scale_forward( | |
| if self.clamp_value is not None and self.clamp_value > 0: |
PaddleFleet Log Analysis
日志分析报告
失败的测试 case: 根本原因分析: PR 修改了 修复建议:
🔄 每次 Re-run 后自动更新 |
PaddlePaddle-bot
left a comment
There was a problem hiding this comment.
🤖 Paddle-CI-Agent | pr_review |
2026-05-31 02:00:20
📋 Review 摘要
PR 概述:对齐 clamped SwiGLU 的计算顺序和数据类型与 Megatron 一致,合并冗余的 clamp/non-clamp 代码路径,精简测试用例。
变更范围:fusions/、transformer/mlp.py、transformer/moe/fp8_utils.py、paddlefleet_ops/_extensions/fuse_swiglu_scale.cu
影响面 Tag:Fusions MoE OP
问题
未发现阻塞性问题。代码变更逻辑正确,精度对齐目标明确。
历史 Findings 修复情况
| Finding | 问题 | 状态 |
|---|---|---|
| F1 | weights_grad CPU fallback 累加精度与 CUDA kernel 存在差异 |
✅ 已修复 |
F1 修复说明:CPU 侧
clamped_weighted_swiglu_back现在使用clamped_swiglu(y, clamp_value) * g.cast(w_dtype)计算weights_grad,与 CUDA kernel 的sum(swiglu_val.cast(dtype) * d_out.cast(scale_dtype))公式对齐,消除了原先 CPU 全程 float32 与 CUDA native-type 乘法之间的精度差异。
📝 PR 规范检查
✓ 标题格式合规([Bug fixes] Tag 匹配 diff 内容),描述结构完整(含 PR Category / PR Types / Description / 精度变化说明)。
总体评价
本 PR 通过统一 clamp/non-clamp 代码路径(移除 ClampedWeightedSwiGLUFunction、合并 clamped_weighted_bias_swiglu_impl 到 weighted_bias_swiglu_impl)显著降低了维护复杂度。CUDA kernel 和 Python fallback 的 d_scale 计算均对齐到 sum(swiglu_val.cast(dtype) * d_out.cast(scale_dtype)) 公式,与 Megatron 参考实现保持 bit-exact 一致。代码质量良好,可合入。
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #1078 +/- ##
===========================================
Coverage ? 100.00%
===========================================
Files ? 4
Lines ? 99
Branches ? 12
===========================================
Hits ? 99
Misses ? 0
Partials ? 0
Flags with carried forward coverage won't be shown. Click here to find out more.
🚀 New features to boost your workflow:
|
PR Category
Operator Mechanism
PR Types
Bug fixes
Description
Align clamped SwiGLU computation order and data types with Megatron, and reduce redundancy testcase
是否引起精度变化
是, 只用造成clampswiglu的精度变化。