Skip to content

Replace conv trunk with a configurable residual tower#14

Merged
cweill merged 1 commit into
mainfrom
feature/residual-network-tower
May 26, 2026
Merged

Replace conv trunk with a configurable residual tower#14
cweill merged 1 commit into
mainfrom
feature/residual-network-tower

Conversation

@cweill

@cweill cweill commented May 26, 2026

Copy link
Copy Markdown
Contributor

Summary

Swap AlphaZeroNet's two-layer conv trunk for a stem + N residual blocks, with channels and num_res_blocks constructor knobs (defaults 64 / 4). This is the trunk architecture real AlphaZero uses (minus normalization, see below) and lets the network scale depth/width per game.

Motivation

Residual connections let gradients flow straight through a deep trunk, enabling more representational capacity than the current 2-conv stack without vanishing-gradient problems. Depth/width are now tunable per game rather than hardcoded.

Design notes

  • No BatchNorm (intentional). The training loop can emit a size-1 final batch, which crashes BN in train mode. The skip connection — not BN — is the load-bearing idea, so the block is Conv→ReLU→Conv + skip → ReLU. BN can be added later alongside drop_last in train.py.
  • Signatures unchanged. forward/predict/predict_batch are identical, so MCTS and training need no changes; only the trunk internals differ.
  • Old conv-trunk checkpoints won't load into the new module (different layer names); there are no committed checkpoints.

Testing

  • pytest tests/test_network.py tests/test_train.py → 18 passed.
  • New tests: configurable depth/width, stem-only (num_res_blocks=0), skip-path gradient flow, and config validation.
  • Training smoke (tic-tac-toe, 39 self-play examples): loss decreases (3.19 → 3.16 over 8 epochs), no NaN/divergence — confirms gradients flow and the deeper net trains. Full convergence vs. the old trunk needs a real training run.

Review focus

  • Default num_res_blocks=4 / channels=64 — reasonable for small boards, or too heavy?
  • The no-BatchNorm decision.

Swap the two-layer conv trunk for a stem plus N residual blocks (skip
connections, no normalization), with `channels` and `num_res_blocks`
constructor knobs (defaults 64 / 4). Skip connections keep gradients
flowing through deeper trunks; BatchNorm is omitted so the block is safe
for any batch size, including the size-1 final training batch.

forward/predict/predict_batch signatures are unchanged, so MCTS and
training are untouched. Adds tests for configurability, the stem-only
(zero-block) case, skip-path gradient flow, and config validation.
@cweill cweill merged commit 291cc33 into main May 26, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant