dsv4 decode_fwd: 7-layer perf config, drop lm_head tail#664
dsv4 decode_fwd: 7-layer perf config, drop lm_head tail#664zhangqi-chen wants to merge 1 commit into
Conversation
- Reduce HCA_NUM_LAYERS 20->2 (CSA 3, FWD 7): 2 SWA + 2x(CSA+HCA) + 1 tail CSA = 7 layers, for a fast ep2 smoke/perf run. - Remove the lm_head_tp tail from l3_decode_fwd: host now outputs hidden_norm [N_RANKS, T, D] (post-final-norm hidden) instead of logits. Drops lm_head_weight input, the logits output, and the lm_head window setup loop; spec builder outputs hidden_norm. PASS a2a3 ep2 (--ptoas 0.46) with l2-swimlane + scope-stats.
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request removes the language model head (lm_head) processing from the l3_decode_fwd function, shifting its output from logits to hidden_norm. It also reduces the layer counts for testing or smaller configurations and updates the tensor specifications accordingly. The review feedback recommends cleaning up now-unused imports (VOCAB_PER_TP) and dead code (LM_HEAD_NAMES and _make_lm_head_spec) resulting from these changes.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| from lm_head import ( | ||
| TP_SIZE as LM_HEAD_ACTIVE_TP_SIZE, | ||
| T_MAX as LM_HEAD_T_MAX, | ||
| VOCAB_PER_TP, | ||
| lm_head_tp, | ||
| ) |
There was a problem hiding this comment.
Since the lm_head tail has been dropped from l3_decode_fwd, VOCAB_PER_TP is no longer used in any active code path. We should remove it from the import statement to keep the code clean. Note that LM_HEAD_ACTIVE_TP_SIZE is still used in an assertion on line 109, so we must keep it for now.
from lm_head import (
TP_SIZE as LM_HEAD_ACTIVE_TP_SIZE,
)| elif name in FINAL_NORM_NAMES: | ||
| specs.append(_make_final_norm_spec(name)) | ||
| elif name in LM_HEAD_NAMES: | ||
| specs.append(_make_lm_head_spec(name)) | ||
| else: |
Summary
hidden_norm[N_RANKS, T, D] (post-final-norm hidden) instead of logits. Removes the lm_head_weight input, the logits output, and the lm_head window setup loop; spec builder outputshidden_norm.