Skip to content

Add Qwen 3.5 and Nemotron 3 (nemotron_h) model support#48

Open
SolshineCode wants to merge 2 commits intoarcee-ai:mainfrom
SolshineCode:feature/nemotron3-qwen3.5-support
Open

Add Qwen 3.5 and Nemotron 3 (nemotron_h) model support#48
SolshineCode wants to merge 2 commits intoarcee-ai:mainfrom
SolshineCode:feature/nemotron3-qwen3.5-support

Conversation

@SolshineCode
Copy link
Copy Markdown

@SolshineCode SolshineCode commented Apr 1, 2026

Summary

Adds DAM (Differentiable Adaptive Merging) support for two new model architectures:

Qwen 3.5 (qwen3_5 / qwen3_5_textmergedqwen3_5)

  • Hybrid attention architecture: Each layer uses either full softmax attention or linear attention (Gated Delta Network), configured via layer_types list
  • Full attention features: Output gating (Q projects to 2x, split into query + sigmoid gate), QK-normalization, partial RoPE (25% of head_dim), GQA
  • Linear attention (Gated Delta Net): Causal conv1d → recurrent delta rule with exponential decay gating, L2-normalized QK, gated RMSNorm output — pure PyTorch implementation (no external fla dependency)
  • RMSNorm: Uses Qwen 3.5's (1 + weight) formulation via custom Qwen3_5DAMRMSNorm class
  • MLP: Standard SwiGLU (gate_proj, up_proj, down_proj) — all DAMLinearLayer

Nemotron-H / Nemotron 3 (nemotron_hmergednemotron_h)

  • Three block types: Mamba-2 SSM, standard GQA attention, and Mixture-of-Experts — dispatched per-layer via layers_block_type
  • Single-mixer blocks: Each block has only one mixer (not attention+MLP), which is architecturally different from standard transformers
  • Mamba-2: Pure PyTorch recurrent forward with in_proj → conv1d → SSM scan → gated RMSNorm → out_proj (no mamba_ssm dependency required)
  • MoE: Sigmoid routing with top-k expert selection, normalized weights, scaling factor, and always-on shared expert
  • MLP: Non-gated (up_proj → ReLU² → down_proj) — all expert and shared expert projections are DAMLinearLayer

Core Changes

  • dam/merge.py: fix_config() handles qwen3_5, qwen3_5_text, and nemotron_h model types
  • dam/model_preparation.py: AutoConfig/AutoModel registration for both new model types
  • dam/utils.py: find_norm_layers() extended to detect Qwen3_5RMSNorm and NemotronHRMSNorm with graceful ImportError fallback for older transformers versions

Both implementations follow existing DAM conventions:

  • DAMLinearLayer for all nn.Linear projections
  • DAMEmbeddingLayer for token embeddings (conditional on dam_embedding_layer)
  • DAMRMSNorm for normalization layers (conditional on dam_layernorms)
  • Proper tie_weights with DAM embedding support
  • Full ForCausalLM with prepare_inputs_for_generation
  • Architecture-specific layers (conv1d, SSM parameters, MoE routing) remain as standard PyTorch modules

New Files

dam/modeling/qwen3_5/
├── __init__.py
├── config.py      (MergedQwen3_5Config)
└── modeling.py     (11 classes, ~1200 lines)

dam/modeling/nemotron/
├── __init__.py
├── config.py      (MergedNemotronHConfig)
└── modeling.py     (13 classes, ~1275 lines)

Test plan

  • Verify syntax compilation passes for all new files (done locally)
  • Test merge workflow with Qwen 3.5 fine-tuned model variants
  • Test merge workflow with Nemotron-H fine-tuned model variants
  • Verify DAM coefficient training converges on both architectures
  • Test with --use_base_model flag for both architectures
  • Validate that find_norm_layers correctly detects new norm types

Notes

  • Both implementations use eager attention only (no flash_attention_2 or SDPA variants) — these can be added in follow-up PRs
  • Mamba-2 and Gated Delta Net use pure PyTorch fallback implementations to avoid hard dependencies on mamba_ssm and fla libraries
  • Requires transformers >= 4.53 for Nemotron-H and transformers >= 4.57 for Qwen 3.5 norm detection in utils.py (gracefully falls back if unavailable)

🤖 Generated with Claude Code


Note

Medium Risk
Adds two new model backends with substantial new modeling code (custom attention/SSM/MoE implementations and config wiring), which may affect correctness and generation behavior for these architectures. Existing Mistral/Llama paths are largely untouched aside from config/norm detection.

Overview
Adds DAM support for Qwen 3.5 and Nemotron-H by introducing new Merged*Config/Merged*ForCausalLM implementations and wiring them into the merge/training flow.

dam/merge.py now rewrites config.json for qwen3_5/qwen3_5_text and nemotron_h to the corresponding merged model types, and dam/model_preparation.py registers the new configs/models with AutoConfig/AutoModelForCausalLM.

Updates dam/utils.py so find_norm_layers() can also detect Qwen/Nemotron RMSNorm classes when available (with ImportError fallbacks), ensuring norm layers are merged for these architectures.

Written by Cursor Bugbot for commit 4f96a93. This will update automatically on new commits. Configure here.

Extends DAM to support two new model architectures:

Qwen 3.5 (model_type: qwen3_5/qwen3_5_text -> mergedqwen3_5):
- Hybrid attention architecture with full (softmax) and linear (Gated
  Delta Network) attention layers, dispatched per-layer via layer_types
- Full attention uses output gating, QK-normalization, and partial RoPE
- Linear attention uses causal conv1d and recurrent delta rule updates
- RMSNorm uses (1+weight) formulation via custom Qwen3_5DAMRMSNorm
- SwiGLU MLP with all projections as DAMLinearLayer

Nemotron-H (model_type: nemotron_h -> mergednemotron_h):
- Three-block-type hybrid: Mamba-2 SSM, standard GQA attention, and MoE
- Each block has a single mixer (not attention+MLP like standard transformers)
- Mamba-2 with pure PyTorch recurrent forward (no mamba_ssm dependency)
- MoE with sigmoid routing, top-k expert selection, and shared expert
- Non-gated MLP experts with ReLU squared activation

Both implementations follow the existing DAM patterns:
- DAMLinearLayer for all nn.Linear projections
- DAMEmbeddingLayer for token embeddings (conditional on dam_embedding_layer)
- DAMRMSNorm for normalization (conditional on dam_layernorms)
- Proper tie_weights support for weight-tied embeddings
- Full ForCausalLM with prepare_inputs_for_generation

Updated files:
- dam/merge.py: fix_config() handles new model types
- dam/model_preparation.py: AutoConfig/AutoModel registration
- dam/utils.py: find_norm_layers() detects Qwen3_5RMSNorm and NemotronHRMSNorm
  with graceful fallback for older transformers versions

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@SolshineCode SolshineCode requested a review from a team as a code owner April 1, 2026 06:03
Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 3 potential issues.

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.


# Discretize
# dA: (B, num_heads, 1, 1)
dA = torch.exp(dt_t.unsqueeze(-1).unsqueeze(-1) * A.unsqueeze(-1).unsqueeze(0))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong A tensor reshape causes Mamba-2 runtime crash

High Severity

A.unsqueeze(-1).unsqueeze(0) produces a 3D tensor of shape (1, num_heads, 1) instead of the required 4D (1, num_heads, 1, 1). When broadcast-multiplied with the 4D dt_t.unsqueeze(-1).unsqueeze(-1) of shape (B, num_heads, 1, 1), PyTorch prepends a dimension, making A effectively (1, 1, num_heads, 1). The result dA has shape (B, num_heads, num_heads, 1) instead of (B, num_heads, 1, 1). The subsequent state * dA then fails because state is (B, 128, 64, 128) and dA is (B, 128, 128, 1) — dimension 2 has an incompatible size mismatch (64 vs 128), causing a runtime error. The fix is A.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) or A[None, :, None, None].

Fix in Cursor Fix in Web

# g shape: (B, L, num_v_heads, 1) -- alpha was (B, L, num_v_heads), after unsqueeze(-1) + broadcast
# Actually alpha is (B, L, num_v_heads), dt_bias is (num_v_heads,), sum is (B, L, num_v_heads)
# softplus gives (B, L, num_v_heads), multiply gives (B, L, num_v_heads)
# We need to fix the shape computation:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dead code from debugging left in decay computation

Medium Severity

Line 477 computes g with alpha.unsqueeze(-1) + self.dt_bias, which produces incorrect broadcasting ((B, L, num_v_heads, num_v_heads) instead of (B, L, num_v_heads)). Line 482 then overwrites g with the corrected computation. The first computation is dead code that wastes GPU memory by allocating a tensor quadratic in num_v_heads, along with the associated intermediate tensors from F.softplus and the multiplication. The surrounding comments confirm this was a debugging artifact that was never cleaned up.

Fix in Cursor Fix in Web

lambda size, eps: Qwen3_5DAMRMSNorm(size, eps=eps, num_models=config.num_merged_models)
if config.dam_layernorms
else lambda size, eps: Qwen3_5RMSNorm(size, eps=eps)
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused NormClass variable defined but never referenced

Low Severity

NormClass is assigned as a lambda on lines 569–573 but is never used. The norm layers on lines 574–583 are constructed using inline conditional expressions instead. This appears to be a leftover from an earlier refactoring attempt and adds unnecessary clutter.

Fix in Cursor Fix in Web

@SolshineCode SolshineCode marked this pull request as draft April 1, 2026 18:45
Cursor Bugbot fixes:
- Fix Mamba-2 dA tensor reshape crash (A[None,:,None,None] for correct 4D broadcast)
- Remove dead code in Qwen 3.5 decay gate computation
- Remove unused NormClass lambda in decoder layer

Gemini review fix:
- Add null safety check on to_legacy_cache() in Qwen 3.5 model

CPU smoke test fixes:
- Add **kwargs to tie_weights() for transformers 5.x compatibility
- Guard DynamicCache.from_legacy_cache with hasattr for forward compat

All 40 CPU smoke tests pass (imports, config, model instantiation,
forward pass, loss computation, DAM layer verification, fix_config,
norm detection, AutoConfig registration).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@SolshineCode SolshineCode force-pushed the feature/nemotron3-qwen3.5-support branch from f0f7d0e to 7e3b5f0 Compare April 1, 2026 23:52
@SolshineCode
Copy link
Copy Markdown
Author

SolshineCode commented Apr 2, 2026

Hey @shamanez @thomasgauthier @mryave
This PR adds DAM support for two new model architectures:

  • Qwen 3.5 (hybrid full + linear attention via Gated Delta Network)
  • Nemotron-H / Nemotron 3 (hybrid Mamba-2 + GQA attention + MoE)

Both follow the existing Llama/Mistral DAM patterns exactly. All 40 CPU smoke tests pass (config, model instantiation, forward pass, loss computation, DAM layer verification, AutoConfig registration).

Would appreciate a review when you get a chance. Happy to address any feedback. 🙏

@SolshineCode SolshineCode marked this pull request as ready for review April 2, 2026 00:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant