Skip to content

Add configurable initialization for PerceiverEncoder.latent parameter and integrate with init_params function#16

Closed
Copilot wants to merge 4 commits intomainfrom
copilot/fix-12
Closed

Add configurable initialization for PerceiverEncoder.latent parameter and integrate with init_params function#16
Copilot wants to merge 4 commits intomainfrom
copilot/fix-12

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented Sep 16, 2025

The PerceiverEncoder.latent parameter was previously initialized using torch.randn(), which creates a standard normal distribution with large value ranges (std ≈ 1.0, range ±3). This can lead to unstable training, especially in deep networks where gradient flow is sensitive to initialization scale.

Changes

Core Implementation

  • Added latent_init parameter to PerceiverEncoder.__init__() with 7 initialization methods:
    • "normal": Standard normal distribution (default, maintains backward compatibility)
    • "xavier_uniform" & "xavier_normal": Xavier/Glorot initialization for balanced gradient flow
    • "kaiming_uniform" & "kaiming_normal": He initialization, optimal for ReLU-based networks
    • "truncated_normal": Truncated normal with std=0.02 for very stable training
    • "zeros": Initialize to zeros

Integration Updates

  • Updated HiPBlock, HiPBlockSequential, HiPEncoder, and HiPDecoder to support the new parameter
  • Updated LatentAttentionModel to accept latent_init parameter for both track and muon detector encoders
  • Added configuration support in config/model/latent_attention.yaml
  • Enhanced init_params function in src/deepmuonreco/nn/utils.py to include sophisticated PerceiverEncoder initialization using truncated normal distribution with fan-in scaling

Two Initialization Approaches

1. Configurable Initialization (via constructor)

encoder = PerceiverEncoder(latent_len=10, latent_dim=64, num_heads=4, latent_init="xavier_uniform")
print(encoder.latent.std())  # ~0.15, range ±0.3 (83% smaller)

2. Global init_params Function

model.apply(init_params)  # Applies sophisticated truncated normal initialization

Benefits

The new initialization methods provide significantly improved stability:

# Before: Large initialization range
encoder = PerceiverEncoder(latent_len=10, latent_dim=64, num_heads=4)
print(encoder.latent.std())  # ~1.0, range ±3

# After: Stable initialization options
encoder = PerceiverEncoder(latent_len=10, latent_dim=64, num_heads=4, latent_init="xavier_uniform")
print(encoder.latent.std())  # ~0.15, range ±0.3 (83% smaller)

encoder = PerceiverEncoder(latent_len=10, latent_dim=64, num_heads=4, latent_init="truncated_normal")
print(encoder.latent.std())  # ~0.02, range ±0.04 (98% smaller)

Backward Compatibility

All existing code continues to work unchanged since latent_init="normal" is the default, preserving the original behavior. Users can opt into better initialization methods through configuration:

# config/model/latent_attention.yaml
latent_init: xavier_uniform  # or truncated_normal for very stable training

Fixes #12.


💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.

Copilot AI and others added 3 commits September 16, 2025 11:34
Co-authored-by: slowmoyang <20718100+slowmoyang@users.noreply.github.com>
Co-authored-by: slowmoyang <20718100+slowmoyang@users.noreply.github.com>
Co-authored-by: slowmoyang <20718100+slowmoyang@users.noreply.github.com>
@slowmoyang
Copy link
Copy Markdown
Member

slowmoyang commented Sep 16, 2025

@copilot can u test the following implementation?

@torch.no_grad()
def init_params(module: nn.Module) -> None:
    """
    """
    if isinstance(module, nn.Linear):
        nn.init.normal_(module.weight, mean=0.0, std=0.02)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, mean=0.0, std=0.02)
    elif isinstance(module, PerceiverEncoder):
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.latent)
        scale = 1
        n = max(1, fan_in)
        s = scale / n
        stddev = math.sqrt(s)
        stddev = stddev / .87962566103423978
        nn.init.trunc_normal_(module.latent, std=stddev, a=-2, b=+2)

#12 (comment)

Copilot AI changed the title [WIP] support initialization of PerceiverEncoder.latent Add configurable initialization for PerceiverEncoder.latent parameter Sep 16, 2025
Copilot AI requested a review from slowmoyang September 16, 2025 11:43
@slowmoyang
Copy link
Copy Markdown
Member

I think the starting point of copilot was bad... I will make a my new PR.

@slowmoyang slowmoyang closed this Sep 16, 2025
Copy link
Copy Markdown
Contributor Author

Copilot AI commented Sep 16, 2025

@copilot can u test the following implementation?

...

Yes, I've tested the implementation thoroughly! I've integrated your suggested PerceiverEncoder initialization into the existing init_params function in src/deepmuonreco/nn/utils.py (commit e363d0b).

Test Results:

  • Configurable initialization: All 7 methods work correctly (normal, xavier_uniform, xavier_normal, kaiming_uniform, kaiming_normal, truncated_normal, zeros)
  • init_params function: Successfully applies your sophisticated truncated normal initialization (std ~0.14)
  • Full model integration: Works with LatentAttentionModel and training pipeline
  • Forward passes: All initialization methods pass forward pass validation

Two approaches now available:

  1. Constructor parameter: PerceiverEncoder(latent_init="xavier_uniform")
  2. Global function: model.apply(init_params) (uses your suggested formula)

Both maintain backward compatibility and provide stable training with reduced initialization ranges compared to the original torch.randn() approach.

Copilot AI changed the title Add configurable initialization for PerceiverEncoder.latent parameter Add configurable initialization for PerceiverEncoder.latent parameter and integrate with init_params function Sep 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

support initialization of PerceiverEncoder.latent

2 participants