Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions ark_nlp/factory/task/base/_sequence_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ def fit(
# loss backword
loss = self._on_backward(inputs, outputs, logits, loss, **kwargs)

# FGM / PGD
if self.use_pgd:
self._pfg_attack(inputs, gradient_accumulation_steps, **kwargs)
torch.cuda.empty_cache() # PGD占用显存较多,防止显存溢出

if self.use_fgm:
self._fgm_attack(inputs, **kwargs)

if (step + 1) % gradient_accumulation_steps == 0:

# optimize
Expand Down Expand Up @@ -493,3 +501,34 @@ def _on_evaluate_end(

if self.ema_decay:
self.ema.restore(self.module.parameters())

def _pfg_attack(self, inputs, gradient_accumulation_steps, **kwargs):
self.pgd.backup_grad()
# 对抗训练
for t in range(self.K):
self.pgd.attack(is_first_attack=(t==0)) # 在embedding上添加对抗扰动, first attack时备份param.data
if t != self.K-1:
self.module.zero_grad()
else:
self.pgd.restore_grad()
outputs = self.module(**inputs)
logits, loss_adv = self._get_train_loss(inputs, outputs, **kwargs)
# 如果GPU数量大于1
if self.n_gpu > 1:
loss_adv = loss_adv.mean()
# 如果使用了梯度累积,除以累积的轮数
if gradient_accumulation_steps > 1:
loss_adv = loss_adv / gradient_accumulation_steps
loss_adv.backward()
self.pgd.restore() # 恢复embedding参数

def _fgm_attack(self, inputs, **kwargs):
self.fgm.attack() # 在embedding上添加对抗扰动
outputs = self.module(**inputs)
# 计算损失
logits, loss_adv = self._get_train_loss(inputs, outputs, **kwargs)
# 如果GPU数量大于1
if self.n_gpu > 1:
loss_adv = loss_adv.mean()
loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
self.fgm.restore() # 恢复embedding参数
11 changes: 10 additions & 1 deletion ark_nlp/factory/task/base/_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.utils.data._utils.collate import default_collate
from ark_nlp.factory.loss_function import get_loss
from ark_nlp.factory.utils.ema import EMA

from ark_nlp.factory.utils.attack import PGD, FGM

class Task(object):
"""
Expand Down Expand Up @@ -50,6 +50,8 @@ def __init__(
n_gpu=1,
device=None,
cuda_device=0,
use_pgd=False,
use_fgm=False,
ema_decay=None,
**kwargs
):
Expand Down Expand Up @@ -82,6 +84,13 @@ def __init__(
self.ema_decay = ema_decay
if self.ema_decay:
self.ema = EMA(self.module.parameters(), decay=self.ema_decay)

if use_pgd:
self.K = 3
self.pgd = PGD(self.module)

if use_fgm:
self.fgm = FGM(self.module)

def _train_collate_fn(self, batch):
return default_collate(batch)
Expand Down