Skip to content

[DO NOT MERGE] trainer ft#21

Draft
fzyzcjy wants to merge 22 commits into
miles-mainfrom
trainer_ft/dev
Draft

[DO NOT MERGE] trainer ft#21
fzyzcjy wants to merge 22 commits into
miles-mainfrom
trainer_ft/dev

Conversation

@fzyzcjy
Copy link
Copy Markdown
Collaborator

@fzyzcjy fzyzcjy commented Apr 2, 2026

No description provided.

fzyzcjy added 3 commits April 1, 2026 22:53
Add _pre_decoder_hooks list and register_pre_decoder_hook() method.
Hooks are called between _preprocess and decoder, allowing external code
to transform decoder_input without Megatron knowing specifics.
Already initialized in __init__, no need for getattr fallback.
@fzyzcjy fzyzcjy changed the title Add register_pre_decoder_hook [DO NOT MERGE] trainer ft Apr 2, 2026
@fzyzcjy fzyzcjy marked this pull request as draft April 2, 2026 02:28
fzyzcjy added 19 commits April 2, 2026 10:36
Remove _pre_decoder_hooks list and register_pre_decoder_hook() method.
Add witness_ids parameter to forward() and build_schedule_plan().
Witness logic is inline: hasattr(self, 'head_witness') check + add to
decoder_input or decoder.input_tensor depending on PP stage.
- build_schedule_plan: accept witness_ids, pass to schedule plan
- TransformerModelChunkSchedulePlan: store witness_ids in chunk_state
- PreProcessNode.forward_impl: apply witness after _preprocess
  (same logic as GPTModel.forward)
- GPTModel.forward: add tail_witness after decoder, before _postprocess
- build_schedule_plan: revert witness_ids param (not supported)
- model_chunk_schedule_plan: revert witness_ids in chunk_state
- fine_grained_callables: revert witness logic in Pre/PostProcessNode
_DataWitness.forward returns [b, s, 1] but Megatron's decoder_input
is in [s, b, h] format after the embedding layer. Without transposing,
broadcasting [s, b, h] + [b, s, 1] creates a [s, s, h] tensor,
causing OOM (648 GiB for s=18432).
When sequence parallel is active, decoder_input and hidden_states
are scattered along the sequence dimension ([s/tp, b, h]). The
witness output must also be scattered to match, otherwise shapes
mismatch (e.g. [18432, 1, 1] vs [9216, 1, h] with TP=2).
…ptimizer

The distributed optimizer replaces optimizer param_groups with shard
main params (fp32 copies). get_main_grads_for_grad_norm checks
_is_witness_param on these main params, but the flag was only set on
the original model params. Copy the flag when building main param
groups for both float16→fp32 and fp32 paths.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant