Skip to content

[rl] Register customized config parser to vllm + less vllm config dependency#3242

Merged
wwwjn merged 12 commits into
mainfrom
gh/wwwjn/20/head
May 11, 2026
Merged

[rl] Register customized config parser to vllm + less vllm config dependency#3242
wwwjn merged 12 commits into
mainfrom
gh/wwwjn/20/head

Conversation

@wwwjn

@wwwjn wwwjn commented May 6, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. Why we need this:

  • get rid of dependency on a HF format checkpoint folder when initializing. Don't implicitly depend on config.json as config source of truth

Another changes in this PR:

  • remove the round-trip translation from torchtitan config -> vllm config -> torchtitan config. Using closure to bypass.

[ghstack-poisoned]
@wwwjn wwwjn requested review from fegin, tianyu-l and wconstab as code owners May 6, 2026 19:07
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 6, 2026
@wwwjn wwwjn changed the title config parser [rl] Register customized config parser to vllm May 6, 2026
Comment thread torchtitan/experiments/rl/actors/generator.py Outdated
Comment thread torchtitan/experiments/rl/models/vllm_registry.py Outdated
Comment thread torchtitan/hf_datasets/text_datasets.py
Comment thread torchtitan/experiments/rl/models/vllm_config_parser.py Outdated
Comment thread torchtitan/experiments/rl/models/vllm_registry.py Outdated
Comment thread torchtitan/experiments/rl/models/vllm_registry.py Outdated
vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. It serves as 2 purpose:
1. get rid of dependency on a HF format checkpoint folder when initializing 
2. Passing customized args to VLLMModelWrapper, eg CompileConfig, skip_init_load_weights



[ghstack-poisoned]
Comment thread torchtitan/experiments/rl/actors/generator.py Outdated
Comment thread torchtitan/experiments/rl/models/vllm_wrapper.py Outdated
Comment thread torchtitan/experiments/rl/actors/generator.py Outdated
Comment thread torchtitan/experiments/rl/models/vllm_registry.py Outdated
Comment thread torchtitan/experiments/rl/models/vllm_registry.py
wwwjn added 4 commits May 6, 2026 13:33
…vllm"


vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. It serves as 2 purpose:
1. get rid of dependency on a HF format checkpoint folder when initializing 
2. Passing customized args to VLLMModelWrapper, eg CompileConfig, skip_init_load_weights



[ghstack-poisoned]
vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. It serves as 2 purpose:
1. get rid of dependency on a HF format checkpoint folder when initializing 
2. Passing customized args to VLLMModelWrapper, eg CompileConfig, skip_init_load_weights



[ghstack-poisoned]
…vllm"


vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. It serves as 2 purpose:
1. get rid of dependency on a HF format checkpoint folder when initializing 
2. Passing customized args to VLLMModelWrapper, eg CompileConfig, skip_init_load_weights



[ghstack-poisoned]
vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. It serves as 2 purpose:
1. get rid of dependency on a HF format checkpoint folder when initializing 
2. Passing customized args to VLLMModelWrapper, eg CompileConfig, skip_init_load_weights



[ghstack-poisoned]
@pytorch-bot pytorch-bot Bot added the ciflow/rl label May 7, 2026
@wwwjn wwwjn changed the title [rl] Register customized config parser to vllm [rl] Register customized config parser to vllm + less vllm config dependency May 7, 2026
wwwjn added 2 commits May 8, 2026 15:55
…vllm + less vllm config dependency"


vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. Why we need this:
- get rid of dependency on a HF format checkpoint folder when initializing. Don't implicitly depend on `config.json` as config source of truth

Another changes in this PR:
- remove the round-trip translation from torchtitan config -> vllm config -> torchtitan config. Using closure to bypass.


[ghstack-poisoned]
… config dependency"


vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. Why we need this:
- get rid of dependency on a HF format checkpoint folder when initializing. Don't implicitly depend on `config.json` as config source of truth

Another changes in this PR:
- remove the round-trip translation from torchtitan config -> vllm config -> torchtitan config. Using closure to bypass.


[ghstack-poisoned]
Comment thread torchtitan/distributed/utils.py
from torchtitan.components.optimizer import OptimizersContainer
from torchtitan.config import CommConfig, Configurable, TORCH_DTYPE_MAP
from torchtitan.config.configs import (
from torchtitan.config import (

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is just consolidate the import path

Comment thread torchtitan/experiments/rl/actors/generator.py
@@ -199,14 +214,17 @@ def __init__(
engine_kwargs = dict(
model=model_path,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this path for?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it serves 2 purpose:

  1. Loading tokenizer. This can be removed by passing tokenizer=tokenizer_path to EngineArgs.
    2.Initial_weight_loading: Will sort out the weight loading part for both trainer and generator in next PR.

After lifting both, we can pass some fake path, say "torchtitan", to vllm


assert vllm_config is not None, "vllm_config is required"

# PP and CP are not supported on this inference path

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this "raise ValueError" may better happen at grpo trainer post_init, to be consistent

here we only need assert

Comment thread torchtitan/experiments/rl/models/vllm_registry.py Outdated
Comment thread torchtitan/experiments/rl/models/vllm_registry.py Outdated
Comment thread torchtitan/experiments/rl/models/vllm_registry.py Outdated
**kwargs,
):
config_dict = model_spec_to_hf_config_dict(model_spec)
return config_dict, PretrainedConfig.from_dict(config_dict)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually very weird that the contract is both a config_dict and a cls(config_dict), sounds redundant to me

Comment thread torchtitan/experiments/rl/models/vllm_registry.py Outdated
Comment thread torchtitan/experiments/rl/models/vllm_registry.py Outdated
Comment thread torchtitan/experiments/rl/models/vllm_registry.py Outdated
wwwjn added 2 commits May 9, 2026 20:43
…vllm + less vllm config dependency"


vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. Why we need this:
- get rid of dependency on a HF format checkpoint folder when initializing. Don't implicitly depend on `config.json` as config source of truth

Another changes in this PR:
- remove the round-trip translation from torchtitan config -> vllm config -> torchtitan config. Using closure to bypass.


[ghstack-poisoned]
… config dependency"


vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. Why we need this:
- get rid of dependency on a HF format checkpoint folder when initializing. Don't implicitly depend on `config.json` as config source of truth

Another changes in this PR:
- remove the round-trip translation from torchtitan config -> vllm config -> torchtitan config. Using closure to bypass.


[ghstack-poisoned]
Comment thread torchtitan/experiments/rl/models/vllm_registry.py Outdated
Comment thread torchtitan/experiments/rl/grpo.py Outdated
Comment thread torchtitan/experiments/rl/models/vllm_wrapper.py Outdated
wwwjn added 2 commits May 11, 2026 08:27
…vllm + less vllm config dependency"


vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. Why we need this:
- get rid of dependency on a HF format checkpoint folder when initializing. Don't implicitly depend on `config.json` as config source of truth

Another changes in this PR:
- remove the round-trip translation from torchtitan config -> vllm config -> torchtitan config. Using closure to bypass.


[ghstack-poisoned]
… config dependency"


vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. Why we need this:
- get rid of dependency on a HF format checkpoint folder when initializing. Don't implicitly depend on `config.json` as config source of truth

Another changes in this PR:
- remove the round-trip translation from torchtitan config -> vllm config -> torchtitan config. Using closure to bypass.


[ghstack-poisoned]
@wwwjn wwwjn changed the base branch from gh/wwwjn/20/base to main May 11, 2026 15:49
@wwwjn wwwjn merged commit ca4c7f2 into main May 11, 2026
10 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rl ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants