Skip to content

[qwen3_5] evolve qwen3_vl to qwen3_5#3371

Merged
shuhuayu merged 7 commits into
pytorch:mainfrom
shuhuayu:modeldev
Jun 10, 2026
Merged

[qwen3_5] evolve qwen3_vl to qwen3_5#3371
shuhuayu merged 7 commits into
pytorch:mainfrom
shuhuayu:modeldev

Conversation

@shuhuayu

@shuhuayu shuhuayu commented May 15, 2026

Copy link
Copy Markdown
Contributor

Qwen3.5 supersedes Qwen3-VL with a hybrid attention architecture: 75% GatedDeltaNet (linear attention) + 25% full attention with output gating and partial RoPE.

Model changes:

  • Hybrid decoder with GatedDeltaNet and Qwen35Attention
  • Head-sharded TP on GatedDeltaNet projections (ColwiseParallel/RowwiseParallel)
  • OffsetRMSNorm, RMSNormGated, MoE with shared expert
  • Removed DeepStack

Parallelisms: fsdp, tp+sp, ep, pp, verified identical logits via numerical tests (scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py).

Numerical parity: kl ~3e-7 against hf models (4b, multimodal) and 100% top-1/top-5 match (scripts/checkpoint_conversion/numerical_tests_qwen3_5.py).

Many thanks to @gali-leilei for initiating the effort of enabling qwen3.5 decoder in torchtitan in #2545, some components are reused in this pr.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 15, 2026
@shuhuayu shuhuayu force-pushed the modeldev branch 2 times, most recently from ee4a27a to af12fc7 Compare May 15, 2026 21:25
@shuhuayu shuhuayu force-pushed the modeldev branch 2 times, most recently from e8fb20e to 6c60af1 Compare May 15, 2026 21:34
Comment thread torchtitan/models/utils.py Outdated
head_dims: int,
seq_len: int,
*,
num_full_attn: int | None = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can compute this from model_config right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct, removed.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

Comment thread torchtitan/models/qwen3_5/README.md Outdated

End-to-end KL divergence against HuggingFace Transformers (4B, multimodal inputs): **~3e-7** average, with **100% top-1 and top-5 match**.

Parallelism correctness: bitwise identical logits across no-parallel, FSDP, FSDP+EP, FSDP+EP+TP, and FSDP+EP+TP+CP configs.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, how could this be true? Different parallelisms have different reductions

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right. what the script did is just near identical numerically.

Comment thread torchtitan/models/qwen3_5/sharding.py Outdated
mesh, plc = x.device_mesh, x.placements
w = self.weight
if isinstance(w, DTensor):
w = w.to_local()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With spmd_types, hopefully we don't need to do this manual conversion.

For now, let's do to_local in the module, similar to GroupedExperts, and use LocalMapConfig to convert inputs, instead of patching forward.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds great, refactored to the style used in groupedexperts.

Comment thread torchtitan/models/qwen3_5/sharding.py Outdated
F.interpolate's decomposition uses _unsafe_index which doesn't support
DTensor. Since pos_embed is Replicate, to_local is a no-op for data.

TODO: Remove once F.interpolate on FSDP2-managed DTensors is fixed upstream.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this can be fixed soon, let's wait.

)
edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names)

apply_fsdp(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we not need things like _apply_fsdp_to_vision_encoder any more?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was previously handled in fully_shard(model, **fsdp_config), but as you said we should separate it. Apply fsdp to vision encoder and treat vit as a single unit.

Comment thread torchtitan/models/qwen3_5/model.py Outdated
class Config(Module.Config):
layer_type: str # "full_attn" or "linear_attn"
attention: Qwen35Attention.Config | None = None
deltanet: GatedDeltaNet.Config | None = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
deltanet: GatedDeltaNet.Config | None = None
delta_net: GatedDeltaNet.Config | None = None

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

applied.

Comment thread torchtitan/models/qwen3_5/model.py Outdated
if self.moe_enabled:
moe_out = self.moe(h)
if self.shared_expert_enabled:
shared_out = torch.sigmoid(self.shared_gate(h)) * self.shared_ffn(h)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of doing this, can we extend https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/common/config_utils.py#L153 and use existing shared_expert inside MoE module?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extended in the common/config_utils.py. Currently only qwen3_5 uses this sigmoid gate, but this is a simple extension can be used later.

Comment thread torchtitan/models/qwen3_5/model.py Outdated

@dataclass(kw_only=True, slots=True)
class Config(Module.Config):
layer_type: str # "full_attn" or "linear_attn"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't really need this? The config can be built that this block either has attention / deltanet. Refer to how feed_forward vs. moe is selected.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, this is redundant since attention and delta_net already indicates this.

for block in model.layers.values() # pyrefly: ignore [not-callable]
if block.layer_type == "full_attn" # pyrefly: ignore [missing-attribute]
]
if full_attn_inner_modules:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see an "else" here -- how are you handling sharded activation on linear attention layers?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used Replicate() for that. but as discussed in a previous thread, current cp is inefficient and beat the purpose of supporting it. cp is removed for now.

# runs inside the local_map boundary on local tensors.
# Applies to full attention layers only — GatedDeltaNet is recurrent
# and allgathers the full sequence via cp=Replicate() in sharding.
if parallel_dims.cp_enabled:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since CP is non-trivial, let's just raise NotImplementedError
https://www.internalfb.com/metamate/M4978C

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed.

@felipemello1 felipemello1 changed the title [qwen3_5] evovle qwen3_vl to qwen3_5 [qwen3_5] evolve qwen3_vl to qwen3_5 May 16, 2026

@shuhuayu shuhuayu left a comment

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a TODO on conv1d waiting for dtensor support in pytorch/pytorch#186129

Comment thread torchtitan/models/qwen3_5/model.py Outdated
Comment on lines +314 to +326
self.kernel = GatedDeltaKernel.Config(backend=config.fla_backend).build()

self.norm = RMSNormGated.Config(
dim=config.value_head_dim,
eps=config.norm_eps,
param_init=config.norm_init,
).build()
self.out_proj = Linear.Config(
in_features=value_dim,
out_features=config.dim,
bias=False,
param_init=config.out_proj_init,
).build()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense, submodule configs are moved to module.config.

Comment thread torchtitan/models/qwen3_5/model.py Outdated

@dataclass(kw_only=True, slots=True)
class Config(Module.Config):
layer_type: str # "full_attn" or "linear_attn"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, this is redundant since attention and delta_net already indicates this.

Comment thread torchtitan/models/qwen3_5/model.py Outdated
class Config(Module.Config):
layer_type: str # "full_attn" or "linear_attn"
attention: Qwen35Attention.Config | None = None
deltanet: GatedDeltaNet.Config | None = None

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

applied.

Comment thread torchtitan/models/qwen3_5/model.py Outdated
if self.moe_enabled:
moe_out = self.moe(h)
if self.shared_expert_enabled:
shared_out = torch.sigmoid(self.shared_gate(h)) * self.shared_ffn(h)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extended in the common/config_utils.py. Currently only qwen3_5 uses this sigmoid gate, but this is a simple extension can be used later.


LayerNorm = Module.from_nn_module(nn.LayerNorm)
GELU = Module.from_nn_module(nn.GELU)

_compiled_create_block_mask = torch.compile(create_block_mask)


def get_vision_block_mask_mod(num_patch: torch.Tensor, max_num_patch: int):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this was a bug.

Comment thread torchtitan/models/qwen3_5/README.md Outdated

End-to-end KL divergence against HuggingFace Transformers (4B, multimodal inputs): **~3e-7** average, with **100% top-1 and top-5 match**.

Parallelism correctness: bitwise identical logits across no-parallel, FSDP, FSDP+EP, FSDP+EP+TP, and FSDP+EP+TP+CP configs.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right. what the script did is just near identical numerically.

Comment thread torchtitan/models/qwen3_5/sharding.py Outdated
mesh, plc = x.device_mesh, x.placements
w = self.weight
if isinstance(w, DTensor):
w = w.to_local()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds great, refactored to the style used in groupedexperts.

Comment on lines +306 to +309
wq: Linear.Config,
wk: Linear.Config,
wv: Linear.Config,
proj: Linear.Config,

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored.

Comment on lines +372 to +374
self.norm1 = LayerNorm(dim, eps=layer_norm_eps)
self.norm2 = LayerNorm(dim, eps=layer_norm_eps)
self.attn = VisionAttention(dim, n_heads, qkv=attn_qkv, proj=attn_proj)
self.attn = VisionAttention(

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored.

Comment thread torchtitan/models/utils.py Outdated
head_dims: int,
seq_len: int,
*,
num_full_attn: int | None = None,

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct, removed.

@shuhuayu shuhuayu force-pushed the modeldev branch 2 times, most recently from a0f6aed to d16d9e8 Compare June 4, 2026 00:19
Comment thread torchtitan/models/common/moe.py Outdated
router: TokenChoiceTopKRouter.Config
load_balance_coeff: float | None = 1e-3
shared_experts: FeedForward.Config | None = None
shared_expert_gate: Module.Config | None = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
shared_expert_gate: Module.Config | None = None
shared_experts_gate: Module.Config | None = None

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more accurate. the hf keys remain unchanged as shared_expert_gate.

enable_ep=enable_ep, enable_sp=enable_sp
)

if getattr(moe_cfg, "shared_expert_gate", None) is not None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need getattr? It seems always existing (could be None)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed. this one and a pre-existing getattr are removed.

Comment thread torchtitan/models/common/moe.py Outdated
Comment on lines +456 to +459
if self.shared_expert_gate is not None:
shared_out_BLD = (
torch.sigmoid(self.shared_expert_gate(x_BLD)) * shared_out_BLD
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the behavior under TP?
We used to assume on TP mesh shared_out_BLD is Partial, now there will be more collectives??
If TP is not supposed to be used (DP, EP only) as it's not efficient, then in sharding annotation, don't annotate / support TP.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, when tp is on and shared experts are used, Dtensor does not know we have already gathered from Shard(1) for the experts computation itself so it will do it twice and thus waste one collection. I redesigned the shared_experts module which now inherits from FeedForward.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually when tp is on, there are two duplicated all-gather for w1 and w3, which seems to me unnecessary. i rewrite it so one all gather for three: w1, w3, and optional gate.

Comment thread torchtitan/models/qwen3_5/sharding.py Outdated
def set_deltanet_conv1d_sharding(deltanet_module) -> None:
"""Set sharding on GatedDeltaNet sub-modules built inline.

Conv1d modules don't have Config fields, so their sharding must be

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread torchtitan/models/qwen3_5/model.py Outdated
Comment on lines +27 to +28
class _Conv1d(nn.Conv1d, Module):
pass

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread torchtitan/models/qwen3_5/model.py Outdated
Comment on lines +38 to +39
except ImportError:
_HAS_FLA = False

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it doesn't make sense to run this model with FLA. Let's put this in model specific requirements.txt and in CI.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you are saying it doesn't make sense to run it without FLA? added the dependency in .ci/docker/requirements-vlm.txt.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since it's in the requirements, can we remove such check, or put the raise here -- if one wants to run qwen3_5, they need to install fla, regardless of if they intend to use native impl or not

Comment thread torchtitan/models/qwen3_5/model.py Outdated
return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)


def _torch_naive_gated_delta(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe native, not naive

Comment thread torchtitan/models/qwen3_5/model.py Outdated
Comment on lines +302 to +315
if isinstance(w, DTensor):
w = w.to_local()
local_groups = w.size(0)
# pyrefly: ignore [no-matching-overload]
out = F.conv1d(
x.to_local(),
w,
None,
conv.stride,
conv.padding,
conv.dilation,
local_groups,
)
x = DTensor.from_local(out, mesh, plc, run_check=False)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use local_map, not to_local / from_local
specify gradient placement

Comment thread torchtitan/trainer.py Outdated
Comment on lines +604 to +606
l.attention
for l in self.model_config.layers
if getattr(l, "attention", None) is not None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be simplified... to just use getattr

Comment thread torchtitan/models/common/decoder.py Outdated
Comment on lines +114 to +116
l.attention
for l in self.layers
if getattr(l, "attention", None) is not None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given how frequent this is used, we probably should create a property in Decoder config to compute this.

@shuhuayu shuhuayu force-pushed the modeldev branch 3 times, most recently from 2adc33e to 2c48e0d Compare June 5, 2026 05:19

@tianyu-l tianyu-l left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure how popular the shared experts gate would be, so would like to stay conservative

Comment thread torchtitan/models/common/moe.py Outdated
)


class SharedExperts(FeedForward):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the gate thing is very much qwen3_5 specific, I would put this in qwen3_5 folder for now, and all other models still use FeedForward.

Comment thread torchtitan/models/common/moe.py Outdated
router: TokenChoiceTopKRouter.Config
load_balance_coeff: float | None = 1e-3
shared_experts: FeedForward.Config | None = None
shared_experts: SharedExperts.Config | None = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and since it inherits FeedForward, we can keep it unchanged.

non_blocking_capacity_factor=non_blocking_capacity_factor,
),
shared_experts=make_ffn_config(
shared_experts=make_shared_experts_config(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only do this to qwen3_5 shared experts

Comment thread torchtitan/models/qwen3_5/sharding.py Outdated
Comment on lines +42 to +46
_REPLICATE_PARAM = dense_param_placement(tp=Replicate())
_REPLICATE_STATE = ShardingConfig(
state_shardings={"weight": _REPLICATE_PARAM, "bias": _REPLICATE_PARAM}
)
_REPLICATE_ACT = dense_activation_placement(tp=Replicate())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we should share reference among all usages

Comment thread torchtitan/models/common/moe.py Outdated
out = super().forward(x)
if self.gate is not None:
# TODO: make the gate activation configurable (e.g. softmax, silu)
out = torch.sigmoid(self.gate(x)) * out

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.gate is Replicate
x is sharded
self.gate(x) is sharded -> replicate
out is partial -> final outcome is Partial

sounds correct.

Comment on lines +227 to +228
in_src_shardings={"x": dense_activation_placement(tp=shared_input_layout)},
in_dst_shardings={"x": dense_activation_placement(tp=Replicate())},

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is worth fixing even if we split up FeedForward and SharedExperts

@shuhuayu shuhuayu force-pushed the modeldev branch 2 times, most recently from 6ca2b22 to 9b68f71 Compare June 9, 2026 09:23
@shuhuayu

shuhuayu commented Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

Thanks for all the comments/suggestions! Some updates: 1) refactored mrope for per layer rope and moved its position building into dataloader, and add mrope_positions to trainer's input_dict in the extra_kwargs part for all pp stages. 2) refactored to spmd types. 3) refactored sharedexperts so it is now only in qwen3_5. 4) redid numerical tests which still passed.

@shuhuayu shuhuayu force-pushed the modeldev branch 2 times, most recently from 8cc31e7 to 72c923e Compare June 9, 2026 17:03
Comment thread torchtitan/trainer.py Outdated
extra_kwargs: dict[str, Any] = {}

positions = extra_inputs.pop("positions", None)
mrope_positions = extra_inputs.pop("mrope_positions", None)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is model detail, shouldn't be exposed in trainer.

Comment thread torchtitan/models/utils.py Outdated
head_dims: int,
seq_len: int,
*,
num_full_attn: int | None = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

Comment thread torchtitan/models/common/decoder.py Outdated
raise ValueError("Decoder config does not define RoPE max_seq_len.")

@property
def first_attn_config(self) -> BaseAttention.Config | None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def first_attn_config(self) -> BaseAttention.Config | None:
def first_attention(self) -> BaseAttention.Config | None:

Comment thread torchtitan/models/common/decoder.py Outdated
Comment on lines +145 to +147
raise ValueError(
"No layer with attention config found for TP validation."
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why raise, no Attention means all-good?

Comment thread torchtitan/models/common/decoder.py Outdated
Comment on lines +312 to +314
assert (
attn_config is not None
), "get_attention_masks requires an attention layer"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar, no attention sounds fine? E.g. some single pipeline stage only has DeltaNet module

Comment on lines +163 to +166
logger.info("Applied fully_shard to the Qwen3.5 model")

if training.enable_cpu_offload:
logger.info("Applied CPU Offloading to the Qwen3.5 model")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

)
# Vision encoder lives on the first stage alongside tok_embeddings
if hasattr(model, "vision_encoder") and model.vision_encoder is not None:
fqn_per_part[0].insert(0, "vision_encoder")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure how heavy vision_encoder it is, maybe worth investigating if we should adjust parallelism.pipeline_parallel_first_stage_less_layers later

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comments to reflect this forward-looking point.

config,
**kwargs,
) -> None:
Decoder.Config.update_from_config(self, config=config, **kwargs)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to bottom after #3595, whichever lands first @wwwjn

global_valid_tokens,
params,
extra_inputs,
{},

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe kill this field

Comment thread torchtitan/trainer.py
# maskless backend (e.g. the SDPA config used by the graph_trainer
# tests) still receives positions for RoPE but no masks — it relies on
# is_causal instead.
if isinstance(self.model_config, Decoder.Config) and positions is not None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

positions is not None lost

Comment thread torchtitan/trainer.py Outdated

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't need this line

Comment thread torchtitan/trainer.py
(e.g. positions, attention_masks), forwarded to every
pipeline-parallel stage.
"""
inputs = input_dict["input"]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inputs and labels are really not special and IMO not worth special handling, except for how labels is involved in loss computation. Can delay the general refactor to later.

Comment thread torchtitan/components/validate.py Outdated

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

@shuhuayu

Copy link
Copy Markdown
Contributor Author

thanks for the careful and sharp reviews! let's merge it to avoid more refactors and iterate later for bugs/features.

@shuhuayu shuhuayu merged commit fd712e8 into pytorch:main Jun 10, 2026
25 of 26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rl ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants