diff --git a/examples/mimo/blend_files/1t_phase1var_moresft_wrapper.yaml b/examples/mimo/blend_files/1t_phase1var_moresft_wrapper.yaml new file mode 100644 index 00000000000..ed19e692d6f --- /dev/null +++ b/examples/mimo/blend_files/1t_phase1var_moresft_wrapper.yaml @@ -0,0 +1,6 @@ +# RKarimi 3B-nano SOTA 1T text subset blend. +# The 3B-nano baseline uses TRAIN_SAMPLES=122070313 and SEQ_LEN=8192, +# which is 1,000,000,004,096 tokens. +__module__: megatron.energon +__class__: McoreBlend +mcore_json: /scratch/fsw/portfolios/llmservice/projects/llmservice_fm_text/users/rkarimimahab/workspace/blends/1T-phase1var-moresft.json diff --git a/examples/mimo/blend_files/text_omnicorpus_blend_10_90_hel.yaml b/examples/mimo/blend_files/text_omnicorpus_blend_10_90_hel.yaml new file mode 100644 index 00000000000..66f1f4ccb70 --- /dev/null +++ b/examples/mimo/blend_files/text_omnicorpus_blend_10_90_hel.yaml @@ -0,0 +1,365 @@ +# 90% RKarimi 1T text subset + 10% OmniCorpus (CC-MAIN-2021-25 excluded - corrupt tar) +__module__: megatron.energon +__class__: MetadatasetV2 +splits: + train: + blend: + - weight: 0.9 + path: __MEGATRON_ROOT__/examples/mimo/blend_files/1t_phase1var_moresft_wrapper.yaml + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2013-20 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2013-48 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2014-10 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2014-15 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2014-23 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2014-35 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2014-41 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2014-42 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2014-49 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2014-52 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2015-06 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2015-11 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2015-14 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2015-18 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2015-22 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2015-27 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2015-32 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2015-35 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2015-40 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2015-48 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2016-07 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2016-18 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2016-22 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2016-26 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2016-30 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2016-36 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2016-40 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2016-44 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2016-50 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2017-04 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2017-09 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2017-13 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2017-17 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2017-22 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2017-26 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2017-30 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2017-34 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2017-39 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2017-43 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2017-47 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2017-51 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2018-05 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2018-09 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2018-13 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2018-17 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2018-26 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2018-30 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2018-34 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2018-39 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2018-43 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2018-47 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2018-51 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2019-04 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2019-09 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2019-13 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2019-18 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2019-22 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2019-26 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2019-30 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2019-35 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2019-39 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2019-43 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2019-47 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2019-51 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2020-05 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2020-10 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2020-16 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2020-24 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2020-29 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2020-34 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2020-40 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2020-45 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2020-50 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2021-04 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2021-10 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2021-17 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2021-21 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2021-31 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2021-39 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2021-43 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2021-49 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2022-05 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2022-21 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2022-27 + subflavors: + cook: omnicorpus + - weight: 0.001163 + path: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/multimodal/datasets/OmniCorpus-CC-210M/webdataset/CC-MAIN-2022-33 + subflavors: + cook: omnicorpus + val: + blend: + - path: __MULTIMODAL_DATA_ROOT__/validation/text_arxiv_math/data.bin + subflavors: + cook: bin_idx + - path: __MULTIMODAL_DATA_ROOT__/validation/text_cc/data.bin + subflavors: + cook: bin_idx + - path: __MULTIMODAL_DATA_ROOT__/validation/text_python/data.bin + subflavors: + cook: bin_idx + - path: __MULTIMODAL_DATA_ROOT__/validation/mint_arxiv + subflavors: + cook: interleaved + - path: __MULTIMODAL_DATA_ROOT__/validation/mint_pdf + subflavors: + cook: interleaved diff --git a/examples/mimo/blend_files/text_only_1t_hel.yaml b/examples/mimo/blend_files/text_only_1t_hel.yaml new file mode 100644 index 00000000000..8ee3ced115b --- /dev/null +++ b/examples/mimo/blend_files/text_only_1t_hel.yaml @@ -0,0 +1,8 @@ +# HEL text-only Energon blend for MIMO jitter isolation. +__module__: megatron.energon +__class__: MetadatasetV2 +splits: + train: + blend: + - weight: 1.0 + path: __MEGATRON_ROOT__/examples/mimo/blend_files/1t_phase1var_moresft_wrapper.yaml diff --git a/examples/mimo/data/__init__.py b/examples/mimo/data/__init__.py index df73bc4abd5..be521ff65cd 100644 --- a/examples/mimo/data/__init__.py +++ b/examples/mimo/data/__init__.py @@ -1,5 +1,11 @@ -from .energon_avlm_task_encoder import VisionAudioQASample +"""MIMO data providers and task encoders.""" -all = [ - VisionAudioQASample, -] +__all__ = ["VisionAudioQASample"] + + +def __getattr__(name): + if name == "VisionAudioQASample": + from .energon_avlm_task_encoder import VisionAudioQASample + + return VisionAudioQASample + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/examples/mimo/data/energon_multimodal_provider.py b/examples/mimo/data/energon_multimodal_provider.py new file mode 100644 index 00000000000..0e754a52758 --- /dev/null +++ b/examples/mimo/data/energon_multimodal_provider.py @@ -0,0 +1,382 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Energon multimodal data provider for MIMO. + +This module intentionally mirrors the provider used by the previous +``feat/nemotron-moe-vlm-mimo`` branch. Energon's ``MultiModalPackingEncoder`` +owns sample cooking, preencoding, and packing; the MIMO-specific adapter only +expands each single ```` placeholder into one placeholder per image +embedding and remaps the batch to MIMO's forward signature. +""" + +from __future__ import annotations + +import inspect +import warnings +from typing import Optional + + +def _supported_kwargs(fn, kwargs): + """Drop kwargs the target callable doesn't accept. + + Lets the caller pass a superset of recipe args without erroring on fields + that the installed energon's VisionConfig doesn't recognize. + """ + params = inspect.signature(fn).parameters + if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values()): + return kwargs + return {key: value for key, value in kwargs.items() if key in params} + + +import torch + +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.energon import WorkerConfig +from megatron.energon.task_encoder.multimodal import ( + MultiModalPackingEncoder, + PackingConfig, + VisionConfig, +) +from megatron.energon.task_encoder.multimodal.sample_types import PackedSample +from megatron.energon.task_encoder.multimodal.vision_tokens import get_num_image_embeddings + + +class TokenizerAdapter: + """Wrap Megatron tokenizers for Energon's tokenizer protocol.""" + + def __init__(self, megatron_tokenizer) -> None: + self._tok = megatron_tokenizer + inner = megatron_tokenizer + if hasattr(inner, "_tokenizer"): + inner = inner._tokenizer + if hasattr(inner, "tokenizer"): + inner = inner.tokenizer + self._hf = inner + + @property + def pad_token_id(self) -> int: + """Return the tokenizer pad id.""" + return self._tok.pad + + @property + def eos_token_id(self) -> int: + """Return the tokenizer EOS id.""" + return self._tok.eod + + def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: + """Encode text with the wrapped HuggingFace tokenizer.""" + return self._hf.encode(text, add_special_tokens=add_special_tokens) + + def decode(self, token_ids, skip_special_tokens: bool = False) -> str: + """Decode token ids with the wrapped HuggingFace tokenizer.""" + return self._hf.decode(token_ids, skip_special_tokens=skip_special_tokens) + + def convert_tokens_to_ids(self, tokens): + """Convert tokens to ids with the wrapped Megatron tokenizer.""" + return self._tok.convert_tokens_to_ids(tokens) + + @property + def chat_template(self): + """Forward HuggingFace chat_template so energon's tokenize_and_prepare can find it.""" + return getattr(self._hf, "chat_template", None) + + def apply_chat_template(self, *args, **kwargs): + """Forward to underlying HuggingFace tokenizer for energon's chat-template path.""" + return self._hf.apply_chat_template(*args, **kwargs) + + +class MimoMultiModalPackingEncoder(MultiModalPackingEncoder): + """Remap Energon multimodal packed samples to MIMO batch inputs.""" + + # Key under which the producing Energon worker's ``global_worker_id`` is + # stamped on each output batch when ``attach_provenance`` is enabled. + # Hetero MIMO uses this to route samples back to their LLM data lane when + # a single encoder-side Energon iterator multiplexes several lanes. + PROVENANCE_KEY = "__encoder_provenance__" + + def __init__( + self, + vision_config: VisionConfig, + packing_config: PackingConfig, + tokenizer, + encoder_name: str = "radio_encoder", + encoder_input_key: str = "x", + target_seq_length: Optional[int] = None, + attach_provenance: bool = False, + ) -> None: + super().__init__(vision_config, packing_config, tokenizer) + self.encoder_name = encoder_name + self.encoder_input_key = encoder_input_key + self._target_seq_length = target_seq_length + self._attach_provenance = attach_provenance + self._embeddings_per_tile = get_num_image_embeddings( + img_h=vision_config.img_h, + img_w=vision_config.img_w, + patch_dim=vision_config.patch_dim, + class_token_len=vision_config.class_token_len, + disable_vision_class_token=vision_config.disable_vision_class_token, + pixel_shuffle=vision_config.pixel_shuffle, + conv_merging=vision_config.conv_merging, + use_tile_tags=vision_config.use_tile_tags, + max_num_tiles=vision_config.max_num_tiles, + use_image_break_token=vision_config.use_image_break_token, + ) + # Stashed so batch() can compute per-image embedding counts under + # dynamic resolution (where the constant emb_per_tile doesn't apply). + self._dynamic_resolution = getattr(vision_config, "dynamic_resolution", False) + self._patch_dim = vision_config.patch_dim + self._pixel_shuffle = vision_config.pixel_shuffle + self._conv_merging = vision_config.conv_merging + + def batch(self, samples: list[PackedSample]) -> dict: + """Expand image placeholders and return a MIMO-compatible batch.""" + image_token_id = self.packing_config.image_token_id + ignore_index = self.packing_config.ignore_index + pad_id = self.packing_config.pad_id + emb_per_tile = self._embeddings_per_tile + + expanded_tokens_list = [] + expanded_labels_list = [] + all_images = [] + + for sample in samples: + tokens = sample.tokens + labels = sample.labels + num_tiles = sample.num_tiles + budget = self._target_seq_length + new_tokens = [] + new_labels = [] + img_idx = 0 + truncated = False + truncated_padding_only = False + kept_tile_count = 0 + + for idx, token in enumerate(tokens.tolist()): + if token == image_token_id: + n_tiles = num_tiles[img_idx] if img_idx < len(num_tiles) else 1 + if self._dynamic_resolution: + # Each image produces (h/p) * (w/p) patches; pixel_shuffle and + # conv_merging each halve both axes => divide by 4 each. + img_pix = sample.images[img_idx] + h_pix = img_pix.shape[-2] + w_pix = img_pix.shape[-1] + per_image = (h_pix // self._patch_dim) * (w_pix // self._patch_dim) + if self._pixel_shuffle: + per_image //= 4 + if self._conv_merging: + per_image //= 4 + n_tokens = per_image + else: + n_tokens = n_tiles * emb_per_tile + if budget is not None and len(new_tokens) + n_tokens > budget: + truncated = True + break + new_tokens.extend([image_token_id] * n_tokens) + new_labels.extend([ignore_index] * n_tokens) + kept_tile_count += n_tiles + img_idx += 1 + else: + if budget is not None and len(new_tokens) + 1 > budget: + truncated = True + truncated_padding_only = _remaining_tokens_are_padding( + tokens=tokens, + labels=labels, + start=idx, + pad_id=pad_id, + ignore_index=ignore_index, + ) + break + new_tokens.append(token) + new_labels.append(labels[idx].item()) + + if truncated and len(sample.cu_lengths) > 2 and not truncated_padding_only: + raise RuntimeError( + "Packed Energon sample exceeds target sequence length after MIMO image-token " + "expansion. Refusing to clamp packed cu_seqlens because that can create " + "zero-length packed segments. Increase --total-seq-length or lower image " + "tiling/packing settings." + ) + + if truncated and not truncated_padding_only: + warnings.warn( + f"Sample truncated to fit target_seq_length ({self._target_seq_length}): " + f"kept {len(new_tokens)} of ~{len(tokens)} original tokens, " + f"{img_idx}/{len(num_tiles)} images ({kept_tile_count} tiles). " + "Consider increasing --total-seq-length or reducing --max-num-tiles.", + stacklevel=2, + ) + + all_images.extend(sample.images[:kept_tile_count]) + expanded_tokens_list.append(torch.tensor(new_tokens, dtype=torch.long)) + expanded_labels_list.append(torch.tensor(new_labels, dtype=torch.long)) + + max_len = max(len(tokens) for tokens in expanded_tokens_list) + if self._target_seq_length is not None: + max_len = self._target_seq_length + + batch_size = len(samples) + tokens_batch = torch.full((batch_size, max_len), pad_id, dtype=torch.long) + labels_batch = torch.full((batch_size, max_len), ignore_index, dtype=torch.long) + + for idx, (tokens, labels) in enumerate(zip(expanded_tokens_list, expanded_labels_list)): + tokens_batch[idx, : len(tokens)] = tokens + labels_batch[idx, : len(labels)] = labels + + loss_mask = (labels_batch != ignore_index).float() + loss_mask[labels_batch == image_token_id] = 0.0 + position_ids = torch.arange(max_len).unsqueeze(0).expand(batch_size, -1).contiguous() + + result = { + "input_ids": tokens_batch, + "labels": labels_batch, + "loss_mask": loss_mask, + "position_ids": position_ids, + } + + if all_images: + images, imgs_sizes, cu_lengths, max_seqlen = self.tiling_strategy.stack(all_images) + encoder_inputs = {self.encoder_input_key: images} + if imgs_sizes is not None: + encoder_inputs["imgs_sizes"] = imgs_sizes.to(torch.int32) + if cu_lengths is not None and max_seqlen is not None: + # THD packing metadata for RADIO's variable-length attention. + # Class-token offsets get applied inside RADIO.forward. + cu = cu_lengths.to(torch.int32) + max_q = max_seqlen.to(torch.int32) if torch.is_tensor(max_seqlen) else torch.tensor(int(max_seqlen), dtype=torch.int32) + encoder_inputs["packed_seq_params"] = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu, + cu_seqlens_kv=cu, + max_seqlen_q=max_q, + max_seqlen_kv=max_q, + ) + result["modality_inputs"] = { + "images": {self.encoder_name: encoder_inputs} + } + + is_packed = any(len(sample.cu_lengths) > 2 for sample in samples) + if is_packed: + if batch_size != 1: + raise RuntimeError(f"Packing requires micro_batch_size=1, got {batch_size}") + result["packing_kwargs"] = _build_packing_kwargs(samples[0], max_len) + + if self._attach_provenance: + active = WorkerConfig.active_worker_config + if active is None: + raise RuntimeError( + "attach_provenance=True requires an active Energon worker context" + ) + result[self.PROVENANCE_KEY] = active.global_worker_id() + + return result + + +def _remaining_tokens_are_padding( + tokens: torch.Tensor, labels: torch.Tensor, start: int, pad_id: int, ignore_index: int +) -> bool: + """Return whether truncation only drops right-padding tokens.""" + remaining_tokens = tokens[start:] + remaining_labels = labels[start:] + return bool( + remaining_tokens.numel() > 0 + and torch.all(remaining_tokens == pad_id).item() + and torch.all(remaining_labels == ignore_index).item() + ) + + +def _build_packing_kwargs(sample: PackedSample, max_len: int) -> dict[str, torch.Tensor]: + """Build validated packed-sequence metadata for the MIMO language model.""" + cu_seqlens = sample.cu_lengths.to(dtype=torch.int32) + if cu_seqlens.numel() < 2: + raise RuntimeError(f"Packed sample must have at least two cu_lengths, got {cu_seqlens}") + if torch.any(cu_seqlens[1:] < cu_seqlens[:-1]): + raise RuntimeError(f"Packed cu_lengths must be monotonic, got {cu_seqlens.tolist()}") + + if cu_seqlens[0] != 0: + cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), cu_seqlens]) + if cu_seqlens[-1] > max_len: + raise RuntimeError( + f"Packed cu_lengths end at {int(cu_seqlens[-1])}, beyond sequence length {max_len}" + ) + if cu_seqlens[-1] != max_len: + cu_seqlens = torch.cat([cu_seqlens, torch.tensor([max_len], dtype=torch.int32)]) + + segment_lens = cu_seqlens[1:] - cu_seqlens[:-1] + if torch.any(segment_lens <= 0): + raise RuntimeError( + "Packed cu_lengths must be strictly increasing after MIMO expansion, " + f"got {cu_seqlens.tolist()}" + ) + max_seqlen = segment_lens.max() + return { + "qkv_format": "thd", + "cu_seqlens_q": cu_seqlens, + "cu_seqlens_kv": cu_seqlens, + "cu_seqlens_q_padded": cu_seqlens, + "cu_seqlens_kv_padded": cu_seqlens, + "max_seqlen_q": int(max_seqlen.item()), + "max_seqlen_kv": int(max_seqlen.item()), + "total_tokens": int(max_len), + } + + +def build_multimodal_encoder( + args, + tokenizer, + encoder_name: str = "radio_encoder", + encoder_input_key: str = "x", + attach_provenance: bool = False, +) -> MimoMultiModalPackingEncoder: + """Build the MIMO Energon encoder from train args.""" + target_seq_length = _resolve_target_seq_length(args) + image_token_id = getattr(args, "image_token_id", None) + if image_token_id is None: + image_token_id = tokenizer.convert_tokens_to_ids(getattr(args, "image_token", "")) + pad_id = getattr(args, "pad_token_id", tokenizer.pad) + + vision_config_kwargs = dict( + img_h=args.img_h, + img_w=args.img_w, + patch_dim=args.patch_dim, + vision_model_type=getattr(args, "vision_model_type", "radio"), + disable_vision_class_token=getattr(args, "disable_vision_class_token", False), + pixel_shuffle=getattr(args, "pixel_shuffle", False), + max_num_tiles=getattr(args, "max_num_tiles", getattr(args, "num_image_tiles", 1)), + use_tiling=getattr(args, "use_tiling", False), + use_thumbnail=getattr(args, "use_thumbnail", False), + class_token_len=getattr(args, "class_token_len", None) or 1, + conv_merging=getattr(args, "conv_merging", False), + use_tile_tags=getattr(args, "use_tile_tags", False), + use_image_break_token=getattr(args, "image_break_token", None) is not None, + use_area_weighted_aspect_ratio=getattr(args, "use_area_weighted_aspect_ratio", False), + dynamic_resolution=getattr(args, "dynamic_resolution", False), + dynamic_resolution_min_patches=getattr(args, "dynamic_resolution_min_patches", 4), + dynamic_resolution_max_patches=getattr(args, "dynamic_resolution_max_patches", 0), + dynamic_resolution_min_side=getattr(args, "dynamic_resolution_min_side", None), + dynamic_resolution_max_side=getattr(args, "dynamic_resolution_max_side", None), + ) + # Drop kwargs the installed energon's VisionConfig doesn't accept (e.g. + # dynamic_resolution_max_side is only on newer forks). + vision_config = VisionConfig(**_supported_kwargs(VisionConfig, vision_config_kwargs)) + packing_config = PackingConfig( + seq_length=target_seq_length, pad_id=pad_id, image_token_id=image_token_id + ) + return MimoMultiModalPackingEncoder( + vision_config=vision_config, + packing_config=packing_config, + tokenizer=TokenizerAdapter(tokenizer), + encoder_name=encoder_name, + encoder_input_key=encoder_input_key, + target_seq_length=target_seq_length, + attach_provenance=attach_provenance, + ) + + +def _resolve_target_seq_length(args) -> int: + """Return the sequence length used by Energon and MIMO expansion.""" + target_seq_length = getattr(args, "total_seq_length", None) + if target_seq_length is None: + target_seq_length = getattr(args, "seq_length", None) + if target_seq_length is None: + raise AttributeError("Energon multimodal provider requires total_seq_length or seq_length") + return target_seq_length diff --git a/examples/mimo/data/hetero_energon.py b/examples/mimo/data/hetero_energon.py new file mode 100644 index 00000000000..1fb2380f91e --- /dev/null +++ b/examples/mimo/data/hetero_energon.py @@ -0,0 +1,732 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Heterogeneous-rank wrapper for the MIMO Energon multimodal provider.""" + +from __future__ import annotations + +import hashlib +import random +from collections import deque +from typing import Callable, Optional + +import torch +import torch.distributed as dist + +from examples.mimo.training.hetero.topology import get_grid_coordinate, is_rank_in_grid +from examples.mimo.utils.hetero import debug_rank, is_process_group_member +from megatron.core.packed_seq_params import PackedSeqParams + + +def build_energon_iterator(args, topology): + """Build an Energon iterator for the current rank, or return None if unused.""" + from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage + + encoder_grid = topology.encoder_grid + llm_grid = topology.llm_grid + encoder_needs_data = ( + encoder_grid is not None + and is_rank_in_grid(encoder_grid) + and is_pp_first_stage(encoder_grid.get_pg("pp")) + ) + llm_needs_data = is_rank_in_grid(llm_grid) and ( + is_pp_first_stage(llm_grid.get_pg("pp")) or is_pp_last_stage(llm_grid.get_pg("pp")) + ) + + if encoder_needs_data: + return _build_encoder_iterator(args, encoder_grid) + if llm_needs_data: + return _build_llm_iterator(args, llm_grid) + return None + + +def validate_energon_data_alignment(data_iterator, _topology) -> None: + """Check the first actual-data batch aligns across non-colocated module grids.""" + if not dist.is_initialized(): + return + + gathered = [None for _ in range(dist.get_world_size())] + dist.all_gather_object( + gathered, data_iterator.peek_alignment() if data_iterator is not None else None + ) + + encoder_signatures_by_lane = {} + llm_signatures_by_lane = {} + for candidate in gathered: + if candidate is None: + continue + target = ( + encoder_signatures_by_lane if candidate["role"] == "encoder" else llm_signatures_by_lane + ) + for lane, signature in zip(candidate["llm_lanes"], candidate["signatures"]): + target.setdefault(lane, set()).add(signature) + + mismatched = {} + for lane in sorted(set(encoder_signatures_by_lane) | set(llm_signatures_by_lane)): + encoder_values = encoder_signatures_by_lane.get(lane, set()) + llm_values = llm_signatures_by_lane.get(lane, set()) + if len(encoder_values) != 1 or len(llm_values) != 1 or encoder_values != llm_values: + mismatched[lane] = {"encoder": sorted(encoder_values), "llm": sorted(llm_values)} + if mismatched: + raise RuntimeError(f"hetero Energon data loaders diverged across grids: {mismatched}") + + +def _build_llm_iterator(args, grid): + """Build the single-lane LLM iterator for this grid coordinate.""" + tp_group = grid.get_pg("tp") + if get_grid_coordinate(grid, "tp") != 0: + lane = get_grid_coordinate(grid, "dp") + return EnergonIterator( + None, tp_group=tp_group, source_rank=False, alignment_role="llm", llm_lanes=[lane] + ) + + lane = get_grid_coordinate(grid, "dp") + return _build_single_lane_iterator( + args, tp_group=tp_group, lane=lane, role="llm", random_seed=args.seed + lane + ) + + +def _build_encoder_iterator(args, grid): + """Build the encoder iterator, composing LLM-lane samples for DP fan-out.""" + tp_group = grid.get_pg("tp") + encoder_dp_rank = get_grid_coordinate(grid, "dp") + llm_lanes = _llm_lanes_for_encoder_rank(args, encoder_dp_rank) + if get_grid_coordinate(grid, "tp") != 0: + return EnergonIterator( + None, + tp_group=tp_group, + source_rank=False, + alignment_role="encoder", + llm_lanes=llm_lanes, + ) + + if len(llm_lanes) == 1: + return _build_single_lane_iterator( + args, + tp_group=tp_group, + lane=llm_lanes[0], + role="encoder", + # energon's WorkerConfig(rank=lane, world_size=llm_dp) already + # salts per-rank, so the seed here must be unsalted. + random_seed=args.seed, + ) + + return _build_routed_encoder_iterator( + args, tp_group=tp_group, encoder_dp_rank=encoder_dp_rank, llm_lanes=llm_lanes + ) + + +def _route_samples_to_lanes( + loader_iter, + *, + lanes_per_encoder: int, + lane_offset: int, + num_workers_per_lane: int, + encoder_dp_rank: int, + pending_by_lane: list, + max_pulls_per_step: int, + provenance_key: str, +) -> tuple[list, int]: + """Pull samples from a single multiplexed loader and route each one to its LLM lane. + + Samples are routed by reading the producing worker's + ``WorkerConfig.global_worker_id()``, which the encoder batcher stamps under + ``provenance_key``. The mapping from worker id back to local lane is: + + global_worker_id = encoder_dp_rank * num_workers_enc + local_worker_id + global_llm_lane = global_worker_id // num_workers_per_lane + local_lane = global_llm_lane - lane_offset + + Surplus samples (a worker yields a second sample for a lane that's already + filled this step) are stashed in ``pending_by_lane`` and consumed on the + next encoder step. ``max_pulls_per_step`` bounds the loop so a stuck or + skewed worker pool fails loudly instead of silently stalling. + + Returns ``(lane_batches, pulls)`` where ``lane_batches[lane]`` is the sample + routed to local lane ``lane``. + """ + lane_batches: list = [None] * lanes_per_encoder + filled = 0 + for lane in range(lanes_per_encoder): + if pending_by_lane[lane]: + lane_batches[lane] = pending_by_lane[lane].popleft() + filled += 1 + pulls = 0 + while filled < lanes_per_encoder: + if pulls >= max_pulls_per_step: + missing = [i for i, b in enumerate(lane_batches) if b is None] + raise RuntimeError( + f"encoder dataloader did not yield samples for local_lanes={missing} " + f"in {max_pulls_per_step} pulls (encoder_dp_rank={encoder_dp_rank}); " + "check Energon worker rotation contract" + ) + sample = next(loader_iter) + pulls += 1 + wid = sample.pop(provenance_key, None) + if wid is None: + raise RuntimeError( + f"encoder sample missing {provenance_key!r}; " + "ensure build_multimodal_encoder was called with attach_provenance=True" + ) + global_llm_lane = wid // num_workers_per_lane + local_lane = global_llm_lane - lane_offset + if not (0 <= local_lane < lanes_per_encoder): + raise RuntimeError( + f"worker_id={wid} maps to global_llm_lane={global_llm_lane}, " + f"outside encoder rank {encoder_dp_rank} range " + f"[{lane_offset}, {lane_offset + lanes_per_encoder})" + ) + if lane_batches[local_lane] is None: + lane_batches[local_lane] = sample + filled += 1 + else: + pending_by_lane[local_lane].append(sample) + return lane_batches, pulls + + +def _build_routed_encoder_iterator(args, tp_group, encoder_dp_rank, llm_lanes): + """Build one Energon iterator per encoder rank and route samples back to LLM lanes. + + The previous implementation built ``lanes_per_encoder`` independent Energon + iterators per encoder rank — one per LLM data lane — which produces + ``lanes_per_encoder × num_workers`` shard-open events at construction. + This collapses that to a single Energon iterator with + ``num_workers = args.num_workers * lanes_per_encoder``; each emitted batch + is routed to its owning lane using the producing worker's + ``WorkerConfig.global_worker_id()`` that the encoder batcher stamps onto + every batch. + + Bit-wise sample parity with the per-lane iterator path is preserved by + Energon's design: ``global_workers = world_size * num_workers`` is invariant + under this reshape and per-worker seeds depend only on ``global_worker_id`` + and ``seed_offset`` (see ``megatron/energon/worker.py``), so each worker + here produces the same shards in the same order as the per-lane worker it + replaces. + """ + from examples.mimo.data.energon_multimodal_provider import ( + MimoMultiModalPackingEncoder, + build_multimodal_encoder, + ) + from megatron.energon import WorkerConfig, get_savable_loader, get_train_dataset + from megatron.energon.cache.no_cache import NoCachePool + + if args.num_workers < 1: + raise ValueError( + "routed encoder iterator requires args.num_workers >= 1 " + "(global_worker_id -> lane mapping divides by num_workers_per_lane); " + f"got {args.num_workers}" + ) + lanes_per_encoder = len(llm_lanes) + num_workers_per_lane = args.num_workers + num_workers_enc = num_workers_per_lane * lanes_per_encoder + lane_offset = llm_lanes[0] + + tokenizer = _build_tokenizer(args) + encoder = build_multimodal_encoder( + args, + tokenizer, + encoder_name=getattr(args, "vision_encoder_key", "radio_encoder"), + encoder_input_key="x", + attach_provenance=True, + ) + worker_config = WorkerConfig( + rank=encoder_dp_rank, + world_size=args.encoder_dp, + num_workers=num_workers_enc, + data_parallel_group=None, + ) + debug_rank( + "building routed encoder dataloader " + f"encoder_dp_rank={encoder_dp_rank} encoder_dp={args.encoder_dp} " + f"num_workers_enc={num_workers_enc} lanes_per_encoder={lanes_per_encoder} " + f"lane_offset={lane_offset}" + ) + dataset = get_train_dataset( + args.data_path, + batch_size=args.micro_batch_size, + task_encoder=encoder, + worker_config=worker_config, + packing_buffer_size=args.packing_buffer_size, + shuffle_buffer_size=args.shuffle_buffer_size, + max_samples_per_sequence=args.max_samples_per_sequence, + ) + loader = get_savable_loader( + dataset, + cache_pool=NoCachePool(), + watchdog_timeout_seconds=5 * 60, + watchdog_initial_timeout_seconds=5 * 60, + ) + + loader_iter_holder: list = [iter(loader)] + # Dense integer keys (0..lanes_per_encoder-1) → use a list so the hot-path + # routing in ``_route_samples_to_lanes`` does O(1) array indexing rather + # than dict probing. + pending_by_lane: list[deque] = [deque() for _ in range(lanes_per_encoder)] + # Energon's SavableDataLoader rotates through every worker in one round, + # so a step worst case needs ``num_workers_enc`` pulls to fill every lane + # (one batch per worker, including the surplus to lanes that filled + # early). The 4× factor adds slack for transient rotation skew; we cap + # below by 2*num_workers_enc so configurations with high + # ``num_workers_per_lane`` aren't bounded too tightly. A genuine stall + # surfaces as a loud failure in ``_route_samples_to_lanes``. + max_pulls_per_step = max(4 * lanes_per_encoder, 2 * num_workers_enc) + provenance_key = MimoMultiModalPackingEncoder.PROVENANCE_KEY + + def next_encoder_batch(): + try: + lane_batches, _pulls = _route_samples_to_lanes( + loader_iter_holder[0], + lanes_per_encoder=lanes_per_encoder, + lane_offset=lane_offset, + num_workers_per_lane=num_workers_per_lane, + encoder_dp_rank=encoder_dp_rank, + pending_by_lane=pending_by_lane, + max_pulls_per_step=max_pulls_per_step, + provenance_key=provenance_key, + ) + except StopIteration: + # One-shot per epoch on savable-loader exhaustion. Any partial + # ``lane_batches`` accumulated before the exception is dropped — + # those samples count against the worker's seed sequence and are + # never delivered. Acceptable because webdataset is streamed as + # a pseudo-infinite source; this branch is rarely hit in practice. + loader_iter_holder[0] = iter(loader) + lane_batches, _pulls = _route_samples_to_lanes( + loader_iter_holder[0], + lanes_per_encoder=lanes_per_encoder, + lane_offset=lane_offset, + num_workers_per_lane=num_workers_per_lane, + encoder_dp_rank=encoder_dp_rank, + pending_by_lane=pending_by_lane, + max_pulls_per_step=max_pulls_per_step, + provenance_key=provenance_key, + ) + signatures = [EnergonIterator._batch_signature(batch) for batch in lane_batches] + return _combine_encoder_batches(lane_batches), signatures + + return EnergonIterator( + None, + tp_group=tp_group, + source_rank=True, + random_seed=args.seed, + local_batch_fn=next_encoder_batch, + alignment_role="encoder", + llm_lanes=llm_lanes, + ) + + +def _llm_lanes_for_encoder_rank(args, encoder_dp_rank: int) -> list[int]: + """Return the contiguous LLM DP lanes owned by one encoder DP lane.""" + scale = args.llm_dp // args.encoder_dp + start = encoder_dp_rank * scale + return list(range(start, start + scale)) + + +def _build_single_lane_iterator(args, tp_group, lane: int, role: str, random_seed: int): + """Build a deterministic loader for one LLM data lane.""" + from examples.mimo.data.energon_multimodal_provider import build_multimodal_encoder + from megatron.energon import WorkerConfig, get_savable_loader, get_train_dataset + + tokenizer = _build_tokenizer(args) + encoder = build_multimodal_encoder( + args, + tokenizer, + encoder_name=getattr(args, "vision_encoder_key", "radio_encoder"), + encoder_input_key="x", + ) + worker_config = WorkerConfig( + rank=lane, world_size=args.llm_dp, num_workers=args.num_workers, data_parallel_group=None + ) + debug_rank( + "building energon dataloader " + f"role={role} lane={lane} dp_world={args.llm_dp} batch_size={args.micro_batch_size}" + ) + dataset = get_train_dataset( + args.data_path, + batch_size=args.micro_batch_size, + task_encoder=encoder, + worker_config=worker_config, + packing_buffer_size=args.packing_buffer_size, + shuffle_buffer_size=args.shuffle_buffer_size, + max_samples_per_sequence=args.max_samples_per_sequence, + ) + from megatron.energon.cache.no_cache import NoCachePool + + loader = get_savable_loader( + dataset, + cache_pool=NoCachePool(), + watchdog_timeout_seconds=5 * 60, + watchdog_initial_timeout_seconds=5 * 60, + ) + return EnergonIterator( + loader, + tp_group=tp_group, + source_rank=True, + random_seed=random_seed, + alignment_role="encoder" if role.startswith("encoder") else "llm", + llm_lanes=[lane], + ) + + +def _combine_encoder_batches(batches: list[dict]) -> dict: + """Combine LLM-lane batches into one encoder batch and drop LLM-only metadata.""" + if not batches: + raise RuntimeError("cannot combine an empty encoder batch list") + + combined = {} + for key in ("input_ids", "labels", "loss_mask", "position_ids"): + values = [batch.get(key) for batch in batches if batch.get(key) is not None] + if values: + combined[key] = torch.cat(values, dim=0) + + modality_values = [ + batch.get("modality_inputs") + for batch in batches + if batch.get("modality_inputs") is not None + ] + if modality_values: + combined["modality_inputs"] = _merge_modality_inputs(modality_values) + + return combined + + +# --------------------------------------------------------------------------- +# Schema-aware merge of ``modality_inputs`` across LLM lanes served by one +# encoder rank. The structure produced by the dataset is fixed: +# +# modality_inputs = { +# "": { # e.g. "images" +# "": { # e.g. "radio_encoder" +# : Tensor of shape (1, T_lane, C), +# "imgs_sizes": Tensor of shape (N_images_lane, 2), +# "packed_seq_params": PackedSeqParams describing the T axis, +# } +# } +# } +# +# Each per-lane tensor has a known concat semantics; we encode them +# explicitly rather than inferring from runtime shape variation: +# +# * packed image buffer: leading dim is always 1 (lane batch == MBS=1); +# dim 1 is the variable token axis -> concat along dim 1. +# * ``imgs_sizes``: dim 0 = per-lane image count -> concat along dim 0. +# * ``packed_seq_params``: cu_seqlens need offset-shifting -> custom merge. +# --------------------------------------------------------------------------- + + +def _merge_modality_inputs(per_lane_modality_inputs): + """Merge the ``modality_inputs`` field of N per-lane batches.""" + merged = {} + modality_types = set().union( + *(p.keys() for p in per_lane_modality_inputs if isinstance(p, dict)) + ) + for mod_type in sorted(modality_types): + per_lane_mod = [p[mod_type] for p in per_lane_modality_inputs if mod_type in p] + merged_per_encoder = {} + encoder_names = set().union( + *(p.keys() for p in per_lane_mod if isinstance(p, dict)) + ) + for enc_name in sorted(encoder_names): + per_lane_enc = [p[enc_name] for p in per_lane_mod if enc_name in p] + merged_per_encoder[enc_name] = _merge_encoder_inputs(per_lane_enc) + merged[mod_type] = merged_per_encoder + return merged + + +def _merge_encoder_inputs(per_lane_enc_inputs): + """Merge per-lane encoder-input dicts using a key-explicit schema. + + Keys are categorized by name / value type: + * ``packed_seq_params`` -> ``_concat_packed_seq_params`` + * ``imgs_sizes`` -> ``torch.cat(..., dim=0)`` + * any other ``Tensor`` -> packed image buffer ``(1, T, C)``, + concat along dim 1 + Anything else triggers a loud error so a future schema change has to be + handled here rather than guessed at by a heuristic. + """ + merged = {} + keys = set().union(*(p.keys() for p in per_lane_enc_inputs if isinstance(p, dict))) + for key in sorted(keys): + vals = [p[key] for p in per_lane_enc_inputs if key in p] + if not vals: + continue + first = vals[0] + if isinstance(first, PackedSeqParams): + merged[key] = _concat_packed_seq_params(vals) + elif key == "imgs_sizes": + assert all(isinstance(v, torch.Tensor) for v in vals), ( + f"imgs_sizes must be tensors, got {[type(v).__name__ for v in vals]}" + ) + merged[key] = torch.cat(vals, dim=0) + elif isinstance(first, torch.Tensor): + # Packed image buffer: leading dim is the lane batch (==1); the + # variable token axis is dim 1. + assert first.dim() >= 2 and first.shape[0] == 1, ( + f"unexpected packed-buffer shape for encoder key {key!r}: " + f"{tuple(first.shape)} (expected leading dim 1)" + ) + merged[key] = torch.cat(vals, dim=1) + else: + raise TypeError( + f"unsupported encoder-input value for key {key!r}: " + f"{type(first).__name__}; extend _merge_encoder_inputs" + ) + return merged + + +def _concat_packed_seq_params(values: list) -> PackedSeqParams: + """Merge per-lane PackedSeqParams into one set covering the merged flat buffer. + + The dim-0 image buffers from each lane are concatenated by the surrounding + tensor merge; here we re-number cu_seqlens so they index into that merged + buffer. Mirrors the offset-shift rule in + ``megatron.energon.task_encoder.multimodal.encoder``. + """ + first = values[0] + for v in values[1:]: + if v.qkv_format != first.qkv_format: + raise ValueError( + f"qkv_format mismatch across encoder lanes: " + f"{first.qkv_format!r} vs {v.qkv_format!r}" + ) + if v.local_cp_size != first.local_cp_size or v.cp_group is not first.cp_group: + raise ValueError("CP fields mismatch across encoder lanes; refusing to merge") + + def _concat_cu(attr: str): + per_lane = [getattr(v, attr) for v in values] + if per_lane[0] is None: + if not all(x is None for x in per_lane): + raise ValueError(f"{attr} present on some lanes but not others") + return None + merged = [per_lane[0]] + offset = int(per_lane[0][-1].item()) + for cu in per_lane[1:]: + merged.append(cu[1:] + offset) + offset += int(cu[-1].item()) + return torch.cat(merged) + + def _max_scalar(attr: str): + per_lane = [getattr(v, attr) for v in values] + if per_lane[0] is None: + if not all(x is None for x in per_lane): + raise ValueError(f"{attr} present on some lanes but not others") + return None + if torch.is_tensor(per_lane[0]): + return torch.stack([x.reshape(()) for x in per_lane]).max() + return max(per_lane) + + def _sum_or_none(attr: str): + per_lane = [getattr(v, attr) for v in values] + if all(x is None for x in per_lane): + return None + if any(x is None for x in per_lane): + raise ValueError(f"{attr} present on some lanes but not others") + return sum(per_lane) + + return PackedSeqParams( + qkv_format=first.qkv_format, + cu_seqlens_q=_concat_cu("cu_seqlens_q"), + cu_seqlens_kv=_concat_cu("cu_seqlens_kv"), + cu_seqlens_q_padded=_concat_cu("cu_seqlens_q_padded"), + cu_seqlens_kv_padded=_concat_cu("cu_seqlens_kv_padded"), + max_seqlen_q=_max_scalar("max_seqlen_q"), + max_seqlen_kv=_max_scalar("max_seqlen_kv"), + total_tokens=_sum_or_none("total_tokens"), + local_cp_size=first.local_cp_size, + cp_group=first.cp_group, + ) + + +def _build_tokenizer(args): + from megatron.core.tokenizers.vision.libraries.multimodal_tokenizer import ( + MegatronMultimodalTokenizer, + ) + + return MegatronMultimodalTokenizer( + path=args.tokenizer_model, + prompt_format=args.tokenizer_prompt_format, + special_tokens=[args.image_token], + image_tag_type=args.image_tag_type, + force_system_message=args.force_system_message, + ) + + +class EnergonIterator: + """Endless wrapper around an Energon dataloader with TP-rank-0 ownership.""" + + def __init__( + self, + dataloader, + tp_group=None, + source_rank: bool = True, + random_seed: Optional[int] = None, + local_batch_fn: Optional[Callable[[], dict]] = None, + alignment_role: Optional[str] = None, + llm_lanes: Optional[list[int]] = None, + ) -> None: + self._dataloader = dataloader + self._iterator = None + self._tp_group = tp_group + self._source_rank = source_rank + self._local_batch_fn = local_batch_fn + self._alignment_role = alignment_role + self._llm_lanes = llm_lanes or [] + self._prefetched = None + self._prefetched_component_signatures = None + self._local_component_signatures = None + self._python_random_state = None + if random_seed is not None: + rng = random.Random(random_seed) + self._python_random_state = rng.getstate() + + def __iter__(self): + return self + + def __next__(self): + if self._prefetched is not None: + batch = self._prefetched + self._prefetched = None + return batch + + batch = self._next_local_batch() if self._source_rank else None + component_signatures = self._current_component_signatures(batch) + if is_process_group_member(self._tp_group) and self._tp_group.size() > 1: + obj = [(batch, component_signatures)] + dist.broadcast_object_list(obj, src=self._tp_source_rank(), group=self._tp_group) + batch, component_signatures = obj[0] + self._prefetched_component_signatures = component_signatures + return batch + + def peek_alignment(self): + """Read and retain the next batch, returning lane signatures from TP source ranks.""" + if self._prefetched is None: + self._prefetched = next(self) + if not self._source_rank or self._alignment_role is None: + return None + signatures = self._prefetched_component_signatures + if signatures is None: + signatures = [self._batch_signature(self._prefetched)] + return { + "role": self._alignment_role, + "llm_lanes": self._llm_lanes, + "signatures": signatures, + } + + def _next_local_batch(self): + """Read the next local Energon batch on the TP source rank.""" + if self._python_random_state is None: + result = self._read_next_local_batch() + return self._extract_batch_and_signatures(result) + + global_random_state = random.getstate() + try: + random.setstate(self._python_random_state) + result = self._read_next_local_batch() + batch = self._extract_batch_and_signatures(result) + self._python_random_state = random.getstate() + return batch + finally: + random.setstate(global_random_state) + + def _extract_batch_and_signatures(self, result): + """Handle local batch providers that also return component signatures.""" + self._local_component_signatures = None + if isinstance(result, tuple) and len(result) == 2: + batch, signatures = result + self._local_component_signatures = signatures + return batch + return result + + def _read_next_local_batch(self): + """Read from the underlying dataloader, cycling at epoch boundaries.""" + if self._local_batch_fn is not None: + return self._local_batch_fn() + if self._iterator is None: + self._iterator = iter(self._dataloader) + try: + return next(self._iterator) + except StopIteration: + self._iterator = iter(self._dataloader) + return next(self._iterator) + + def _current_component_signatures(self, batch): + """Return per-lane signatures for the current batch if they can be inferred.""" + if batch is None: + return None + if self._local_component_signatures is not None: + return self._local_component_signatures + return [self._batch_signature(batch)] + + def _tp_source_rank(self) -> int: + """Return the global source rank for the local TP batch broadcast.""" + if hasattr(dist, "get_global_rank"): + return dist.get_global_rank(self._tp_group, 0) + return dist.get_process_group_ranks(self._tp_group)[0] + + @classmethod + def _batch_signature(cls, batch: dict) -> tuple[int, ...]: + """Return a compact signature for cross-grid data-alignment checks.""" + image_tensor = cls._nested_get(batch, ("modality_inputs", "images")) + if isinstance(image_tensor, dict): + image_tensor = cls._first_tensor(image_tensor) + packing_kwargs = batch.get("packing_kwargs") + return ( + cls._checksum_tensor(batch.get("input_ids")), + cls._checksum_tensor(batch.get("labels")), + int(batch.get("loss_mask", torch.zeros(1)).sum().item()), + 0 if image_tensor is None else int(image_tensor.shape[0]), + cls._checksum_tensor(image_tensor), + cls._checksum_packing_kwargs(packing_kwargs), + ) + + @staticmethod + def _nested_get(value: dict, keys: tuple[str, ...]): + """Return a nested dict value if every key exists.""" + current = value + for key in keys: + if not isinstance(current, dict) or key not in current: + return None + current = current[key] + return current + + @classmethod + def _first_tensor(cls, value): + """Return the first tensor inside a nested mapping.""" + if isinstance(value, torch.Tensor): + return value + if isinstance(value, dict): + for item in value.values(): + tensor = cls._first_tensor(item) + if tensor is not None: + return tensor + return None + + @classmethod + def _checksum_packing_kwargs(cls, packing_kwargs: Optional[dict]) -> int: + """Checksum packed-sequence metadata used by the language model.""" + if packing_kwargs is None: + return 0 + checksum = 0 + for key in sorted(packing_kwargs): + value = packing_kwargs[key] + if isinstance(value, torch.Tensor): + value_checksum = cls._checksum_tensor(value) + elif value is None: + value_checksum = 0 + elif isinstance(value, str): + value_checksum = sum(value.encode("utf-8")) + else: + value_checksum = int(value) + checksum = (checksum * 131 + value_checksum) % 2_147_483_647 + return checksum + + @staticmethod + def _checksum_tensor(tensor: Optional[torch.Tensor]) -> int: + """Return a stable full-tensor checksum for a CPU tensor-like batch field.""" + if tensor is None or tensor.numel() == 0: + return 0 + tensor = tensor.detach().cpu().contiguous() + digest = hashlib.blake2b(digest_size=8) + digest.update(str(tuple(tensor.shape)).encode("ascii")) + digest.update(str(tensor.dtype).encode("ascii")) + digest.update(memoryview(tensor.numpy()).cast("B")) + return int.from_bytes(digest.digest(), byteorder="big", signed=False) diff --git a/examples/mimo/data/hetero_mock.py b/examples/mimo/data/hetero_mock.py new file mode 100644 index 00000000000..137028a68e1 --- /dev/null +++ b/examples/mimo/data/hetero_mock.py @@ -0,0 +1,133 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Mock VLM data provider for heterogeneous MIMO training examples.""" + +from __future__ import annotations + +import argparse + +import torch + +from examples.mimo.utils.hetero import debug_rank + + +def validate_mock_data_args(args: argparse.Namespace) -> None: + """Validate synthetic next-token VLM data constraints.""" + image_seq_length = args.image_seq_length or args.seq_length // 2 + if image_seq_length >= args.seq_length: + raise ValueError("--image-seq-length must be smaller than --seq-length") + if args.seq_length - image_seq_length < 2: + raise ValueError("mock next-token training needs at least two text tokens") + + +class MockVLMIterator: + """Infinite iterator yielding synthetic VLM-like microbatches.""" + + def __init__( + self, args: argparse.Namespace, micro_batch_size: int, encoder_name: str, seed: int + ) -> None: + self.args = args + self.micro_batch_size = micro_batch_size + self.encoder_name = encoder_name + self.image_seq_length = args.image_seq_length or args.seq_length // 2 + self.vision_encoder_key = getattr(args, "vision_encoder_key", "clip_encoder") + self.vision_input_mode = getattr(args, "vision_input_mode", "hidden_states") + self.dtype = torch.float32 if args.fp32 else torch.bfloat16 + self.generator = torch.Generator(device="cuda") + self.generator.manual_seed(seed) + if self.image_seq_length >= args.seq_length: + raise ValueError("--image-seq-length must be smaller than --seq-length") + + def __iter__(self): + return self + + def __next__(self): + args = self.args + debug_rank( + f"mock batch start: micro_batch_size={self.micro_batch_size}, " + f"image_seq_length={self.image_seq_length}" + ) + image_tokens = torch.full( + (self.micro_batch_size, self.image_seq_length), + args.image_token_id, + dtype=torch.long, + device="cuda", + ) + text_tokens = torch.randint( + 1, + args.vocab_size, + (self.micro_batch_size, args.seq_length - self.image_seq_length), + device="cuda", + generator=self.generator, + ) + special_token_ids = {args.image_token_id, args.pad_token_id} + replacement_token_id = next( + ( + token_id + for token_id in range(1, args.vocab_size) + if token_id not in special_token_ids + ), + None, + ) + if replacement_token_id is None: + raise RuntimeError("mock data needs at least one non-special token id") + if 1 <= args.image_token_id < args.vocab_size: + text_tokens[text_tokens == args.image_token_id] = replacement_token_id + if 1 <= args.pad_token_id < args.vocab_size: + text_tokens[text_tokens == args.pad_token_id] = replacement_token_id + input_ids = torch.cat([image_tokens, text_tokens], dim=1) + + labels = torch.full_like(input_ids, -100) + labels[:, :-1] = input_ids[:, 1:] + labels[(labels == args.image_token_id) | (labels == args.pad_token_id)] = -100 + loss_mask = (labels != -100).to(dtype=torch.float32) + + if self.vision_input_mode == "pixels": + encoder_inputs = { + self.vision_encoder_key: { + "x": torch.randn( + self.micro_batch_size * args.num_image_tiles, + 3, + args.img_h, + args.img_w, + device="cuda", + dtype=self.dtype, + generator=self.generator, + ) + } + } + else: + encoder_hidden_states = torch.randn( + self.image_seq_length, + self.micro_batch_size, + args.hidden_size, + device="cuda", + dtype=self.dtype, + generator=self.generator, + ) + encoder_inputs = { + self.vision_encoder_key: { + "hidden_states": encoder_hidden_states, + "attention_mask": None, + } + } + + num_image_placeholders = (input_ids == args.image_token_id).sum().item() + expected_image_placeholders = self.image_seq_length * self.micro_batch_size + if num_image_placeholders != expected_image_placeholders: + raise RuntimeError( + f"mock batch has {num_image_placeholders} image placeholders, " + f"expected {expected_image_placeholders}" + ) + + debug_rank("mock batch ready") + return { + "input_ids": input_ids, + "labels": labels, + "loss_mask": loss_mask, + "position_ids": torch.arange(args.seq_length, device="cuda") + .unsqueeze(0) + .expand(self.micro_batch_size, -1) + .clone(), + "modality_inputs": {self.encoder_name: {**encoder_inputs}}, + } diff --git a/examples/mimo/docs/e2e_training_parity_plan.md b/examples/mimo/docs/e2e_training_parity_plan.md new file mode 100644 index 00000000000..b2f84090a08 --- /dev/null +++ b/examples/mimo/docs/e2e_training_parity_plan.md @@ -0,0 +1,72 @@ +# E2E Training Parity Plan + +This note tracks the plan for checking end-to-end training parity between the +previous `examples/mimo/train.py` flow from `feat/nemotron-moe-vlm-mimo` and the +new heterogenous `examples/mimo/train_hetero.py` flow. + +## Goal + +Verify that the new heterogenous MIMO training loop matches the previous +Megatron `pretrain()`-based flow for the Nemotron 20L VLM workflow. The strongest +parity signal is matching behavior on a frozen batch stream before comparing live +Energon training runs. + +## Plan + +1. Compare resolved training configuration. + - Dump the final args used by old `train.py`. + - Dump the final args used by new `train_hetero.py`. + - Compare behavior-relevant fields: model config, vision config, MoE config, + TP/PP/EP/ETP/EDP, batch sizes, optimizer, scheduler, seeds, loss scaling, + per-token loss, and dataloader settings. + +2. Start both runs from the same initial weights. + - Prefer a canonical initialized checkpoint or state dict over relying only on + seed-based initialization. + - Compare parameter hashes by logical module: vision encoder, LLM backbone, + MoE experts, router parameters, and projector/MIMO bridge. + +3. Validate data parity before training. + - First use a recorded frozen batch stream, not live Energon. + - Dump exact batch tensors and metadata from the old path: tokens, labels, + loss mask, position ids, modality inputs, packed sequence params, and sample + signatures if available. + - Feed the same frozen batches to the new heterogenous loop and compare batch + hashes before forward. + +4. Run forward-only parity. + - Use the same initialized weights and same frozen batch. + - Disable optimizer updates. + - Compare logits checksums where practical, unreduced loss numerator, token + denominator, normalized loss, and auxiliary/router losses. + +5. Run single-step training parity. + - Use the same frozen batch. + - Run forward, backward, optimizer step, and LR scheduler step. + - Compare loss before step, grad norm, skipped/nan flags, LR, selected + parameter deltas, and post-step parameter hashes. + +6. Run short frozen-stream loss-curve parity. + - Use a fixed stream of 10 to 20 frozen batches. + - Compare per-iteration loss, grad norm, LR, loss scale, skipped/nan counts, + consumed samples, and token counts. + +7. Run actual Energon parity. + - Run the old `train.py` flow and the new `train_hetero.py` flow against the + real Nemotron 20L Energon setup. + - Log sample signatures per global step in both paths. + - First verify that both paths consume the same samples in the same order. + - Compare loss curves only after sample order parity is established. + +## Expected Limits + +Bitwise parity may not be realistic between the old colocated Megatron +`pretrain()` path and the new non-colocated heterogenous grids because collective +ordering, parameter partitioning, and optimizer sharding can differ. The first +strict gates should therefore be configuration parity, initial-weight parity, +frozen-batch forward parity, token-count parity, LR schedule parity, and a short +frozen-batch training curve within a tight tolerance. + +The known parity gap is the old `--use-loss-scaling` path. The new heterogenous +loop uses per-token global loss normalization, but it does not yet implement the +old optional sqrt-weighted scaled loss behavior. diff --git a/examples/mimo/model_providers/nemotron_moe_vlm.py b/examples/mimo/model_providers/nemotron_moe_vlm.py new file mode 100644 index 00000000000..cc50f5a2b61 --- /dev/null +++ b/examples/mimo/model_providers/nemotron_moe_vlm.py @@ -0,0 +1,745 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Model providers and configs for MIMO VLM examples.""" + +from __future__ import annotations + +import argparse +from contextlib import nullcontext +from typing import Optional + +import torch + +from examples.mimo.utils.hetero import ( + debug_rank, + get_grid_dim_size, + get_group_rank_or, + get_group_size_or, + is_process_group_member, +) +from megatron.core.activations import fast_gelu, squared_relu +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec +from megatron.core.models.mamba.mamba_model import MambaModel +from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules +from megatron.core.models.multimodal.llava_model import pixel_shuffle +from megatron.core.models.vision.multimodal_projector import MultimodalProjector +from megatron.core.models.vision.radio import RADIOViTModel +from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.enums import AttnBackend +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.utils import sharded_state_dict_default + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TELayerNormColumnParallelLinear, + TERowParallelLinear, + ) +except ImportError: + TEColumnParallelLinear = None + TELayerNormColumnParallelLinear = None + TERowParallelLinear = None + +MOCK_MODEL_PROVIDER = "mock" +NEMOTRON_20L_MODEL_PROVIDER = "nemotron-moe-vlm-20l" +NEMOTRON_54L_MODEL_PROVIDER = "nemotron-moe-vlm-54l" +NEMOTRON_20L_IMAGE_SEQ_PER_TILE = 256 +NEMOTRON_20L_MAX_NUM_TILES = 12 +NEMOTRON_20L_DEFAULT_STAGE = "stage2" +MOCK_VISION_ENCODER_KEY = "clip_encoder" +NEMOTRON_VISION_ENCODER_KEY = "radio_encoder" + + +def is_nemotron_20l(args: argparse.Namespace) -> bool: + """Return whether the Nemotron6-MoE VLM 20L provider is active.""" + return args.model_provider == NEMOTRON_20L_MODEL_PROVIDER + + +def is_nemotron_moe_vlm(args: argparse.Namespace) -> bool: + """Return whether a Nemotron6-MoE VLM provider is active.""" + return args.model_provider in (NEMOTRON_20L_MODEL_PROVIDER, NEMOTRON_54L_MODEL_PROVIDER) + + +def add_model_provider_args(parser: argparse.ArgumentParser) -> None: + """Register model-provider arguments for hetero MIMO examples.""" + provider = parser.add_argument_group("model provider") + provider.add_argument( + "--model-provider", + choices=[MOCK_MODEL_PROVIDER, NEMOTRON_20L_MODEL_PROVIDER, NEMOTRON_54L_MODEL_PROVIDER], + default=MOCK_MODEL_PROVIDER, + ) + provider.add_argument("--hidden-size", type=int, default=128) + provider.add_argument("--num-layers", type=int, default=2) + provider.add_argument("--num-attention-heads", type=int, default=8) + provider.add_argument("--vocab-size", type=int, default=512) + provider.add_argument("--seq-length", type=int, default=32) + provider.add_argument("--image-seq-length", type=int, default=None) + provider.add_argument("--image-token-id", type=int, default=511) + provider.add_argument("--pad-token-id", type=int, default=0) + provider.add_argument("--image-token", type=str, default="") + provider.add_argument("--tokenizer-model", type=str, default=None) + provider.add_argument("--tokenizer-prompt-format", type=str, default="nemotron6-moe") + provider.add_argument("--image-tag-type", type=str, default="") + provider.add_argument("--force-system-message", action="store_true") + provider.add_argument("--num-moe-experts", type=int, default=4) + provider.add_argument("--moe-router-topk", type=int, default=1) + provider.add_argument( + "--moe-router-force-load-balancing", + action="store_true", + help="Use random router logits to force MoE load balancing for benchmark/debug runs.", + ) + provider.add_argument("--moe-grouped-gemm", action="store_true") + provider.add_argument("--img-h", type=int, default=512) + provider.add_argument("--img-w", type=int, default=512) + provider.add_argument("--patch-dim", type=int, default=16) + provider.add_argument("--class-token-len", type=int, default=8) + provider.add_argument( + "--num-image-tiles", + "--max-num-tiles", + dest="num_image_tiles", + type=int, + default=NEMOTRON_20L_MAX_NUM_TILES, + ) + provider.add_argument("--vision-model-type", type=str, default="radio") + provider.add_argument("--pixel-shuffle", action="store_true") + provider.add_argument("--disable-vision-class-token", action="store_true") + provider.add_argument("--use-tiling", action="store_true") + provider.add_argument("--use-thumbnail", action="store_true") + provider.add_argument( + "--dynamic-resolution", + action=argparse.BooleanOptionalAction, + default=None, + help=( + "Patchify each image at its native aspect ratio with a token budget instead of " + "fixed-tile resize. Enabled by default for Nemotron6-MoE VLM providers. " + "Pass --no-dynamic-resolution to disable." + ), + ) + provider.add_argument( + "--dynamic-resolution-min-patches", + type=int, + default=4, + help="Lower bound on per-image patch count under dynamic resolution.", + ) + provider.add_argument( + "--dynamic-resolution-max-patches", + type=int, + default=0, + help="Upper bound on per-image patch count under dynamic resolution; 0 = uncapped.", + ) + provider.add_argument("--freeze-lm", action="store_true") + provider.add_argument("--freeze-vit", action="store_true") + provider.add_argument("--freeze-projection", action="store_true") + provider.add_argument("--training-stage", choices=["stage1", "stage2", "stage3"], default=None) + provider.add_argument("--fp32", action="store_true") + + +def prepare_model_provider_args(args: argparse.Namespace) -> None: + """Apply provider defaults and derived tokenizer/vision settings.""" + apply_model_provider_defaults(args) + apply_training_stage(args) + resolve_image_token_id(args) + args.vision_encoder_key = get_encoder_module_name(args) + args.vision_input_mode = "pixels" if is_nemotron_moe_vlm(args) else "hidden_states" + + +def apply_model_provider_defaults(args: argparse.Namespace) -> None: + """Apply Nemotron6-MoE VLM model defaults.""" + if not is_nemotron_moe_vlm(args): + return + + args.num_layers = 54 if args.model_provider == NEMOTRON_54L_MODEL_PROVIDER else 20 + args.hidden_size = 2688 + args.num_attention_heads = 32 + args.num_moe_experts = 128 + args.moe_router_topk = 6 + args.moe_grouped_gemm = True + args.hybrid_layer_pattern = ( + "MEMEM*EMEM*EMEM*EMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEME" + if args.model_provider == NEMOTRON_54L_MODEL_PROVIDER + else "MEMEM*EMEMEM*EMEMEM*" + ) + args.seq_length = 8192 + args.image_seq_length = NEMOTRON_20L_IMAGE_SEQ_PER_TILE * args.num_image_tiles + args.pixel_shuffle = True + args.disable_vision_class_token = True + if args.dynamic_resolution is None: + args.dynamic_resolution = True + if args.dynamic_resolution: + # Dynamic-resolution strategy reads `use_thumbnail` inside + # `DynamicResolutionImageTilingStrategy` and emits an extra thumbnail + # tile when True. `use_tiling` is inert in this branch (the fixed-tile + # path is unreachable), but pin it False for args-dump parity. + args.use_tiling = False + args.use_thumbnail = False + else: + args.use_tiling = True + args.use_thumbnail = True + + +def apply_training_stage(args: argparse.Namespace) -> None: + """Apply stage-specific freeze flags for the Nemotron VLM recipe.""" + if not is_nemotron_moe_vlm(args): + return + + stage = args.training_stage or NEMOTRON_20L_DEFAULT_STAGE + if stage == "stage1": + args.freeze_vit = True + args.freeze_lm = True + elif stage == "stage2": + args.freeze_vit = True + elif stage != "stage3": + raise ValueError(f"unsupported Nemotron VLM training stage: {stage}") + args.training_stage = stage + + +def resolve_image_token_id(args: argparse.Namespace) -> None: + """Resolve image, pad, and vocab ids from the configured tokenizer.""" + if not is_nemotron_moe_vlm(args) or not args.tokenizer_model: + return + + from megatron.core.tokenizers.vision.libraries.multimodal_tokenizer import ( + MegatronMultimodalTokenizer, + ) + + tokenizer = MegatronMultimodalTokenizer( + path=args.tokenizer_model, + prompt_format=args.tokenizer_prompt_format, + special_tokens=[args.image_token], + image_tag_type=args.image_tag_type, + force_system_message=args.force_system_message, + ) + image_token_id = tokenizer.convert_tokens_to_ids(args.image_token) + if image_token_id is None: + raise RuntimeError( + f"tokenizer at {args.tokenizer_model} did not produce an id for {args.image_token}" + ) + args.image_token_id = int(image_token_id) + if tokenizer.pad is not None: + args.pad_token_id = int(tokenizer.pad) + if tokenizer.vocab_size is not None: + args.vocab_size = int(tokenizer.vocab_size) + + +def validate_model_provider_args(args: argparse.Namespace) -> None: + """Validate derived model-provider arguments.""" + if args.hidden_size % args.num_attention_heads != 0: + raise ValueError("--hidden-size must be divisible by --num-attention-heads") + if not 0 <= args.image_token_id < args.vocab_size: + raise ValueError("--image-token-id must be within --vocab-size") + if not 0 <= args.pad_token_id < args.vocab_size: + raise ValueError("--pad-token-id must be within --vocab-size") + + +def _pixel_shuffle_dynamic_res(x, imgs_sizes, patch_dim, scale_factor=0.5, version=2): + """Pixel shuffle for dynamic resolution (variable tile sizes). + + Splits the packed sequence by per-tile lengths, applies pixel shuffle to each + tile, then re-concatenates. Mirrors sasatheesh/pre-vlm-05's + llava_model.pixel_shuffle_dynamic_res; vendored here to avoid touching the + upstream-owned llava_model.py. + """ + seq_lens = torch.prod(imgs_sizes // patch_dim, dim=-1) + splits = torch.split(x, seq_lens.tolist(), dim=-2) + + out = [] + for i, sv in enumerate(splits): + h = imgs_sizes[i][0] // patch_dim + w = imgs_sizes[i][1] // patch_dim + sv = sv.reshape(sv.shape[0], h, w, -1) + + n, h, w, c = sv.size() + sv = sv.view(n, h, int(w * scale_factor), int(c / scale_factor)) + sv = sv.permute(0, 2, 1, 3).contiguous() + sv = sv.view( + n, + int(w * scale_factor), + int(h * scale_factor), + int(c / (scale_factor * scale_factor)), + ) + + if version == 2: + sv = sv.permute(0, 2, 1, 3).contiguous() + + sv = sv.reshape(sv.shape[0], -1, sv.shape[-1]) + out.append(sv) + + return torch.cat(out, dim=-2) + + +class RADIOEncoderWrapper(torch.nn.Module): + """RADIO encoder wrapper matching the Nemotron6-MoE VLM provider.""" + + def __init__( + self, + transformer_config: TransformerConfig, + transformer_layer_spec: ModuleSpec, + pg_collection: Optional[ProcessGroupCollection], + img_h: int, + img_w: int, + patch_dim: int, + class_token_len: int, + drop_class_token: bool = True, + apply_pixel_shuffle: bool = True, + force_eval_mode: bool = False, + dynamic_resolution: bool = False, + ) -> None: + super().__init__() + self.class_token_len = class_token_len + self.drop_class_token = drop_class_token + self.apply_pixel_shuffle = apply_pixel_shuffle + self.force_eval_mode = force_eval_mode + self.dynamic_resolution = dynamic_resolution + self.radio_model = RADIOViTModel( + transformer_config=transformer_config, + transformer_layer_spec=transformer_layer_spec, + patch_dim=patch_dim, + img_h=img_h, + img_w=img_w, + class_token_len=class_token_len, + add_class_token=True, + max_img_h=2048, + max_img_w=2048, + has_cpe=True, + embedder_bias=False, + dynamic_resolution=dynamic_resolution, + pg_collection=pg_collection, + ) + if self.force_eval_mode: + self.radio_model.eval() + + def train(self, mode: bool = True): + """Keep frozen RADIO in eval mode while allowing the projection to train.""" + super().train(mode) + if self.force_eval_mode: + self.radio_model.eval() + return self + + @property + def config(self): + """Expose the underlying RADIO config for DDP wrapping.""" + return self.radio_model.config + + def forward( + self, + x: torch.Tensor, + imgs_sizes: Optional[torch.Tensor] = None, + packed_seq_params=None, + ) -> torch.Tensor: + """Run RADIO, drop class tokens, and apply pixel shuffle.""" + context = torch.no_grad() if self.force_eval_mode else nullcontext() + debug_rank(f"RADIO forward start: input_shape={tuple(x.shape)}") + with context: + x = x.to(dtype=self.radio_model.embedder.weight.dtype) + embeddings = self.radio_model( + x, imgs_sizes=imgs_sizes, packed_seq_params=packed_seq_params + ) + debug_rank(f"RADIO forward done: output_shape={tuple(embeddings.shape)}") + if self.drop_class_token: + if ( + self.dynamic_resolution + and imgs_sizes is not None + and self.class_token_len > 0 + ): + # Class tokens are interleaved between tiles; build mask to remove them. + remove_mask = torch.full( + (embeddings.shape[-2],), True, dtype=torch.bool, device=embeddings.device + ) + patch_dim = self.radio_model.patch_dim + if torch.is_tensor(imgs_sizes): + seq_lens = torch.prod(imgs_sizes // patch_dim, dim=-1) + else: + seq_lens = torch.tensor( + [(h // patch_dim) * (w // patch_dim) for h, w in imgs_sizes] + ) + current_length = 0 + for sl in seq_lens: + remove_mask[current_length : current_length + self.class_token_len] = False + current_length += int(sl) + self.class_token_len + embeddings = embeddings[:, remove_mask, :] + else: + embeddings = embeddings[:, self.class_token_len :, :] + debug_rank(f"RADIO class tokens dropped: output_shape={tuple(embeddings.shape)}") + if self.apply_pixel_shuffle: + if self.dynamic_resolution and imgs_sizes is not None: + embeddings = _pixel_shuffle_dynamic_res( + embeddings, imgs_sizes, self.radio_model.patch_dim + ) + else: + embeddings = pixel_shuffle(embeddings, scale_factor=0.5) + debug_rank(f"RADIO pixel shuffle done: output_shape={tuple(embeddings.shape)}") + return embeddings + + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): + """Delegate checkpoint sharding to the wrapped RADIO model.""" + sharded_sd = {} + for name, child in self.named_children(): + sharded_sd.update( + sharded_state_dict_default(child, f"{prefix}{name}.", sharded_offsets, metadata) + ) + return sharded_sd + + +def get_encoder_module_name(args: argparse.Namespace) -> str: + """Return the concrete encoder key for the active vision provider.""" + return NEMOTRON_VISION_ENCODER_KEY if is_nemotron_moe_vlm(args) else MOCK_VISION_ENCODER_KEY + + +def get_vision_encoder_module(args: argparse.Namespace, vision_submodule): + """Return the provider-owned encoder module used for DDP config and freezing.""" + return vision_submodule.encoders[get_encoder_module_name(args)] + + +def iter_vision_projection_modules(vision_submodule): + """Return the provider-owned projection modules used for freeze-stage policy.""" + return iter(vision_submodule.input_projections) + + +def projection_layer_spec() -> ModuleSpec: + """Return the TE-backed projection MLP spec.""" + if TEColumnParallelLinear is None or TERowParallelLinear is None: + raise RuntimeError("TEColumnParallelLinear and TERowParallelLinear are required") + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules(linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear), + ) + + +def nemotron_projection_layer_spec() -> ModuleSpec: + """Return the Nemotron VLM RADIO-to-language projector layer spec.""" + if TELayerNormColumnParallelLinear is None or TERowParallelLinear is None: + raise RuntimeError("TELayerNormColumnParallelLinear and TERowParallelLinear are required") + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + ) + + +def nemotron_language_config( + args: argparse.Namespace, tp_size: int, pp_size: int, ep_size: int, expt_tp_size: int +) -> TransformerConfig: + """Build the Nemotron6-MoE language TransformerConfig.""" + bf16 = not args.fp32 + dtype = torch.bfloat16 if bf16 else torch.float32 + config = TransformerConfig( + num_layers=54 if args.model_provider == NEMOTRON_54L_MODEL_PROVIDER else 20, + hidden_size=2688, + num_attention_heads=32, + attention_backend=AttnBackend.flash, + num_query_groups=8, + ffn_hidden_size=1856, + kv_channels=128, + activation_func=squared_relu, + gated_linear_unit=False, + attention_dropout=0.0, + hidden_dropout=0.0, + normalization="RMSNorm", + add_bias_linear=False, + init_method_std=0.0173, + use_cpu_initialization=True, + variable_seq_lengths=True, + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=expt_tp_size, + sequence_parallel=tp_size > 1, + params_dtype=dtype, + pipeline_dtype=dtype, + bf16=bf16, + calculate_per_token_loss=True, + cross_entropy_loss_fusion=True, + cross_entropy_fusion_impl="te", + bias_activation_fusion=False, + masked_softmax_fusion=True, + persist_layer_norm=True, + bias_dropout_fusion=True, + recompute_granularity="selective", + recompute_modules=["core_attn"], + moe_ffn_hidden_size=1856, + num_moe_experts=128, + moe_router_topk=6, + moe_grouped_gemm=True, + moe_router_score_function="sigmoid", + moe_router_topk_scaling_factor=2.5, + moe_router_enable_expert_bias=True, + moe_router_dtype="fp32", + moe_router_load_balancing_type="seq_aux_loss", + moe_router_force_load_balancing=args.moe_router_force_load_balancing, + moe_router_fusion=False, + moe_aux_loss_coeff=1.0e-4, + moe_shared_expert_intermediate_size=3712, + moe_shared_expert_overlap=True, + moe_token_dispatcher_type="alltoall", + moe_permute_fusion=True, + use_fused_weighted_squared_relu=True, + is_hybrid_model=True, + mamba_num_heads=64, + mamba_head_dim=64, + mamba_num_groups=8, + mamba_state_dim=128, + linear_conv_kernel_dim=4, + ) + config.position_embedding_type = "none" + config.seq_length = 8192 + config.max_position_embeddings = 8192 + return config + + +def require_per_token_loss(config: TransformerConfig) -> None: + """The hetero MIMO loop scales both language and vision grads by real LM tokens.""" + if not config.calculate_per_token_loss: + raise ValueError("train_hetero.py requires calculate_per_token_loss=True") + + +def radio_vision_config(args: argparse.Namespace, tp_size: int, pp_size: int) -> TransformerConfig: + """Build the exact RADIO vision TransformerConfig from the 20L reference provider.""" + bf16 = not args.fp32 + dtype = torch.bfloat16 if bf16 else torch.float32 + config = TransformerConfig( + num_layers=32, + hidden_size=1280, + num_attention_heads=16, + use_cpu_initialization=True, + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + params_dtype=dtype, + pipeline_dtype=dtype, + bf16=bf16, + ) + config.kv_channels = 80 + config.num_query_groups = 16 + config.ffn_hidden_size = 5120 + config.gated_linear_unit = False + config.activation_func = fast_gelu + config.add_bias_linear = True + config.add_qkv_bias = True + config.normalization = "LayerNorm" + config.layernorm_epsilon = 1.0e-6 + config.layernorm_zero_centered_gamma = False + config.apply_rope_fusion = False + config.qk_layernorm = False + config.bias_activation_fusion = False + config.bias_dropout_fusion = False + config.attention_softmax_in_fp32 = True + config.attention_dropout = 0.0 + config.hidden_dropout = 0.0 + # Trigger TransformerBlock's final_layernorm allocation (matches sanj path). + config.mtp_num_layers = 0 + return config + + +def nemotron_projection_config(args: argparse.Namespace, tp_size: int) -> TransformerConfig: + """Build the exact RADIO-to-Nemotron projection config.""" + bf16 = not args.fp32 + dtype = torch.bfloat16 if bf16 else torch.float32 + config = TransformerConfig( + num_layers=1, + hidden_size=2688, + num_attention_heads=1, + use_cpu_initialization=True, + params_dtype=dtype, + pipeline_dtype=dtype, + bf16=bf16, + ) + config.tensor_model_parallel_size = tp_size + config.ffn_hidden_size = 4 * 5120 + config.bias_activation_fusion = False + config.bias_dropout_fusion = False + config.add_bias_linear = False + config.activation_func = squared_relu + config.normalization = "RMSNorm" + return config + + +def language_model_spec( + args: argparse.Namespace, + pg_collection: Optional[ProcessGroupCollection], + llm_grid: HyperCommGrid, +) -> ModuleSpec: + """Create the language ModuleSpec for the local language grid.""" + pp_pg = getattr(pg_collection, "pp", None) if pg_collection is not None else None + tp_pg = getattr(pg_collection, "tp", None) if pg_collection is not None else None + ep_pg = getattr(pg_collection, "ep", None) if pg_collection is not None else None + expt_tp_pg = getattr(pg_collection, "expt_tp", None) if pg_collection is not None else None + + fallback_tp_size = get_grid_dim_size(llm_grid, "tp") + pp_rank = get_group_rank_or(pp_pg) + pp_size = get_group_size_or(pp_pg, get_grid_dim_size(llm_grid, "pp")) + tp_size = get_group_size_or(tp_pg, fallback_tp_size) + ep_size = get_group_size_or(ep_pg, args.llm_ep) + expt_tp_size = get_group_size_or(expt_tp_pg, args.llm_expt_tp or fallback_tp_size) + if is_nemotron_moe_vlm(args): + config = nemotron_language_config(args, tp_size, pp_size, ep_size, expt_tp_size) + require_per_token_loss(config) + return ModuleSpec( + module=MambaModel, + params={ + "config": config, + "mamba_stack_spec": mamba_stack_spec, + "vocab_size": args.vocab_size, + "max_sequence_length": args.seq_length, + "pre_process": pp_rank == 0, + "post_process": pp_rank == pp_size - 1, + "hybrid_layer_pattern": args.hybrid_layer_pattern, + "position_embedding_type": "none", + "share_embeddings_and_output_weights": False, + "scatter_embedding_sequence_parallel": False, + "pg_collection": pg_collection, + }, + ) + + num_moe_experts = args.num_moe_experts if args.num_moe_experts > 0 else None + bf16 = not args.fp32 + moe_kwargs = {} + if num_moe_experts is not None: + moe_kwargs = { + "num_moe_experts": num_moe_experts, + "moe_router_topk": args.moe_router_topk, + "moe_router_pre_softmax": args.moe_router_topk == 1, + "expert_model_parallel_size": ep_size, + "expert_tensor_parallel_size": expt_tp_size, + "moe_grouped_gemm": args.moe_grouped_gemm, + } + + config = TransformerConfig( + num_layers=args.num_layers, + hidden_size=args.hidden_size, + num_attention_heads=args.num_attention_heads, + use_cpu_initialization=True, + variable_seq_lengths=True, + moe_token_dispatcher_type="alltoall", + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + pipeline_dtype=torch.bfloat16 if bf16 else torch.float32, + bf16=bf16, + calculate_per_token_loss=True, + cross_entropy_loss_fusion=True, + cross_entropy_fusion_impl="te", + **moe_kwargs, + ) + require_per_token_loss(config) + return ModuleSpec( + module=GPTModel, + params={ + "config": config, + "transformer_layer_spec": get_gpt_layer_with_transformer_engine_spec( + num_experts=num_moe_experts, moe_grouped_gemm=args.moe_grouped_gemm + ), + "vocab_size": args.vocab_size, + "max_sequence_length": args.seq_length, + "pre_process": pp_rank == 0, + "post_process": pp_rank == pp_size - 1, + "pg_collection": pg_collection, + }, + ) + + +def vision_submodules_spec( + args: argparse.Namespace, + pg_collection: Optional[ProcessGroupCollection], + encoder_grid: HyperCommGrid, +) -> ModuleSpec: + """Create the vision ModuleSpec for the local encoder grid.""" + from megatron.core.transformer.transformer_block import TransformerBlock + + pp_pg = getattr(pg_collection, "pp", None) if pg_collection is not None else None + tp_pg = getattr(pg_collection, "tp", None) if pg_collection is not None else None + tp_size = get_group_size_or(tp_pg, get_grid_dim_size(encoder_grid, "tp")) + pp_size = get_group_size_or(pp_pg, get_grid_dim_size(encoder_grid, "pp")) + pp_rank = get_group_rank_or(pp_pg) + bf16 = not args.fp32 + + if is_nemotron_moe_vlm(args): + vision_config = radio_vision_config(args, tp_size, pp_size) + vision_encoder_spec = ModuleSpec( + module=RADIOEncoderWrapper, + params={ + "transformer_config": vision_config, + "transformer_layer_spec": get_vit_layer_with_transformer_engine_spec(), + "pg_collection": pg_collection, + "img_h": args.img_h, + "img_w": args.img_w, + "patch_dim": args.patch_dim, + "class_token_len": args.class_token_len, + "drop_class_token": True, + "apply_pixel_shuffle": True, + "force_eval_mode": args.freeze_vit, + "dynamic_resolution": bool(getattr(args, "dynamic_resolution", False)), + }, + ) + vision_projection_spec = ModuleSpec( + module=MultimodalProjector, + params={ + "config": nemotron_projection_config(args, tp_size), + "submodules": nemotron_projection_layer_spec().submodules, + "projector_type": "mlp", + "input_size": 5120, + "tp_group": tp_pg if is_process_group_member(tp_pg) else None, + }, + ) + return ModuleSpec( + module=VisionModalitySubmodules, + params={"pg_collection": pg_collection}, + submodules={ + "encoders": {NEMOTRON_VISION_ENCODER_KEY: vision_encoder_spec}, + "input_projections": [vision_projection_spec], + }, + ) + + vision_config = TransformerConfig( + num_layers=args.num_layers, + hidden_size=args.hidden_size, + num_attention_heads=args.num_attention_heads, + use_cpu_initialization=True, + variable_seq_lengths=True, + moe_token_dispatcher_type="alltoall", + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + pipeline_dtype=torch.bfloat16 if bf16 else torch.float32, + bf16=bf16, + calculate_per_token_loss=True, + ) + vision_encoder_spec = ModuleSpec( + module=TransformerBlock, + params={ + "config": vision_config, + "spec": get_gpt_layer_with_transformer_engine_spec(), + "pg_collection": pg_collection, + "pre_process": pp_rank == 0, + "post_process": pp_rank == pp_size - 1, + }, + ) + + projection_config = TransformerConfig( + num_layers=1, hidden_size=args.hidden_size, num_attention_heads=1 + ) + projection_config.ffn_hidden_size = args.hidden_size + projection_config.activation_func = torch.nn.functional.gelu + + vision_projection_spec = ModuleSpec( + module=MultimodalProjector, + params={ + "config": projection_config, + "submodules": projection_layer_spec().submodules, + "projector_type": "mlp", + "input_size": vision_config.hidden_size, + "tp_group": tp_pg if is_process_group_member(tp_pg) else None, + }, + ) + + return ModuleSpec( + module=VisionModalitySubmodules, + params={"pg_collection": pg_collection}, + submodules={ + "encoders": {MOCK_VISION_ENCODER_KEY: vision_encoder_spec}, + "input_projections": [vision_projection_spec], + }, + ) diff --git a/examples/mimo/scripts/README.md b/examples/mimo/scripts/README.md new file mode 100644 index 00000000000..3b8ef84e91c --- /dev/null +++ b/examples/mimo/scripts/README.md @@ -0,0 +1,12 @@ +# MIMO hetero training sbatches + +| Script | Nodes | Layout | GBS | Purpose | +|---|---|---|---|---| +| sbatch_hetero_parity_gbs192.sh | 9 | 1 enc + 8 LLM, TP=2 EP=16 | 192 | 9n Sanjeev parity test (5000 iters, paired with sbatch_sanjeev_parity_gbs192.sh) | +| sbatch_hetero_prod_gbs768_33n_ep8.sh | 33 | 1 enc + 32 LLM, TP=2 EP=8 | 768 | 33n production | +| sbatch_hetero_prod_gbs768_68n_ep8.sh | 68 | 4 enc + 64 LLM, TP=2 EP=8 | 768 | 68n production | +| sbatch_hetero_prod_gbs768_100n.sh | 100 | 4 enc + 96 LLM, TP=2 EP=8 | 768 | 100n production | + +Production sbatches use Sanjeev's WSD schedule (`TRAIN_SAMPLES=122070313`, `LR_WARMUP_SAMPLES=1024000`, `LR_WSD_DECAY_SAMPLES=18310547`) with EP=8 (vs Sanjeev's EP=16), no MTP, no force-LB. Load LLM weights via `--load-nemotron-checkpoint` from sasatheesh's `iter_0001000`. + +Launch: `sbatch examples/mimo/scripts/