-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathattack.py
More file actions
420 lines (340 loc) · 18 KB
/
attack.py
File metadata and controls
420 lines (340 loc) · 18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
import torch
import torch.nn.functional as F
from abc import ABC, abstractmethod
from model import Model, ForwardMode
from transform import unnormalize_inplace, normalize_inplace
from loss import ce_loss, l2_loss, dlr_loss, dlr_loss_targeted, ce_loss_targeted
from perf.profiling import ProfileModelMemory, ProfileModelGradient
from contextlib import contextmanager
# =========================== Base ===========================
class Attack(ABC):
@abstractmethod
def perturb(self, x: torch.Tensor, y: torch.Tensor = None, emb_orig=None) -> torch.Tensor:
pass
class AttackModel(Model):
def __init__(self, model: Model, mean=0, std=1):
super().__init__()
self.model = model
self.mean = mean
self.std = std
def forward(self, x, mode=ForwardMode.EMBEDDINGS):
x = (x - self.mean) / self.std
wrapped_x = self.model.wrap_tensor(x)
return self.model(wrapped_x, mode)
# =========================== PGD ===========================
class PGDAttack(Attack):
def __init__(self, logger, model: AttackModel,
epsilon=8 / 255, alpha=2 / 255, steps=10,
norm="linf", random_start=True,
clamp_min=0.0, clamp_max=1.0,
loss_type="ce"):
assert norm in ("linf", "l2", "l1")
assert loss_type in ("ce", "l2", "dlr", "dlr-targeted", "ce-targeted")
self.logger = logger
self.model = model
self.epsilon = epsilon
self.alpha = alpha
self.steps = steps
self.norm = norm
self.random_start = random_start
self.clamp_min = clamp_min
self.clamp_max = clamp_max
self.loss_type = loss_type
self.y_target = None
def perturb(self, x, y=None, emb_orig=None):
with no_model_grads(self.model):
x_adv = x.detach().clone().requires_grad_(True)
if self.loss_type in {"ce", "dlr", "ce-targeted", "dlr-targeted"}:
if y is None:
raise ValueError("This loss type requires labels.")
with torch.no_grad():
acc = self._acc_with_x(x, y)
self.logger.info(f"[PGD] Initial accuracy: {acc.item() * 100:.2f}%")
elif self.loss_type == "l2":
if emb_orig is None:
raise ValueError("L2 loss requires original embeddings.")
with torch.no_grad():
cos_sim = self._cos_sim_with_x(x, emb_orig)
self.logger.info(f"[PGD] Initial cosine similarity: {cos_sim.item():.4f}")
if self.random_start:
self.logger.debug(f"[PGD] Applying random start for norm={self.norm}")
x_adv = {
"linf": random_start_linf,
"l2": random_start_l2,
"l1": random_start_l1
}[self.norm](x, self.epsilon, self.clamp_min, self.clamp_max)
for step in range(self.steps):
x_adv = x_adv.detach().clone().requires_grad_(True)
model_input = x_adv.clone()
if self.loss_type == "ce":
with ProfileModelMemory(self.model, self.logger):
logits, _ = self.model(model_input, mode=ForwardMode.LOGITS)
loss = ce_loss(logits, y)
elif self.loss_type == "ce-targeted":
logits, _ = self.model(model_input, mode=ForwardMode.LOGITS)
loss = ce_loss_targeted(logits, self.y_target)
elif self.loss_type == "dlr":
logits, _ = self.model(model_input, mode=ForwardMode.LOGITS)
loss = dlr_loss(logits, y)
elif self.loss_type == "dlr-targeted":
logits, _ = self.model(model_input, mode=ForwardMode.LOGITS)
loss = dlr_loss_targeted(logits, y, self.y_target)
elif self.loss_type == "l2":
x_adv_emb = self.model(model_input, mode=ForwardMode.EMBEDDINGS)
loss = l2_loss(x_adv_emb, emb_orig)
grad = torch.autograd.grad(loss, x_adv)[0]
grad_norm = grad.norm().item()
del loss
x_adv = x_adv.detach()
if self.norm == "linf":
x_adv = project_linf(x_adv + self.alpha * grad.sign(), x, self.epsilon)
elif self.norm == "l2":
step_dir = self.alpha * grad / (L2_norm(grad, keepdim=True) + 1e-12)
x_adv = project_l2(x_adv + step_dir, x, self.epsilon)
elif self.norm == "l1":
step_dir = self.alpha * grad / (L1_norm(grad, keepdim=True) + 1e-12)
x_adv = x + project_l1(x_adv + step_dir - x, self.epsilon)
x_adv = x_adv.clamp(self.clamp_min, self.clamp_max)
# === Log step-wise metrics ===
self.logger.debug(f"[PGD][Step {step+1}/{self.steps}] Grad norm: {grad_norm:.4f}")
if self.loss_type == "l2":
with torch.no_grad():
cos_sim_step = self._cos_sim_with_x(x_adv, emb_orig)
self.logger.debug(f"[PGD][Step {step+1}] Cosine similarity: {cos_sim_step.item():.4f}")
elif self.loss_type in {"ce", "dlr", "ce-targeted", "dlr-targeted"}:
with torch.no_grad():
acc_step = self._acc_with_x(x_adv, y)
self.logger.debug(f"[PGD][Step {step+1}] Accuracy: {acc_step.item() * 100:.2f}%")
self.logger.debug(f"[PGD][Step {step+1}] Perturbed sample range: "
f"[{x_adv.min().item():.4f}, {x_adv.max().item():.4f}]")
del grad
# === Final metric ===
if self.loss_type == "l2":
with torch.no_grad():
final_cos_sim = self._cos_sim_with_x(x_adv, emb_orig)
self.logger.info(f"[PGD] Final cosine similarity: {final_cos_sim.item():.4f}")
elif self.loss_type in {"ce", "dlr", "ce-targeted", "dlr-targeted"}:
with torch.no_grad():
final_acc = self._acc_with_x(x_adv, y)
self.logger.info(f"[PGD] Final accuracy: {final_acc.item() * 100:.2f}%")
self.logger.info("[PGD] Attack completed.")
return x_adv
def _acc_with_x(self, x, y):
logits, _ = self.model(x, mode=ForwardMode.LOGITS)
return self._acc_with_logits(logits, y)
def _acc_with_logits(self, logits, y):
preds = logits.argmax(dim=1)
return (preds == y).float().mean()
def _cos_sim_with_x(self, x, original_emb):
x_emb = self.model(x, mode=ForwardMode.EMBEDDINGS)
return self._cos_sim_with_emb(x_emb, original_emb)
def _cos_sim_with_emb(self, emb, original_emb):
return F.cosine_similarity(emb, original_emb, dim=1).mean()
# =========================== APGD ===========================
class APGDAttack(Attack):
def __init__(self, logger, model: AttackModel,
n_iter=100, norm="linf", n_restarts=1,
eps=8 / 255, loss_type="ce",
eot_iter=1, best_loss=True,
device=None):
assert norm in ("linf", "l2", "l1")
assert loss_type in ("ce", "l2", "dlr", "dlr-targeted", "ce-targeted")
self.logger = logger
self.model = model
self.n_iter = n_iter
self.n_restarts = n_restarts
self.eps = eps
self.norm = norm
self.loss_type = loss_type
self.eot_iter = eot_iter
self.best_loss = best_loss
self.device = device
self.y_target = None
def perturb(self, x, y=None, emb_orig=None):
with no_model_grads(self.model):
if y is None and self.loss_type in {"ce", "dlr", "dlr-targeted", "ce-targeted"}:
raise ValueError("This attack requires labels.")
x = x.detach().clone().to(self.device)
y = y.to(self.device) if y is not None else None
emb_orig = emb_orig.to(self.device) if emb_orig is not None else None
best_adv = x.clone()
for restart in range(self.n_restarts):
self.logger.info(f"[APGD] Restart {restart + 1}/{self.n_restarts}")
delta = (2 * torch.rand_like(x) - 1) if self.norm == "linf" else torch.randn_like(x)
delta = self._normalize(delta) * self.eps
x_adv = (x + delta).clamp(0, 1)
for i in range(self.n_iter):
x_adv = x_adv.detach().clone().requires_grad_(True)
loss = 0.0
with ProfileModelGradient(self.model, self.logger):
for _ in range(self.eot_iter):
if self.loss_type == "ce":
logits, _ = self.model(x_adv, ForwardMode.LOGITS)
loss += ce_loss(logits, y)
elif self.loss_type == "ce-targeted":
logits, _ = self.model(x_adv, ForwardMode.LOGITS)
loss += ce_loss_targeted(logits, self.y_target)
elif self.loss_type == "dlr":
logits, _ = self.model(x_adv, ForwardMode.LOGITS)
loss += dlr_loss(logits, y)
elif self.loss_type == "dlr-targeted":
logits, _ = self.model(x_adv, ForwardMode.LOGITS)
loss += dlr_loss_targeted(logits, y, self.y_target)
elif self.loss_type == "l2":
x_emb = self.model(x_adv, ForwardMode.EMBEDDINGS)
loss += l2_loss(x_emb, emb_orig)
loss /= self.eot_iter
grad = torch.autograd.grad(loss, x_adv)[0]
grad_norm = grad.norm().item()
x_adv = x_adv.detach()
if self.norm == "linf":
x_adv = project_linf(x_adv + self.eps * grad.sign(), x, self.eps)
elif self.norm == "l2":
step = grad / (L2_norm(grad, keepdim=True) + 1e-12)
x_adv = project_l2(x_adv + self.eps * step, x, self.eps)
elif self.norm == "l1":
step = grad / (L1_norm(grad, keepdim=True) + 1e-12)
x_adv = x + project_l1(x_adv + self.eps * step - x, self.eps)
x_adv = x_adv.clamp(0, 1)
# === Log step-wise metrics ===
self.logger.debug(f"[APGD][Restart {restart+1}][Iter {i+1}/{self.n_iter}] Grad norm: {grad_norm:.4f}")
if self.loss_type == "l2":
with torch.no_grad():
cos_sim_step = self._cos_sim_with_x(x_adv, emb_orig)
self.logger.debug(f"[APGD][Restart {restart+1}][Iter {i+1}] Cosine similarity: {cos_sim_step.item():.4f}")
elif self.loss_type in {"ce", "dlr", "ce-targeted", "dlr-targeted"}:
with torch.no_grad():
acc_step = self._acc_with_x(x_adv, y)
self.logger.debug(f"[APGD][Restart {restart+1}][Iter {i+1}] Accuracy: {acc_step.item() * 100:.2f}%")
self.logger.debug(f"[APGD][Restart {restart+1}][Iter {i+1}] Perturbed sample range: "
f"[{x_adv.min().item():.4f}, {x_adv.max().item():.4f}]")
del grad
best_adv = x_adv
# === Final metric ===
if self.loss_type == "l2":
with torch.no_grad():
final_cos_sim = self._cos_sim_with_x(best_adv, emb_orig)
self.logger.info(f"[APGD] Final cosine similarity: {final_cos_sim.item():.4f}")
elif self.loss_type in {"ce", "dlr", "ce-targeted", "dlr-targeted"}:
with torch.no_grad():
final_acc = self._acc_with_x(best_adv, y)
self.logger.info(f"[APGD] Final accuracy: {final_acc.item() * 100:.2f}%")
self.logger.info("[APGD] Attack completed.")
return best_adv
def _normalize(self, x):
if self.norm == "linf":
return x.sign()
elif self.norm == "l2":
return x / (x.view(x.size(0), -1).norm(p=2, dim=1, keepdim=True) + 1e-12)
elif self.norm == "l1":
return x / (x.view(x.size(0), -1).abs().sum(dim=1, keepdim=True) + 1e-12)
def _acc_with_x(self, x, y):
logits, _ = self.model(x, mode=ForwardMode.LOGITS)
return (logits.argmax(dim=1) == y).float().mean()
def _cos_sim_with_x(self, x, emb_orig):
x_emb = self.model(x, mode=ForwardMode.EMBEDDINGS)
return F.cosine_similarity(x_emb, emb_orig, dim=1).mean()
# =========================== Two-Stage Attack ===========================
def two_stage_attack(logger, model, inputs, labels, attack_stage1, attack_stage2, mean, std):
logger.info("Running two-stage attack...")
inputs_unorm = inputs.detach().clone()
unnormalize_inplace(inputs_unorm, mean, std)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
adv_stage1 = attack_stage1.perturb(inputs_unorm, labels)
normalize_inplace(adv_stage1, mean, std)
wrapped_adv_stage1 = model.wrap_tensor(adv_stage1)
with torch.no_grad():
logits_stage1, _ = model(wrapped_adv_stage1, ForwardMode.LOGITS)
preds_stage1 = logits_stage1.argmax(dim=1)
correct_mask = preds_stage1 == labels
keep_idx = correct_mask.nonzero(as_tuple=True)[0]
adv_final = adv_stage1.detach()
if len(keep_idx) > 0:
logger.info(f"Stage1 left {len(keep_idx)}/{inputs.size(0)} samples correct. Applying Stage2...")
with torch.no_grad():
inputs_unorm2 = inputs[keep_idx]
inputs_unorm2 = inputs_unorm2.detach().clone()
unnormalize_inplace(inputs_unorm2, mean, std)
adv_stage2 = attack_stage2.perturb(inputs_unorm2, labels[keep_idx])
with torch.no_grad():
normalize_inplace(adv_stage2, mean, std)
adv_final[keep_idx] = adv_stage2
return adv_final
def two_stage_attack_l2(logger, model, inputs, emb_orig, attack_stage1, attack_stage2, mean, std, cosine_threshold=0.2):
logger.info("Running two-stage attack (L∞, L2-loss + cosine filtering)...")
# === Clone and unnormalize input
inputs_unorm = inputs.detach().clone()
unnormalize_inplace(inputs_unorm, mean, std)
# === Stage 1 (with autocast to bfloat16)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
adv_stage1 = attack_stage1.perturb(inputs_unorm, None, emb_orig)
normalize_inplace(adv_stage1, mean, std)
# === Evaluate cosine similarity and L2 distance
with torch.no_grad():
emb_stage1 = model(adv_stage1, mode=ForwardMode.EMBEDDINGS)
# Cosine similarity between each pair: [B]
cosine_sim = F.cosine_similarity(emb_stage1, emb_orig, dim=1)
l2_dists = torch.norm(emb_stage1 - emb_orig, dim=1)
logger.info(f"Stage1 Cosine similarity | min: {cosine_sim.min():.4f}, "
f"mean: {cosine_sim.mean():.4f}, max: {cosine_sim.max():.4f}")
logger.info(f"Stage1 L2 distances | min: {l2_dists.min():.4f}, "
f"mean: {l2_dists.mean():.4f}, max: {l2_dists.max():.4f}")
# Keep only samples still too close to original
keep_idx = (cosine_sim >= cosine_threshold).nonzero(as_tuple=True)[0]
adv_final = adv_stage1.detach().clone()
# === Stage 2 (refine only filtered samples)
if len(keep_idx) > 0:
logger.info(f"Refining {len(keep_idx)} samples in Stage 2 (cosine ≥ {cosine_threshold})")
with torch.no_grad():
inputs_unorm2 = inputs[keep_idx].detach().clone()
unnormalize_inplace(inputs_unorm2, mean, std)
adv_stage2 = attack_stage2.perturb(inputs_unorm2, None, emb_orig[keep_idx])
with torch.no_grad():
normalize_inplace(adv_stage2, mean, std)
adv_final[keep_idx] = adv_stage2
else:
logger.info("No samples required Stage 2 refinement.")
return adv_final
# =========================== Norm Helpers ===========================
def L2_norm(x, keepdim=False):
norm = x.view(x.shape[0], -1).norm(p=2, dim=1)
return norm.view(-1, *[1] * (x.dim() - 1)) if keepdim else norm
def L1_norm(x, keepdim=False):
norm = x.view(x.shape[0], -1).abs().sum(dim=1)
return norm.view(-1, *[1] * (x.dim() - 1)) if keepdim else norm
def project_linf(x_adv, x_orig, eps):
return torch.max(torch.min(x_adv, x_orig + eps), x_orig - eps)
def project_l2(x_adv, x_orig, eps):
delta = x_adv - x_orig
norm = delta.view(delta.size(0), -1).norm(p=2, dim=1, keepdim=True)
factor = torch.clamp(eps / (norm + 1e-12), max=1.0)
delta = delta * factor.view(-1, *[1] * (delta.ndim - 1))
return x_orig + delta
def project_l1(delta, eps):
flat = delta.view(delta.size(0), -1)
norm = flat.abs().sum(dim=1, keepdim=True)
factor = (eps / (norm + 1e-12)).clamp(max=1.0)
flat = flat * factor
return flat.view_as(delta)
def random_start_linf(x, eps, clamp_min, clamp_max):
delta = (2.0 * torch.rand_like(x) - 1.0) * eps
return (x + delta).clamp(clamp_min, clamp_max)
def random_start_l2(x, eps, clamp_min, clamp_max):
delta = torch.randn_like(x)
norm = delta.view(delta.size(0), -1).norm(p=2, dim=1, keepdim=True)
delta = delta * (eps / (norm + 1e-12)).view(-1, *[1] * (x.dim() - 1))
return (x + delta).clamp(clamp_min, clamp_max)
def random_start_l1(x, eps, clamp_min, clamp_max):
delta = torch.randn_like(x)
delta = project_l1(delta, eps)
return (x + delta).clamp(clamp_min, clamp_max)
@contextmanager
def no_model_grads(model):
backup = {name: p.requires_grad for name, p in model.named_parameters()}
for p in model.parameters():
p.requires_grad = False
try:
yield
finally:
for name, p in model.named_parameters():
p.requires_grad = backup[name]