diff --git a/config/model/latent_attention.yaml b/config/model/latent_attention.yaml index 67b678a..479e3aa 100644 --- a/config/model/latent_attention.yaml +++ b/config/model/latent_attention.yaml @@ -6,6 +6,7 @@ encoder_num_layers: 4 decoder_num_layers: 4 widening_factor: 4 dropout_p: 0.01 +latent_init: normal # Options: normal, xavier_uniform, xavier_normal, kaiming_uniform, kaiming_normal, truncated_normal, zeros module: @@ -24,6 +25,7 @@ module: decoder_num_layers: ${model.decoder_num_layers} widening_factor: ${model.widening_factor} dropout_p: ${model.dropout_p} + latent_init: ${model.latent_init} in_keys: diff --git a/environment.yaml b/environment.yaml index 76fab2e..991d7f9 100644 --- a/environment.yaml +++ b/environment.yaml @@ -3,7 +3,7 @@ channels: - conda-forge dependencies: - python=3.12 - - pytorch-gpu=2.7.0 + - pytorch=2.7.0 - tensordict=0.8.3 - einops=0.8.1 - lightning=2.5.1.post0 diff --git a/src/deepmuonreco/nn/models/latent_attention.py b/src/deepmuonreco/nn/models/latent_attention.py index d8710f4..f7dd87e 100644 --- a/src/deepmuonreco/nn/models/latent_attention.py +++ b/src/deepmuonreco/nn/models/latent_attention.py @@ -31,10 +31,12 @@ def __init__( decoder_num_layers: int, dropout_p: float = 0.1, widening_factor: int = 4, + latent_init: str = "normal", ) -> None: """ Args: latent_len: number of latent vectors in the encoder for muon detector system measurement embeddings + latent_init: Initialization method for PerceiverEncoder latent parameters """ super().__init__() @@ -52,6 +54,7 @@ def __init__( widening_factor=widening_factor, dropout_p=dropout_p, bias=True, + latent_init=latent_init, ) self.muon_det_encoder = PerceiverEncoder( @@ -62,6 +65,7 @@ def __init__( widening_factor=widening_factor, dropout_p=dropout_p, bias=True, + latent_init=latent_init, ) self.encoder = TransformerDecoder( diff --git a/src/deepmuonreco/nn/transformers/hip.py b/src/deepmuonreco/nn/transformers/hip.py index 46ae26f..ed11a47 100644 --- a/src/deepmuonreco/nn/transformers/hip.py +++ b/src/deepmuonreco/nn/transformers/hip.py @@ -38,6 +38,7 @@ def __init__( processor_num_heads: int, processor_widening_factor: int, dropout_p: float = 0, + latent_init: str = "normal", ) -> None: """ """ @@ -66,6 +67,7 @@ def __init__( widening_factor=encoder_widening_factor, input_dim=input_dim, dropout_p=dropout_p, + latent_init=latent_init, ) processor = PerceiverProcessor( @@ -148,6 +150,7 @@ def __init__( # NOTE: encoder encoder_num_heads: list[int] | None = None, encoder_widening_factor: list[int] | None = None, + latent_init: str = "normal", ) -> None: super().__init__() @@ -178,7 +181,7 @@ def __init__( # HiP's encoder self.block_list = nn.ModuleList([ - HiPBlock(**kwargs) for kwargs in kwargs_list + HiPBlock(**kwargs, latent_init=latent_init) for kwargs in kwargs_list ]) @@ -199,6 +202,7 @@ def __init__( encoder_num_heads: list[int] | None = None, encoder_widening_factor: list[int] | None = None, return_hidden: bool = True, + latent_init: str = "normal", ) -> None: """ """ @@ -212,6 +216,7 @@ def __init__( processor_widening_factor, encoder_num_heads, encoder_widening_factor, + latent_init, ) self.return_hidden = return_hidden diff --git a/src/deepmuonreco/nn/transformers/perceiver.py b/src/deepmuonreco/nn/transformers/perceiver.py index 87ada1f..475142d 100644 --- a/src/deepmuonreco/nn/transformers/perceiver.py +++ b/src/deepmuonreco/nn/transformers/perceiver.py @@ -30,12 +30,31 @@ def __init__( input_dim: int | None = None, dropout_p: float = 0, bias: bool = False, + latent_init: str = "normal", ) -> None: """ + Args: + latent_len: Number of latent vectors + latent_dim: Dimension of each latent vector + num_heads: Number of attention heads + use_post_attention_residual: Whether to use post-attention residual connection + widening_factor: MLP widening factor + input_dim: Input dimension (if different from latent_dim) + dropout_p: Dropout probability + bias: Whether to use bias in attention layers + latent_init: Initialization method for latent parameters. Options: + - "normal": Standard normal distribution (default, backward compatible) + - "xavier_uniform": Xavier/Glorot uniform initialization + - "xavier_normal": Xavier/Glorot normal initialization + - "kaiming_uniform": Kaiming/He uniform initialization + - "kaiming_normal": Kaiming/He normal initialization + - "truncated_normal": Truncated normal distribution (std=0.02) + - "zeros": Initialize to zeros """ super().__init__() - self.latent = nn.Parameter(data=torch.randn(latent_len, latent_dim)) + self.latent = nn.Parameter(data=torch.empty(latent_len, latent_dim)) + self._initialize_latent(latent_init) self.attention = CrossAttentionBlock( embed_dim=latent_dim, @@ -54,6 +73,32 @@ def __init__( dropout_p=dropout_p, ) + def _initialize_latent(self, init_method: str) -> None: + """Initialize the latent parameter tensor using the specified method.""" + with torch.no_grad(): + if init_method == "normal": + # Standard normal distribution (backward compatible) + nn.init.normal_(self.latent, mean=0.0, std=1.0) + elif init_method == "xavier_uniform": + nn.init.xavier_uniform_(self.latent) + elif init_method == "xavier_normal": + nn.init.xavier_normal_(self.latent) + elif init_method == "kaiming_uniform": + nn.init.kaiming_uniform_(self.latent, mode='fan_in') + elif init_method == "kaiming_normal": + nn.init.kaiming_normal_(self.latent, mode='fan_in') + elif init_method == "truncated_normal": + # Truncated normal with smaller std for more stable training + nn.init.trunc_normal_(self.latent, mean=0.0, std=0.02, a=-2*0.02, b=2*0.02) + elif init_method == "zeros": + nn.init.zeros_(self.latent) + else: + raise ValueError( + f"Unknown latent initialization method: {init_method}. " + f"Supported methods: normal, xavier_uniform, xavier_normal, " + f"kaiming_uniform, kaiming_normal, truncated_normal, zeros" + ) + def forward( self, input: Tensor,