Skip to content

[Bug] L1/L2 regularization is completely bypassed during training due to the use of .data in loss computation #19

@zhaoshuoxp

Description

@zhaoshuoxp

[Bug] L1/L2 regularization is completely bypassed during training due to the use of .data in loss computation

Describe the bug

While tuning the L1 and L2 penalty hyperparameters to achieve the expected network sparsity (as described in the official documentation), I noticed that even extremely high penalty values (e.g., L1 = 0.5 or 10.0) had absolutely zero effect on the weight distribution or network sparsity.

Upon inspecting the training loops in Phase 1 and Phase 2, I found that the .data attribute is being used to access the model weights when computing the regularization norms. In PyTorch, accessing a tensor via .data returns a tensor that shares the same memory but is detached from the computational graph (effectively setting requires_grad=False for that operation).

As a result, the autograd engine treats the entire loss_norm term as a disconnected constant. When total_loss.backward() is called, no gradients from the L1/L2 penalties ever reach the model parameters, rendering the regularization completely ineffective.

To Reproduce

Location of the bug:

  1. train_scdori_phases and compute_eval_loss_scdori (Phase 1 training)
  2. train_model_grn and compute_eval_loss_grn (Phase 2 training)

Code snippet causing the issue:

# The `.data` attribute strips the tensor from the autograd graph
l1_norm_tf = torch.norm(model.topic_tf_decoder.data, p=1)
l2_norm_tf = torch.norm(model.topic_tf_decoder.data, p=2)
l1_norm_peak = torch.norm(model.topic_peak_decoder.data, p=1)
# ... and similarly for all other regularized matrices (gene_peak, grn_activator, grn_repressor)

Expected behavior
The regularization terms should be part of the computational graph so that their gradients can penalize dense weights during backpropagation, driving the network towards the intended sparse topology.

Proposed Solution
Simply remove the .data attribute when calculating the norms for the loss function. This allows PyTorch to track the operations and backpropagate the gradients correctly.

Corrected code:

l1_norm_tf = torch.norm(model.topic_tf_decoder, p=1)
l2_norm_tf = torch.norm(model.topic_tf_decoder, p=2)
l1_norm_peak = torch.norm(model.topic_peak_decoder, p=1)
# ... apply to all penalty calculations

Note: The use of .data for the in-place clamping operation later in the training loops (e.g., model.gene_peak_factor_learnt.data.clamp_(min=0)) is correct and should be kept, as it is a direct weight modification outside of the gradient calculation. However, for the loss computation, it must be removed.

Impact
Fixing this issue allows the L1 penalties to actively prune the network during training, significantly reducing false-positive edges in the inferred Gene Regulatory Networks (GRNs) and making the model responsive to the l1_penalty_* hyperparameters defined in config.py.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions