-
Notifications
You must be signed in to change notification settings - Fork 7
Description
[Bug] L1/L2 regularization is completely bypassed during training due to the use of .data in loss computation
Describe the bug
While tuning the L1 and L2 penalty hyperparameters to achieve the expected network sparsity (as described in the official documentation), I noticed that even extremely high penalty values (e.g., L1 = 0.5 or 10.0) had absolutely zero effect on the weight distribution or network sparsity.
Upon inspecting the training loops in Phase 1 and Phase 2, I found that the .data attribute is being used to access the model weights when computing the regularization norms. In PyTorch, accessing a tensor via .data returns a tensor that shares the same memory but is detached from the computational graph (effectively setting requires_grad=False for that operation).
As a result, the autograd engine treats the entire loss_norm term as a disconnected constant. When total_loss.backward() is called, no gradients from the L1/L2 penalties ever reach the model parameters, rendering the regularization completely ineffective.
To Reproduce
Location of the bug:
train_scdori_phasesandcompute_eval_loss_scdori(Phase 1 training)train_model_grnandcompute_eval_loss_grn(Phase 2 training)
Code snippet causing the issue:
# The `.data` attribute strips the tensor from the autograd graph
l1_norm_tf = torch.norm(model.topic_tf_decoder.data, p=1)
l2_norm_tf = torch.norm(model.topic_tf_decoder.data, p=2)
l1_norm_peak = torch.norm(model.topic_peak_decoder.data, p=1)
# ... and similarly for all other regularized matrices (gene_peak, grn_activator, grn_repressor)Expected behavior
The regularization terms should be part of the computational graph so that their gradients can penalize dense weights during backpropagation, driving the network towards the intended sparse topology.
Proposed Solution
Simply remove the .data attribute when calculating the norms for the loss function. This allows PyTorch to track the operations and backpropagate the gradients correctly.
Corrected code:
l1_norm_tf = torch.norm(model.topic_tf_decoder, p=1)
l2_norm_tf = torch.norm(model.topic_tf_decoder, p=2)
l1_norm_peak = torch.norm(model.topic_peak_decoder, p=1)
# ... apply to all penalty calculationsNote: The use of .data for the in-place clamping operation later in the training loops (e.g., model.gene_peak_factor_learnt.data.clamp_(min=0)) is correct and should be kept, as it is a direct weight modification outside of the gradient calculation. However, for the loss computation, it must be removed.
Impact
Fixing this issue allows the L1 penalties to actively prune the network during training, significantly reducing false-positive edges in the inferred Gene Regulatory Networks (GRNs) and making the model responsive to the l1_penalty_* hyperparameters defined in config.py.