[qwen3_5] evolve qwen3_vl to qwen3_5#3371
Conversation
ee4a27a to
af12fc7
Compare
e8fb20e to
6c60af1
Compare
| head_dims: int, | ||
| seq_len: int, | ||
| *, | ||
| num_full_attn: int | None = None, |
There was a problem hiding this comment.
can compute this from model_config right?
|
|
||
| End-to-end KL divergence against HuggingFace Transformers (4B, multimodal inputs): **~3e-7** average, with **100% top-1 and top-5 match**. | ||
|
|
||
| Parallelism correctness: bitwise identical logits across no-parallel, FSDP, FSDP+EP, FSDP+EP+TP, and FSDP+EP+TP+CP configs. |
There was a problem hiding this comment.
hmm, how could this be true? Different parallelisms have different reductions
There was a problem hiding this comment.
you are right. what the script did is just near identical numerically.
| mesh, plc = x.device_mesh, x.placements | ||
| w = self.weight | ||
| if isinstance(w, DTensor): | ||
| w = w.to_local() |
There was a problem hiding this comment.
With spmd_types, hopefully we don't need to do this manual conversion.
For now, let's do to_local in the module, similar to GroupedExperts, and use LocalMapConfig to convert inputs, instead of patching forward.
There was a problem hiding this comment.
sounds great, refactored to the style used in groupedexperts.
| F.interpolate's decomposition uses _unsafe_index which doesn't support | ||
| DTensor. Since pos_embed is Replicate, to_local is a no-op for data. | ||
|
|
||
| TODO: Remove once F.interpolate on FSDP2-managed DTensors is fixed upstream. |
There was a problem hiding this comment.
If this can be fixed soon, let's wait.
| ) | ||
| edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names) | ||
|
|
||
| apply_fsdp( |
There was a problem hiding this comment.
do we not need things like _apply_fsdp_to_vision_encoder any more?
There was a problem hiding this comment.
This was previously handled in fully_shard(model, **fsdp_config), but as you said we should separate it. Apply fsdp to vision encoder and treat vit as a single unit.
| class Config(Module.Config): | ||
| layer_type: str # "full_attn" or "linear_attn" | ||
| attention: Qwen35Attention.Config | None = None | ||
| deltanet: GatedDeltaNet.Config | None = None |
There was a problem hiding this comment.
| deltanet: GatedDeltaNet.Config | None = None | |
| delta_net: GatedDeltaNet.Config | None = None |
| if self.moe_enabled: | ||
| moe_out = self.moe(h) | ||
| if self.shared_expert_enabled: | ||
| shared_out = torch.sigmoid(self.shared_gate(h)) * self.shared_ffn(h) |
There was a problem hiding this comment.
instead of doing this, can we extend https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/common/config_utils.py#L153 and use existing shared_expert inside MoE module?
There was a problem hiding this comment.
extended in the common/config_utils.py. Currently only qwen3_5 uses this sigmoid gate, but this is a simple extension can be used later.
|
|
||
| @dataclass(kw_only=True, slots=True) | ||
| class Config(Module.Config): | ||
| layer_type: str # "full_attn" or "linear_attn" |
There was a problem hiding this comment.
don't really need this? The config can be built that this block either has attention / deltanet. Refer to how feed_forward vs. moe is selected.
There was a problem hiding this comment.
good point, this is redundant since attention and delta_net already indicates this.
| for block in model.layers.values() # pyrefly: ignore [not-callable] | ||
| if block.layer_type == "full_attn" # pyrefly: ignore [missing-attribute] | ||
| ] | ||
| if full_attn_inner_modules: |
There was a problem hiding this comment.
I didn't see an "else" here -- how are you handling sharded activation on linear attention layers?
There was a problem hiding this comment.
We used Replicate() for that. but as discussed in a previous thread, current cp is inefficient and beat the purpose of supporting it. cp is removed for now.
| # runs inside the local_map boundary on local tensors. | ||
| # Applies to full attention layers only — GatedDeltaNet is recurrent | ||
| # and allgathers the full sequence via cp=Replicate() in sharding. | ||
| if parallel_dims.cp_enabled: |
There was a problem hiding this comment.
Since CP is non-trivial, let's just raise NotImplementedError
https://www.internalfb.com/metamate/M4978C
shuhuayu
left a comment
There was a problem hiding this comment.
Left a TODO on conv1d waiting for dtensor support in pytorch/pytorch#186129
| self.kernel = GatedDeltaKernel.Config(backend=config.fla_backend).build() | ||
|
|
||
| self.norm = RMSNormGated.Config( | ||
| dim=config.value_head_dim, | ||
| eps=config.norm_eps, | ||
| param_init=config.norm_init, | ||
| ).build() | ||
| self.out_proj = Linear.Config( | ||
| in_features=value_dim, | ||
| out_features=config.dim, | ||
| bias=False, | ||
| param_init=config.out_proj_init, | ||
| ).build() |
There was a problem hiding this comment.
make sense, submodule configs are moved to module.config.
|
|
||
| @dataclass(kw_only=True, slots=True) | ||
| class Config(Module.Config): | ||
| layer_type: str # "full_attn" or "linear_attn" |
There was a problem hiding this comment.
good point, this is redundant since attention and delta_net already indicates this.
| class Config(Module.Config): | ||
| layer_type: str # "full_attn" or "linear_attn" | ||
| attention: Qwen35Attention.Config | None = None | ||
| deltanet: GatedDeltaNet.Config | None = None |
| if self.moe_enabled: | ||
| moe_out = self.moe(h) | ||
| if self.shared_expert_enabled: | ||
| shared_out = torch.sigmoid(self.shared_gate(h)) * self.shared_ffn(h) |
There was a problem hiding this comment.
extended in the common/config_utils.py. Currently only qwen3_5 uses this sigmoid gate, but this is a simple extension can be used later.
|
|
||
| LayerNorm = Module.from_nn_module(nn.LayerNorm) | ||
| GELU = Module.from_nn_module(nn.GELU) | ||
|
|
||
| _compiled_create_block_mask = torch.compile(create_block_mask) | ||
|
|
||
|
|
||
| def get_vision_block_mask_mod(num_patch: torch.Tensor, max_num_patch: int): |
There was a problem hiding this comment.
yes, this was a bug.
|
|
||
| End-to-end KL divergence against HuggingFace Transformers (4B, multimodal inputs): **~3e-7** average, with **100% top-1 and top-5 match**. | ||
|
|
||
| Parallelism correctness: bitwise identical logits across no-parallel, FSDP, FSDP+EP, FSDP+EP+TP, and FSDP+EP+TP+CP configs. |
There was a problem hiding this comment.
you are right. what the script did is just near identical numerically.
| mesh, plc = x.device_mesh, x.placements | ||
| w = self.weight | ||
| if isinstance(w, DTensor): | ||
| w = w.to_local() |
There was a problem hiding this comment.
sounds great, refactored to the style used in groupedexperts.
| wq: Linear.Config, | ||
| wk: Linear.Config, | ||
| wv: Linear.Config, | ||
| proj: Linear.Config, |
| self.norm1 = LayerNorm(dim, eps=layer_norm_eps) | ||
| self.norm2 = LayerNorm(dim, eps=layer_norm_eps) | ||
| self.attn = VisionAttention(dim, n_heads, qkv=attn_qkv, proj=attn_proj) | ||
| self.attn = VisionAttention( |
| head_dims: int, | ||
| seq_len: int, | ||
| *, | ||
| num_full_attn: int | None = None, |
a0f6aed to
d16d9e8
Compare
| router: TokenChoiceTopKRouter.Config | ||
| load_balance_coeff: float | None = 1e-3 | ||
| shared_experts: FeedForward.Config | None = None | ||
| shared_expert_gate: Module.Config | None = None |
There was a problem hiding this comment.
| shared_expert_gate: Module.Config | None = None | |
| shared_experts_gate: Module.Config | None = None |
There was a problem hiding this comment.
more accurate. the hf keys remain unchanged as shared_expert_gate.
| enable_ep=enable_ep, enable_sp=enable_sp | ||
| ) | ||
|
|
||
| if getattr(moe_cfg, "shared_expert_gate", None) is not None: |
There was a problem hiding this comment.
why do we need getattr? It seems always existing (could be None)
There was a problem hiding this comment.
indeed. this one and a pre-existing getattr are removed.
| if self.shared_expert_gate is not None: | ||
| shared_out_BLD = ( | ||
| torch.sigmoid(self.shared_expert_gate(x_BLD)) * shared_out_BLD | ||
| ) |
There was a problem hiding this comment.
What's the behavior under TP?
We used to assume on TP mesh shared_out_BLD is Partial, now there will be more collectives??
If TP is not supposed to be used (DP, EP only) as it's not efficient, then in sharding annotation, don't annotate / support TP.
There was a problem hiding this comment.
you are right, when tp is on and shared experts are used, Dtensor does not know we have already gathered from Shard(1) for the experts computation itself so it will do it twice and thus waste one collection. I redesigned the shared_experts module which now inherits from FeedForward.
There was a problem hiding this comment.
actually when tp is on, there are two duplicated all-gather for w1 and w3, which seems to me unnecessary. i rewrite it so one all gather for three: w1, w3, and optional gate.
| def set_deltanet_conv1d_sharding(deltanet_module) -> None: | ||
| """Set sharding on GatedDeltaNet sub-modules built inline. | ||
|
|
||
| Conv1d modules don't have Config fields, so their sharding must be |
There was a problem hiding this comment.
Could you do similar things as in https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/common/nn_modules.py
| class _Conv1d(nn.Conv1d, Module): | ||
| pass |
There was a problem hiding this comment.
| except ImportError: | ||
| _HAS_FLA = False |
There was a problem hiding this comment.
I think it doesn't make sense to run this model with FLA. Let's put this in model specific requirements.txt and in CI.
There was a problem hiding this comment.
maybe you are saying it doesn't make sense to run it without FLA? added the dependency in .ci/docker/requirements-vlm.txt.
There was a problem hiding this comment.
since it's in the requirements, can we remove such check, or put the raise here -- if one wants to run qwen3_5, they need to install fla, regardless of if they intend to use native impl or not
| return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) | ||
|
|
||
|
|
||
| def _torch_naive_gated_delta( |
| if isinstance(w, DTensor): | ||
| w = w.to_local() | ||
| local_groups = w.size(0) | ||
| # pyrefly: ignore [no-matching-overload] | ||
| out = F.conv1d( | ||
| x.to_local(), | ||
| w, | ||
| None, | ||
| conv.stride, | ||
| conv.padding, | ||
| conv.dilation, | ||
| local_groups, | ||
| ) | ||
| x = DTensor.from_local(out, mesh, plc, run_check=False) |
There was a problem hiding this comment.
use local_map, not to_local / from_local
specify gradient placement
| l.attention | ||
| for l in self.model_config.layers | ||
| if getattr(l, "attention", None) is not None |
There was a problem hiding this comment.
can be simplified... to just use getattr
| l.attention | ||
| for l in self.layers | ||
| if getattr(l, "attention", None) is not None |
There was a problem hiding this comment.
given how frequent this is used, we probably should create a property in Decoder config to compute this.
2adc33e to
2c48e0d
Compare
tianyu-l
left a comment
There was a problem hiding this comment.
not sure how popular the shared experts gate would be, so would like to stay conservative
| ) | ||
|
|
||
|
|
||
| class SharedExperts(FeedForward): |
There was a problem hiding this comment.
Given the gate thing is very much qwen3_5 specific, I would put this in qwen3_5 folder for now, and all other models still use FeedForward.
| router: TokenChoiceTopKRouter.Config | ||
| load_balance_coeff: float | None = 1e-3 | ||
| shared_experts: FeedForward.Config | None = None | ||
| shared_experts: SharedExperts.Config | None = None |
There was a problem hiding this comment.
and since it inherits FeedForward, we can keep it unchanged.
| non_blocking_capacity_factor=non_blocking_capacity_factor, | ||
| ), | ||
| shared_experts=make_ffn_config( | ||
| shared_experts=make_shared_experts_config( |
There was a problem hiding this comment.
only do this to qwen3_5 shared experts
| _REPLICATE_PARAM = dense_param_placement(tp=Replicate()) | ||
| _REPLICATE_STATE = ShardingConfig( | ||
| state_shardings={"weight": _REPLICATE_PARAM, "bias": _REPLICATE_PARAM} | ||
| ) | ||
| _REPLICATE_ACT = dense_activation_placement(tp=Replicate()) |
There was a problem hiding this comment.
not sure if we should share reference among all usages
| out = super().forward(x) | ||
| if self.gate is not None: | ||
| # TODO: make the gate activation configurable (e.g. softmax, silu) | ||
| out = torch.sigmoid(self.gate(x)) * out |
There was a problem hiding this comment.
self.gate is Replicate
x is sharded
self.gate(x) is sharded -> replicate
out is partial -> final outcome is Partial
sounds correct.
| in_src_shardings={"x": dense_activation_placement(tp=shared_input_layout)}, | ||
| in_dst_shardings={"x": dense_activation_placement(tp=Replicate())}, |
There was a problem hiding this comment.
This is worth fixing even if we split up FeedForward and SharedExperts
6ca2b22 to
9b68f71
Compare
|
Thanks for all the comments/suggestions! Some updates: 1) refactored |
8cc31e7 to
72c923e
Compare
| extra_kwargs: dict[str, Any] = {} | ||
|
|
||
| positions = extra_inputs.pop("positions", None) | ||
| mrope_positions = extra_inputs.pop("mrope_positions", None) |
There was a problem hiding this comment.
No, this is model detail, shouldn't be exposed in trainer.
| head_dims: int, | ||
| seq_len: int, | ||
| *, | ||
| num_full_attn: int | None = None, |
| raise ValueError("Decoder config does not define RoPE max_seq_len.") | ||
|
|
||
| @property | ||
| def first_attn_config(self) -> BaseAttention.Config | None: |
There was a problem hiding this comment.
| def first_attn_config(self) -> BaseAttention.Config | None: | |
| def first_attention(self) -> BaseAttention.Config | None: |
| raise ValueError( | ||
| "No layer with attention config found for TP validation." | ||
| ) |
There was a problem hiding this comment.
why raise, no Attention means all-good?
| assert ( | ||
| attn_config is not None | ||
| ), "get_attention_masks requires an attention layer" |
There was a problem hiding this comment.
similar, no attention sounds fine? E.g. some single pipeline stage only has DeltaNet module
| logger.info("Applied fully_shard to the Qwen3.5 model") | ||
|
|
||
| if training.enable_cpu_offload: | ||
| logger.info("Applied CPU Offloading to the Qwen3.5 model") |
| ) | ||
| # Vision encoder lives on the first stage alongside tok_embeddings | ||
| if hasattr(model, "vision_encoder") and model.vision_encoder is not None: | ||
| fqn_per_part[0].insert(0, "vision_encoder") |
There was a problem hiding this comment.
not sure how heavy vision_encoder it is, maybe worth investigating if we should adjust parallelism.pipeline_parallel_first_stage_less_layers later
There was a problem hiding this comment.
added comments to reflect this forward-looking point.
| config, | ||
| **kwargs, | ||
| ) -> None: | ||
| Decoder.Config.update_from_config(self, config=config, **kwargs) |
| global_valid_tokens, | ||
| params, | ||
| extra_inputs, | ||
| {}, |
| # maskless backend (e.g. the SDPA config used by the graph_trainer | ||
| # tests) still receives positions for RoPE but no masks — it relies on | ||
| # is_causal instead. | ||
| if isinstance(self.model_config, Decoder.Config) and positions is not None: |
There was a problem hiding this comment.
positions is not None lost
| (e.g. positions, attention_masks), forwarded to every | ||
| pipeline-parallel stage. | ||
| """ | ||
| inputs = input_dict["input"] |
There was a problem hiding this comment.
inputs and labels are really not special and IMO not worth special handling, except for how labels is involved in loss computation. Can delay the general refactor to later.
|
thanks for the careful and sharp reviews! let's merge it to avoid more refactors and iterate later for bugs/features. |
Qwen3.5 supersedes Qwen3-VL with a hybrid attention architecture: 75% GatedDeltaNet (linear attention) + 25% full attention with output gating and partial RoPE.
Model changes:
ColwiseParallel/RowwiseParallel)Parallelisms: fsdp, tp+sp, ep, pp, verified identical logits via numerical tests (
scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py).Numerical parity: kl ~3e-7 against hf models (4b, multimodal) and 100% top-1/top-5 match (
scripts/checkpoint_conversion/numerical_tests_qwen3_5.py).Many thanks to @gali-leilei for initiating the effort of enabling qwen3.5 decoder in torchtitan in #2545, some components are reused in this pr.