From 48663cfdbcd447c495b00306c61ba49a24f1731d Mon Sep 17 00:00:00 2001 From: liushu <1554987494@qq.com> Date: Tue, 15 Nov 2022 15:07:17 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0FGM/PGD=E5=AF=B9=E6=8A=97?= =?UTF-8?q?=E8=AE=AD=E7=BB=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../task/base/_sequence_classification.py | 39 +++++++++++++++++++ ark_nlp/factory/task/base/_task.py | 11 +++++- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/ark_nlp/factory/task/base/_sequence_classification.py b/ark_nlp/factory/task/base/_sequence_classification.py index c8e5ede..495d5a6 100644 --- a/ark_nlp/factory/task/base/_sequence_classification.py +++ b/ark_nlp/factory/task/base/_sequence_classification.py @@ -105,6 +105,14 @@ def fit( # loss backword loss = self._on_backward(inputs, outputs, logits, loss, **kwargs) + # FGM / PGD + if self.use_pgd: + self._pfg_attack(inputs, gradient_accumulation_steps, **kwargs) + torch.cuda.empty_cache() # PGD占用显存较多,防止显存溢出 + + if self.use_fgm: + self._fgm_attack(inputs, **kwargs) + if (step + 1) % gradient_accumulation_steps == 0: # optimize @@ -493,3 +501,34 @@ def _on_evaluate_end( if self.ema_decay: self.ema.restore(self.module.parameters()) + + def _pfg_attack(self, inputs, gradient_accumulation_steps, **kwargs): + self.pgd.backup_grad() + # 对抗训练 + for t in range(self.K): + self.pgd.attack(is_first_attack=(t==0)) # 在embedding上添加对抗扰动, first attack时备份param.data + if t != self.K-1: + self.module.zero_grad() + else: + self.pgd.restore_grad() + outputs = self.module(**inputs) + logits, loss_adv = self._get_train_loss(inputs, outputs, **kwargs) + # 如果GPU数量大于1 + if self.n_gpu > 1: + loss_adv = loss_adv.mean() + # 如果使用了梯度累积,除以累积的轮数 + if gradient_accumulation_steps > 1: + loss_adv = loss_adv / gradient_accumulation_steps + loss_adv.backward() + self.pgd.restore() # 恢复embedding参数 + + def _fgm_attack(self, inputs, **kwargs): + self.fgm.attack() # 在embedding上添加对抗扰动 + outputs = self.module(**inputs) + # 计算损失 + logits, loss_adv = self._get_train_loss(inputs, outputs, **kwargs) + # 如果GPU数量大于1 + if self.n_gpu > 1: + loss_adv = loss_adv.mean() + loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度 + self.fgm.restore() # 恢复embedding参数 \ No newline at end of file diff --git a/ark_nlp/factory/task/base/_task.py b/ark_nlp/factory/task/base/_task.py index c5c3a88..5f78834 100644 --- a/ark_nlp/factory/task/base/_task.py +++ b/ark_nlp/factory/task/base/_task.py @@ -21,7 +21,7 @@ from torch.utils.data._utils.collate import default_collate from ark_nlp.factory.loss_function import get_loss from ark_nlp.factory.utils.ema import EMA - +from ark_nlp.factory.utils.attack import PGD, FGM class Task(object): """ @@ -50,6 +50,8 @@ def __init__( n_gpu=1, device=None, cuda_device=0, + use_pgd=False, + use_fgm=False, ema_decay=None, **kwargs ): @@ -82,6 +84,13 @@ def __init__( self.ema_decay = ema_decay if self.ema_decay: self.ema = EMA(self.module.parameters(), decay=self.ema_decay) + + if use_pgd: + self.K = 3 + self.pgd = PGD(self.module) + + if use_fgm: + self.fgm = FGM(self.module) def _train_collate_fn(self, batch): return default_collate(batch)