When I was studying your code implementation, I noticed that the pseudo labels generated during training were not saved and used for the self-training stage. Instead, only the LAT loss + p2_cons + p2_reg was used in the latter part of the training process. So, your work is an end-to-end one.
if process <= division:
if mask is None:
# pseudo = minMaxNorm(uphw(attn_pos, size=size)).gt(0.5).float()
pseudo = self.crf(uphw(minMaxNorm(x), size=size), minMaxNorm(uphw(attn_pos.detach(), size=size)), iters=self.cfg.crf_round).gt(0.5).float()
else:
loss = {} ## clear
pseudo = uphw(mask, size=size)
bce_pr_wrt_ps = F.binary_cross_entropy_with_logits(pred, pseudo, reduction="none").reshape(n, -1).mean(dim=-1)
loss["p1_bce"] = (bce_pr_wrt_ps * torch.softmax(self.cfg.rew_kai / (bce_pr_wrt_ps + 1.0), dim=-1)).sum()
loss["p1_cons"] = F.l1_loss(torch.sigmoid(pred[0:n]), torch.sigmoid(pred[n::]))
if process > division:
loss["p2_lwt"] = self.lwt(torch.sigmoid(pred), minMaxNorm(x), margin=0.5)
loss["p2_cons"] = F.l1_loss(torch.sigmoid(pred[0:n]), torch.sigmoid(pred[n::]))
loss["p2_reg"] = 0.5 - torch.abs(torch.sigmoid(pred) - 0.5).mean()
When I was studying your code implementation, I noticed that the pseudo labels generated during training were not saved and used for the self-training stage. Instead, only the LAT loss + p2_cons + p2_reg was used in the latter part of the training process. So, your work is an end-to-end one.