Add Qwen 3.5 and Nemotron 3 (nemotron_h) model support#48
Add Qwen 3.5 and Nemotron 3 (nemotron_h) model support#48SolshineCode wants to merge 2 commits intoarcee-ai:mainfrom
Conversation
Extends DAM to support two new model architectures: Qwen 3.5 (model_type: qwen3_5/qwen3_5_text -> mergedqwen3_5): - Hybrid attention architecture with full (softmax) and linear (Gated Delta Network) attention layers, dispatched per-layer via layer_types - Full attention uses output gating, QK-normalization, and partial RoPE - Linear attention uses causal conv1d and recurrent delta rule updates - RMSNorm uses (1+weight) formulation via custom Qwen3_5DAMRMSNorm - SwiGLU MLP with all projections as DAMLinearLayer Nemotron-H (model_type: nemotron_h -> mergednemotron_h): - Three-block-type hybrid: Mamba-2 SSM, standard GQA attention, and MoE - Each block has a single mixer (not attention+MLP like standard transformers) - Mamba-2 with pure PyTorch recurrent forward (no mamba_ssm dependency) - MoE with sigmoid routing, top-k expert selection, and shared expert - Non-gated MLP experts with ReLU squared activation Both implementations follow the existing DAM patterns: - DAMLinearLayer for all nn.Linear projections - DAMEmbeddingLayer for token embeddings (conditional on dam_embedding_layer) - DAMRMSNorm for normalization (conditional on dam_layernorms) - Proper tie_weights support for weight-tied embeddings - Full ForCausalLM with prepare_inputs_for_generation Updated files: - dam/merge.py: fix_config() handles new model types - dam/model_preparation.py: AutoConfig/AutoModel registration - dam/utils.py: find_norm_layers() detects Qwen3_5RMSNorm and NemotronHRMSNorm with graceful fallback for older transformers versions Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 3 potential issues.
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
dam/modeling/nemotron/modeling.py
Outdated
|
|
||
| # Discretize | ||
| # dA: (B, num_heads, 1, 1) | ||
| dA = torch.exp(dt_t.unsqueeze(-1).unsqueeze(-1) * A.unsqueeze(-1).unsqueeze(0)) |
There was a problem hiding this comment.
Wrong A tensor reshape causes Mamba-2 runtime crash
High Severity
A.unsqueeze(-1).unsqueeze(0) produces a 3D tensor of shape (1, num_heads, 1) instead of the required 4D (1, num_heads, 1, 1). When broadcast-multiplied with the 4D dt_t.unsqueeze(-1).unsqueeze(-1) of shape (B, num_heads, 1, 1), PyTorch prepends a dimension, making A effectively (1, 1, num_heads, 1). The result dA has shape (B, num_heads, num_heads, 1) instead of (B, num_heads, 1, 1). The subsequent state * dA then fails because state is (B, 128, 64, 128) and dA is (B, 128, 128, 1) — dimension 2 has an incompatible size mismatch (64 vs 128), causing a runtime error. The fix is A.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) or A[None, :, None, None].
dam/modeling/qwen3_5/modeling.py
Outdated
| # g shape: (B, L, num_v_heads, 1) -- alpha was (B, L, num_v_heads), after unsqueeze(-1) + broadcast | ||
| # Actually alpha is (B, L, num_v_heads), dt_bias is (num_v_heads,), sum is (B, L, num_v_heads) | ||
| # softplus gives (B, L, num_v_heads), multiply gives (B, L, num_v_heads) | ||
| # We need to fix the shape computation: |
There was a problem hiding this comment.
Dead code from debugging left in decay computation
Medium Severity
Line 477 computes g with alpha.unsqueeze(-1) + self.dt_bias, which produces incorrect broadcasting ((B, L, num_v_heads, num_v_heads) instead of (B, L, num_v_heads)). Line 482 then overwrites g with the corrected computation. The first computation is dead code that wastes GPU memory by allocating a tensor quadratic in num_v_heads, along with the associated intermediate tensors from F.softplus and the multiplication. The surrounding comments confirm this was a debugging artifact that was never cleaned up.
dam/modeling/qwen3_5/modeling.py
Outdated
| lambda size, eps: Qwen3_5DAMRMSNorm(size, eps=eps, num_models=config.num_merged_models) | ||
| if config.dam_layernorms | ||
| else lambda size, eps: Qwen3_5RMSNorm(size, eps=eps) | ||
| ) |
There was a problem hiding this comment.
Unused NormClass variable defined but never referenced
Low Severity
NormClass is assigned as a lambda on lines 569–573 but is never used. The norm layers on lines 574–583 are constructed using inline conditional expressions instead. This appears to be a leftover from an earlier refactoring attempt and adds unnecessary clutter.
Cursor Bugbot fixes: - Fix Mamba-2 dA tensor reshape crash (A[None,:,None,None] for correct 4D broadcast) - Remove dead code in Qwen 3.5 decay gate computation - Remove unused NormClass lambda in decoder layer Gemini review fix: - Add null safety check on to_legacy_cache() in Qwen 3.5 model CPU smoke test fixes: - Add **kwargs to tie_weights() for transformers 5.x compatibility - Guard DynamicCache.from_legacy_cache with hasattr for forward compat All 40 CPU smoke tests pass (imports, config, model instantiation, forward pass, loss computation, DAM layer verification, fix_config, norm detection, AutoConfig registration). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
f0f7d0e to
7e3b5f0
Compare
|
Hey @shamanez @thomasgauthier @mryave
Both follow the existing Llama/Mistral DAM patterns exactly. All 40 CPU smoke tests pass (config, model instantiation, forward pass, loss computation, DAM layer verification, AutoConfig registration). Would appreciate a review when you get a chance. Happy to address any feedback. 🙏 |


Summary
Adds DAM (Differentiable Adaptive Merging) support for two new model architectures:
Qwen 3.5 (
qwen3_5/qwen3_5_text→mergedqwen3_5)layer_typeslistfladependency)(1 + weight)formulation via customQwen3_5DAMRMSNormclassDAMLinearLayerNemotron-H / Nemotron 3 (
nemotron_h→mergednemotron_h)layers_block_typemamba_ssmdependency required)DAMLinearLayerCore Changes
dam/merge.py:fix_config()handlesqwen3_5,qwen3_5_text, andnemotron_hmodel typesdam/model_preparation.py: AutoConfig/AutoModel registration for both new model typesdam/utils.py:find_norm_layers()extended to detectQwen3_5RMSNormandNemotronHRMSNormwith gracefulImportErrorfallback for older transformers versionsBoth implementations follow existing DAM conventions:
DAMLinearLayerfor allnn.LinearprojectionsDAMEmbeddingLayerfor token embeddings (conditional ondam_embedding_layer)DAMRMSNormfor normalization layers (conditional ondam_layernorms)tie_weightswith DAM embedding supportForCausalLMwithprepare_inputs_for_generationNew Files
Test plan
--use_base_modelflag for both architecturesfind_norm_layerscorrectly detects new norm typesNotes
mamba_ssmandflalibrariestransformers >= 4.53for Nemotron-H andtransformers >= 4.57for Qwen 3.5 norm detection inutils.py(gracefully falls back if unavailable)🤖 Generated with Claude Code
Note
Medium Risk
Adds two new model backends with substantial new modeling code (custom attention/SSM/MoE implementations and config wiring), which may affect correctness and generation behavior for these architectures. Existing Mistral/Llama paths are largely untouched aside from config/norm detection.
Overview
Adds DAM support for Qwen 3.5 and Nemotron-H by introducing new
Merged*Config/Merged*ForCausalLMimplementations and wiring them into the merge/training flow.dam/merge.pynow rewritesconfig.jsonforqwen3_5/qwen3_5_textandnemotron_hto the corresponding merged model types, anddam/model_preparation.pyregisters the new configs/models withAutoConfig/AutoModelForCausalLM.Updates
dam/utils.pysofind_norm_layers()can also detect Qwen/Nemotron RMSNorm classes when available (with ImportError fallbacks), ensuring norm layers are merged for these architectures.Written by Cursor Bugbot for commit 4f96a93. This will update automatically on new commits. Configure here.