-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
rbcd_attack.py
763 lines (644 loc) · 32.5 KB
/
rbcd_attack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
from collections import defaultdict
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from tqdm import tqdm
from torch_geometric.utils import coalesce, to_undirected
# (predictions, labels, ids/mask) -> Tensor with one element
LOSS_TYPE = Callable[[Tensor, Tensor, Optional[Tensor]], Tensor]
class PRBCDAttack(torch.nn.Module):
r"""The Projected Randomized Block Coordinate Descent (PRBCD) adversarial
attack from the `Robustness of Graph Neural Networks at Scale
<https://www.cs.cit.tum.de/daml/robustness-of-gnns-at-scale>`_ paper.
This attack uses an efficient gradient based approach that (during the
attack) relaxes the discrete entries in the adjacency matrix
:math:`\{0, 1\}` to :math:`[0, 1]` and solely perturbs the adjacency matrix
(no feature perturbations). Thus, this attack supports all models that can
handle weighted graphs that are differentiable w.r.t. these edge weights,
*e.g.*, :class:`~torch_geometric.nn.conv.GCNConv` or
:class:`~torch_geometric.nn.conv.GraphConv`. For non-differentiable models
you might need modifications, e.g., see example for
:class:`~torch_geometric.nn.conv.GATConv`.
The memory overhead is driven by the additional edges (at most
:attr:`block_size`). For scalability reasons, the block is drawn with
replacement and then the index is made unique. Thus, the actual block size
is typically slightly smaller than specified.
This attack can be used for both global and local attacks as well as
test-time attacks (evasion) and training-time attacks (poisoning). Please
see the provided examples.
This attack is designed with a focus on node- or graph-classification,
however, to adapt to other tasks you most likely only need to provide an
appropriate loss and model. However, we currently do not support batching
out of the box (sampling needs to be adapted).
.. note::
For examples of using the PRBCD Attack, see
`examples/contrib/rbcd_attack.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
contrib/rbcd_attack.py>`_
for a test time attack (evasion) or
`examples/contrib/rbcd_attack_poisoning.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
contrib/rbcd_attack_poisoning.py>`_
for a training time (poisoning) attack.
Args:
model (torch.nn.Module): The GNN module to assess.
block_size (int): Number of randomly selected elements in the
adjacency matrix to consider.
epochs (int, optional): Number of epochs (aborts early if
:obj:`mode='greedy'` and budget is satisfied) (default: :obj:`125`)
epochs_resampling (int, optional): Number of epochs to resample the
random block. (default: obj:`100`)
loss (str or callable, optional): A loss to quantify the "strength" of
an attack. Note that this function must match the output format of
:attr:`model`. By default, it is assumed that the task is
classification and that the model returns raw predictions (*i.e.*,
no output activation) or uses :obj:`logsoftmax`. Moreover, and the
number of predictions should match the number of labels passed to
:attr:`attack`. Either pass a callable or one of: :obj:`'masked'`,
:obj:`'margin'`, :obj:`'prob_margin'`, :obj:`'tanh_margin'`.
(default: :obj:`'prob_margin'`)
metric (callable, optional): Second (potentially
non-differentiable) loss for monitoring or early stopping (if
:obj:`mode='greedy'`). (default: same as :attr:`loss`)
lr (float, optional): Learning rate for updating edge weights.
Additionally, it is heuristically corrected for :attr:`block_size`,
budget (see :attr:`attack`) and graph size. (default: :obj:`1_000`)
is_undirected (bool, optional): If :obj:`True` the graph is
assumed to be undirected. (default: :obj:`True`)
log (bool, optional): If set to :obj:`False`, will not log any learning
progress. (default: :obj:`True`)
"""
coeffs = {
'max_final_samples': 20,
'max_trials_sampling': 20,
'with_early_stopping': True,
'eps': 1e-7
}
def __init__(
self,
model: torch.nn.Module,
block_size: int,
epochs: int = 125,
epochs_resampling: int = 100,
loss: Optional[Union[str, LOSS_TYPE]] = 'prob_margin',
metric: Optional[Union[str, LOSS_TYPE]] = None,
lr: float = 1_000,
is_undirected: bool = True,
log: bool = True,
**kwargs,
):
super().__init__()
self.model = model
self.block_size = block_size
self.epochs = epochs
if isinstance(loss, str):
if loss == 'masked':
self.loss = self._masked_cross_entropy
elif loss == 'margin':
self.loss = partial(self._margin_loss, reduce='mean')
elif loss == 'prob_margin':
self.loss = self._probability_margin_loss
elif loss == 'tanh_margin':
self.loss = self._tanh_margin_loss
else:
raise ValueError(f'Unknown loss `{loss}`')
else:
self.loss = loss
self.is_undirected = is_undirected
self.log = log
self.metric = metric or self.loss
self.epochs_resampling = epochs_resampling
self.lr = lr
self.coeffs.update(kwargs)
def attack(
self,
x: Tensor,
edge_index: Tensor,
labels: Tensor,
budget: int,
idx_attack: Optional[Tensor] = None,
**kwargs,
) -> Tuple[Tensor, Tensor]:
"""Attack the predictions for the provided model and graph.
A subset of predictions may be specified with :attr:`idx_attack`. The
attack is allowed to flip (i.e. add or delete) :attr:`budget` edges and
will return the strongest perturbation it can find. It returns both the
resulting perturbed :attr:`edge_index` as well as the perturbations.
Args:
x (torch.Tensor): The node feature matrix.
edge_index (torch.Tensor): The edge indices.
labels (torch.Tensor): The labels.
budget (int): The number of allowed perturbations (i.e.
number of edges that are flipped at most).
idx_attack (torch.Tensor, optional): Filter for predictions/labels.
Shape and type must match that it can index :attr:`labels`
and the model's predictions.
**kwargs (optional): Additional arguments passed to the GNN module.
:rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`)
"""
self.model.eval()
self.device = x.device
assert kwargs.get('edge_weight') is None
edge_weight = torch.ones(edge_index.size(1), device=self.device)
self.edge_index = edge_index.cpu().clone()
self.edge_weight = edge_weight.cpu().clone()
self.num_nodes = x.size(0)
# For collecting attack statistics
self.attack_statistics = defaultdict(list)
# Prepare attack and define `self.iterable` to iterate over
step_sequence = self._prepare(budget)
# Loop over the epochs (Algorithm 1, line 5)
for step in tqdm(step_sequence, disable=not self.log, desc='Attack'):
loss, gradient = self._forward_and_gradient(
x, labels, idx_attack, **kwargs)
scalars = self._update(step, gradient, x, labels, budget,
idx_attack, **kwargs)
scalars['loss'] = loss.item()
self._append_statistics(scalars)
perturbed_edge_index, flipped_edges = self._close(
x, labels, budget, idx_attack, **kwargs)
assert flipped_edges.size(1) <= budget, (
f'# perturbed edges {flipped_edges.size(1)} '
f'exceeds budget {budget}')
return perturbed_edge_index, flipped_edges
def _prepare(self, budget: int) -> Iterable[int]:
"""Prepare attack."""
if self.block_size <= budget:
raise ValueError(
f'The search space size ({self.block_size}) must be '
f'greater than the number of permutations ({budget})')
# For early stopping (not explicitly covered by pseudo code)
self.best_metric = float('-Inf')
# Sample initial search space (Algorithm 1, line 3-4)
self._sample_random_block(budget)
steps = range(self.epochs)
return steps
@torch.no_grad()
def _update(self, epoch: int, gradient: Tensor, x: Tensor, labels: Tensor,
budget: int, idx_attack: Optional[Tensor] = None,
**kwargs) -> Dict[str, float]:
"""Update edge weights given gradient."""
# Gradient update step (Algorithm 1, line 7)
self.block_edge_weight = self._update_edge_weights(
budget, self.block_edge_weight, epoch, gradient)
# For monitoring
pmass_update = torch.clamp(self.block_edge_weight, 0, 1)
# Projection to stay within relaxed `L_0` budget
# (Algorithm 1, line 8)
self.block_edge_weight = self._project(budget, self.block_edge_weight,
self.coeffs['eps'])
# For monitoring
scalars = dict(
prob_mass_after_update=pmass_update.sum().item(),
prob_mass_after_update_max=pmass_update.max().item(),
prob_mass_after_projection=self.block_edge_weight.sum().item(),
prob_mass_after_projection_nonzero_weights=(
self.block_edge_weight > self.coeffs['eps']).sum().item(),
prob_mass_after_projection_max=self.block_edge_weight.max().item())
if not self.coeffs['with_early_stopping']:
return scalars
# Calculate metric after the current epoch (overhead
# for monitoring and early stopping)
topk_block_edge_weight = torch.zeros_like(self.block_edge_weight)
topk_block_edge_weight[torch.topk(self.block_edge_weight,
budget).indices] = 1
edge_index, edge_weight = self._get_modified_adj(
self.edge_index, self.edge_weight, self.block_edge_index,
topk_block_edge_weight)
prediction = self._forward(x, edge_index, edge_weight, **kwargs)
metric = self.metric(prediction, labels, idx_attack)
# Save best epoch for early stopping
# (not explicitly covered by pseudo code)
if metric > self.best_metric:
self.best_metric = metric
self.best_block = self.current_block.cpu().clone()
self.best_edge_index = self.block_edge_index.cpu().clone()
self.best_pert_edge_weight = self.block_edge_weight.cpu().clone()
# Resampling of search space (Algorithm 1, line 9-14)
if epoch < self.epochs_resampling - 1:
self._resample_random_block(budget)
elif epoch == self.epochs_resampling - 1:
# Retrieve best epoch if early stopping is active
# (not explicitly covered by pseudo code)
self.current_block = self.best_block.to(self.device)
self.block_edge_index = self.best_edge_index.to(self.device)
block_edge_weight = self.best_pert_edge_weight.clone()
self.block_edge_weight = block_edge_weight.to(self.device)
scalars['metric'] = metric.item()
return scalars
@torch.no_grad()
def _close(self, x: Tensor, labels: Tensor, budget: int,
idx_attack: Optional[Tensor] = None,
**kwargs) -> Tuple[Tensor, Tensor]:
"""Clean up and prepare return argument."""
# Retrieve best epoch if early stopping is active
# (not explicitly covered by pseudo code)
if self.coeffs['with_early_stopping']:
self.current_block = self.best_block.to(self.device)
self.block_edge_index = self.best_edge_index.to(self.device)
self.block_edge_weight = self.best_pert_edge_weight.to(self.device)
# Sample final discrete graph (Algorithm 1, line 16)
edge_index, flipped_edges = self._sample_final_edges(
x, labels, budget, idx_attack=idx_attack, **kwargs)
return edge_index, flipped_edges
def _forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor,
**kwargs) -> Tensor:
"""Forward model."""
return self.model(x, edge_index, edge_weight, **kwargs)
def _forward_and_gradient(self, x: Tensor, labels: Tensor,
idx_attack: Optional[Tensor] = None,
**kwargs) -> Tuple[Tensor, Tensor]:
"""Forward and update edge weights."""
self.block_edge_weight.requires_grad = True
# Retrieve sparse perturbed adjacency matrix `A \oplus p_{t-1}`
# (Algorithm 1, line 6 / Algorithm 2, line 7)
edge_index, edge_weight = self._get_modified_adj(
self.edge_index, self.edge_weight, self.block_edge_index,
self.block_edge_weight)
# Get prediction (Algorithm 1, line 6 / Algorithm 2, line 7)
prediction = self._forward(x, edge_index, edge_weight, **kwargs)
# Calculate loss combining all each node
# (Algorithm 1, line 7 / Algorithm 2, line 8)
loss = self.loss(prediction, labels, idx_attack)
# Retrieve gradient towards the current block
# (Algorithm 1, line 7 / Algorithm 2, line 8)
gradient = torch.autograd.grad(loss, self.block_edge_weight)[0]
return loss, gradient
def _get_modified_adj(self, edge_index: Tensor, edge_weight: Tensor,
block_edge_index: Tensor,
block_edge_weight: Tensor) -> Tuple[Tensor, Tensor]:
"""Merges adjacency matrix with current block (incl. weights)."""
if self.is_undirected:
block_edge_index, block_edge_weight = to_undirected(
block_edge_index, block_edge_weight, num_nodes=self.num_nodes,
reduce='mean')
modified_edge_index = torch.cat(
(edge_index.to(self.device), block_edge_index), dim=-1)
modified_edge_weight = torch.cat(
(edge_weight.to(self.device), block_edge_weight))
modified_edge_index, modified_edge_weight = coalesce(
modified_edge_index, modified_edge_weight,
num_nodes=self.num_nodes, reduce='sum')
# Allow (soft) removal of edges
is_edge_in_clean_adj = modified_edge_weight > 1
modified_edge_weight[is_edge_in_clean_adj] = (
2 - modified_edge_weight[is_edge_in_clean_adj])
return modified_edge_index, modified_edge_weight
def _filter_self_loops_in_block(self, with_weight: bool):
is_not_sl = self.block_edge_index[0] != self.block_edge_index[1]
self.current_block = self.current_block[is_not_sl]
self.block_edge_index = self.block_edge_index[:, is_not_sl]
if with_weight:
self.block_edge_weight = self.block_edge_weight[is_not_sl]
def _sample_random_block(self, budget: int = 0):
for _ in range(self.coeffs['max_trials_sampling']):
num_possible_edges = self._num_possible_edges(
self.num_nodes, self.is_undirected)
self.current_block = torch.randint(num_possible_edges,
(self.block_size, ),
device=self.device)
self.current_block = torch.unique(self.current_block, sorted=True)
if self.is_undirected:
self.block_edge_index = self._linear_to_triu_idx(
self.num_nodes, self.current_block)
else:
self.block_edge_index = self._linear_to_full_idx(
self.num_nodes, self.current_block)
self._filter_self_loops_in_block(with_weight=False)
self.block_edge_weight = torch.full(self.current_block.shape,
self.coeffs['eps'],
device=self.device)
if self.current_block.size(0) >= budget:
return
raise RuntimeError('Sampling random block was not successful. '
'Please decrease `budget`.')
def _resample_random_block(self, budget: int):
# Keep at most half of the block (i.e. resample low weights)
sorted_idx = torch.argsort(self.block_edge_weight)
keep_above = (self.block_edge_weight
<= self.coeffs['eps']).sum().long()
if keep_above < sorted_idx.size(0) // 2:
keep_above = sorted_idx.size(0) // 2
sorted_idx = sorted_idx[keep_above:]
self.current_block = self.current_block[sorted_idx]
# Sample until enough edges were drawn
for _ in range(self.coeffs['max_trials_sampling']):
n_edges_resample = self.block_size - self.current_block.size(0)
num_possible_edges = self._num_possible_edges(
self.num_nodes, self.is_undirected)
lin_index = torch.randint(num_possible_edges, (n_edges_resample, ),
device=self.device)
current_block = torch.cat((self.current_block, lin_index))
self.current_block, unique_idx = torch.unique(
current_block, sorted=True, return_inverse=True)
if self.is_undirected:
self.block_edge_index = self._linear_to_triu_idx(
self.num_nodes, self.current_block)
else:
self.block_edge_index = self._linear_to_full_idx(
self.num_nodes, self.current_block)
# Merge existing weights with new edge weights
block_edge_weight_prev = self.block_edge_weight[sorted_idx]
self.block_edge_weight = torch.full(self.current_block.shape,
self.coeffs['eps'],
device=self.device)
self.block_edge_weight[
unique_idx[:sorted_idx.size(0)]] = block_edge_weight_prev
if not self.is_undirected:
self._filter_self_loops_in_block(with_weight=True)
if self.current_block.size(0) > budget:
return
raise RuntimeError('Sampling random block was not successful.'
'Please decrease `budget`.')
def _sample_final_edges(self, x: Tensor, labels: Tensor, budget: int,
idx_attack: Optional[Tensor] = None,
**kwargs) -> Tuple[Tensor, Tensor]:
best_metric = float('-Inf')
block_edge_weight = self.block_edge_weight
block_edge_weight[block_edge_weight <= self.coeffs['eps']] = 0
for i in range(self.coeffs['max_final_samples']):
if i == 0:
# In first iteration employ top k heuristic instead of sampling
sampled_edges = torch.zeros_like(block_edge_weight)
sampled_edges[torch.topk(block_edge_weight,
budget).indices] = 1
else:
sampled_edges = torch.bernoulli(block_edge_weight).float()
if sampled_edges.sum() > budget:
# Allowed budget is exceeded
continue
edge_index, edge_weight = self._get_modified_adj(
self.edge_index, self.edge_weight, self.block_edge_index,
sampled_edges)
prediction = self._forward(x, edge_index, edge_weight, **kwargs)
metric = self.metric(prediction, labels, idx_attack)
# Save best sample
if metric > best_metric:
best_metric = metric
self.block_edge_weight = sampled_edges.clone().cpu()
# Recover best sample
self.block_edge_weight = self.block_edge_weight.to(self.device)
flipped_edges = self.block_edge_index[:, self.block_edge_weight > 0]
edge_index, edge_weight = self._get_modified_adj(
self.edge_index, self.edge_weight, self.block_edge_index,
self.block_edge_weight)
edge_mask = edge_weight == 1
edge_index = edge_index[:, edge_mask]
return edge_index, flipped_edges
def _update_edge_weights(self, budget: int, block_edge_weight: Tensor,
epoch: int, gradient: Tensor) -> Tensor:
# The learning rate is refined heuristically, s.t. (1) it is
# independent of the number of perturbations (assuming an undirected
# adjacency matrix) and (2) to decay learning rate during fine-tuning
# (i.e. fixed search space).
lr = (budget / self.num_nodes * self.lr /
np.sqrt(max(0, epoch - self.epochs_resampling) + 1))
return block_edge_weight + lr * gradient
@staticmethod
def _project(budget: int, values: Tensor, eps: float = 1e-7) -> Tensor:
r"""Project :obj:`values`:
:math:`budget \ge \sum \Pi_{[0, 1]}(\text{values})`.
"""
if torch.clamp(values, 0, 1).sum() > budget:
left = (values - 1).min()
right = values.max()
miu = PRBCDAttack._bisection(values, left, right, budget)
values = values - miu
return torch.clamp(values, min=eps, max=1 - eps)
@staticmethod
def _bisection(edge_weights: Tensor, a: float, b: float, n_pert: int,
eps=1e-5, max_iter=1e3) -> Tensor:
"""Bisection search for projection."""
def shift(offset: float):
return (torch.clamp(edge_weights - offset, 0, 1).sum() - n_pert)
miu = a
for _ in range(int(max_iter)):
miu = (a + b) / 2
# Check if middle point is root
if (shift(miu) == 0.0):
break
# Decide the side to repeat the steps
if (shift(miu) * shift(a) < 0):
b = miu
else:
a = miu
if ((b - a) <= eps):
break
return miu
@staticmethod
def _num_possible_edges(n: int, is_undirected: bool) -> int:
"""Determine number of possible edges for graph."""
if is_undirected:
return n * (n - 1) // 2
else:
return int(n**2) # We filter self-loops later
@staticmethod
def _linear_to_triu_idx(n: int, lin_idx: Tensor) -> Tensor:
"""Linear index to upper triangular matrix without diagonal. This is
similar to
https://stackoverflow.com/questions/242711/algorithm-for-index-numbers-of-triangular-matrix-coefficients/28116498#28116498
with number nodes decremented and col index incremented by one.
"""
nn = n * (n - 1)
row_idx = n - 2 - torch.floor(
torch.sqrt(-8 * lin_idx.double() + 4 * nn - 7) / 2.0 - 0.5).long()
col_idx = 1 + lin_idx + row_idx - nn // 2 + torch.div(
(n - row_idx) * (n - row_idx - 1), 2, rounding_mode='floor')
return torch.stack((row_idx, col_idx))
@staticmethod
def _linear_to_full_idx(n: int, lin_idx: Tensor) -> Tensor:
"""Linear index to dense matrix including diagonal."""
row_idx = torch.div(lin_idx, n, rounding_mode='floor')
col_idx = lin_idx % n
return torch.stack((row_idx, col_idx))
@staticmethod
def _margin_loss(score: Tensor, labels: Tensor,
idx_mask: Optional[Tensor] = None,
reduce: Optional[str] = None) -> Tensor:
r"""Margin loss between true score and highest non-target score.
.. math::
m = - s_{y} + max_{y' \ne y} s_{y'}
where :math:`m` is the margin :math:`s` the score and :math:`y` the
labels.
Args:
score (Tensor): Some score (*e.g.*, logits) of shape
:obj:`[n_elem, dim]`.
labels (LongTensor): The labels of shape :obj:`[n_elem]`.
idx_mask (Tensor, optional): To select subset of `score` and
`labels` of shape :obj:`[n_select]`. Defaults to None.
reduce (str, optional): if :obj:`mean` the result is aggregated.
Otherwise, return element wise margin.
:rtype: (Tensor)
"""
if idx_mask is not None:
score = score[idx_mask]
labels = labels[idx_mask]
linear_idx = torch.arange(score.size(0), device=score.device)
true_score = score[linear_idx, labels]
score = score.clone()
score[linear_idx, labels] = float('-Inf')
best_non_target_score = score.amax(dim=-1)
margin_ = best_non_target_score - true_score
if reduce is None:
return margin_
return margin_.mean()
@staticmethod
def _tanh_margin_loss(prediction: Tensor, labels: Tensor,
idx_mask: Optional[Tensor] = None) -> Tensor:
"""Calculate tanh margin loss, a node-classification loss that focuses
on nodes next to decision boundary.
Args:
prediction (Tensor): Prediction of shape :obj:`[n_elem, dim]`.
labels (LongTensor): The labels of shape :obj:`[n_elem]`.
idx_mask (Tensor, optional): To select subset of `score` and
`labels` of shape :obj:`[n_select]`. Defaults to None.
:rtype: (Tensor)
"""
log_prob = F.log_softmax(prediction, dim=-1)
margin_ = GRBCDAttack._margin_loss(log_prob, labels, idx_mask)
loss = torch.tanh(margin_).mean()
return loss
@staticmethod
def _probability_margin_loss(prediction: Tensor, labels: Tensor,
idx_mask: Optional[Tensor] = None) -> Tensor:
"""Calculate probability margin loss, a node-classification loss that
focuses on nodes next to decision boundary. See `Are Defenses for
Graph Neural Networks Robust?
<https://www.cs.cit.tum.de/daml/are-gnn-defenses-robust>`_ for details.
Args:
prediction (Tensor): Prediction of shape :obj:`[n_elem, dim]`.
labels (LongTensor): The labels of shape :obj:`[n_elem]`.
idx_mask (Tensor, optional): To select subset of `score` and
`labels` of shape :obj:`[n_select]`. Defaults to None.
:rtype: (Tensor)
"""
prob = F.softmax(prediction, dim=-1)
margin_ = GRBCDAttack._margin_loss(prob, labels, idx_mask)
return margin_.mean()
@staticmethod
def _masked_cross_entropy(log_prob: Tensor, labels: Tensor,
idx_mask: Optional[Tensor] = None) -> Tensor:
"""Calculate masked cross entropy loss, a node-classification loss that
focuses on nodes next to decision boundary.
Args:
log_prob (Tensor): Log probabilities of shape :obj:`[n_elem, dim]`.
labels (LongTensor): The labels of shape :obj:`[n_elem]`.
idx_mask (Tensor, optional): To select subset of `score` and
`labels` of shape :obj:`[n_select]`. Defaults to None.
:rtype: (Tensor)
"""
if idx_mask is not None:
log_prob = log_prob[idx_mask]
labels = labels[idx_mask]
is_correct = log_prob.argmax(-1) == labels
if is_correct.any():
log_prob = log_prob[is_correct]
labels = labels[is_correct]
return F.nll_loss(log_prob, labels)
def _append_statistics(self, mapping: Dict[str, Any]):
for key, value in mapping.items():
self.attack_statistics[key].append(value)
def __repr__(self) -> str:
return f'{self.__class__.__name__}()'
class GRBCDAttack(PRBCDAttack):
r"""The Greedy Randomized Block Coordinate Descent (GRBCD) adversarial
attack from the `Robustness of Graph Neural Networks at Scale
<https://www.cs.cit.tum.de/daml/robustness-of-gnns-at-scale>`_ paper.
GRBCD shares most of the properties and requirements with
:class:`PRBCDAttack`. It also uses an efficient gradient based approach.
However, it greedily flips edges based on the gradient towards the
adjacency matrix.
.. note::
For examples of using the GRBCD Attack, see
`examples/contrib/rbcd_attack.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
contrib/rbcd_attack.py>`_
for a test time attack (evasion).
Args:
model (torch.nn.Module): The GNN module to assess.
block_size (int): Number of randomly selected elements in the
adjacency matrix to consider.
epochs (int, optional): Number of epochs (aborts early if
:obj:`mode='greedy'` and budget is satisfied) (default: :obj:`125`)
loss (str or callable, optional): A loss to quantify the "strength" of
an attack. Note that this function must match the output format of
:attr:`model`. By default, it is assumed that the task is
classification and that the model returns raw predictions (*i.e.*,
no output activation) or uses :obj:`logsoftmax`. Moreover, and the
number of predictions should match the number of labels passed to
:attr:`attack`. Either pass Callable or one of: :obj:`'masked'`,
:obj:`'margin'`, :obj:`'prob_margin'`, :obj:`'tanh_margin'`.
(default: :obj:`'masked'`)
is_undirected (bool, optional): If :obj:`True` the graph is
assumed to be undirected. (default: :obj:`True`)
log (bool, optional): If set to :obj:`False`, will not log any learning
progress. (default: :obj:`True`)
"""
coeffs = {'max_trials_sampling': 20, 'eps': 1e-7}
def __init__(
self,
model: torch.nn.Module,
block_size: int,
epochs: int = 125,
loss: Optional[Union[str, LOSS_TYPE]] = 'masked',
is_undirected: bool = True,
log: bool = True,
**kwargs,
):
super().__init__(model, block_size, epochs, loss=loss,
is_undirected=is_undirected, log=log, **kwargs)
@torch.no_grad()
def _prepare(self, budget: int) -> List[int]:
"""Prepare attack."""
self.flipped_edges = self.edge_index.new_empty(2, 0).to(self.device)
# Determine the number of edges to be flipped in each attach step/epoch
step_size = budget // self.epochs
if step_size > 0:
steps = self.epochs * [step_size]
for i in range(budget % self.epochs):
steps[i] += 1
else:
steps = [1] * budget
# Sample initial search space (Algorithm 2, line 3-4)
self._sample_random_block(step_size)
return steps
@torch.no_grad()
def _update(self, step_size: int, gradient: Tensor, *args,
**kwargs) -> Dict[str, Any]:
"""Update edge weights given gradient."""
_, topk_edge_index = torch.topk(gradient, step_size)
flip_edge_index = self.block_edge_index[:, topk_edge_index]
flip_edge_weight = torch.ones_like(flip_edge_index[0],
dtype=torch.float32)
self.flipped_edges = torch.cat((self.flipped_edges, flip_edge_index),
axis=-1)
if self.is_undirected:
flip_edge_index, flip_edge_weight = to_undirected(
flip_edge_index, flip_edge_weight, num_nodes=self.num_nodes,
reduce='mean')
edge_index = torch.cat(
(self.edge_index.to(self.device), flip_edge_index.to(self.device)),
dim=-1)
edge_weight = torch.cat((self.edge_weight.to(self.device),
flip_edge_weight.to(self.device)))
edge_index, edge_weight = coalesce(edge_index, edge_weight,
num_nodes=self.num_nodes,
reduce='sum')
is_one_mask = torch.isclose(edge_weight, torch.tensor(1.))
self.edge_index = edge_index[:, is_one_mask]
self.edge_weight = edge_weight[is_one_mask]
# self.edge_weight = torch.ones_like(self.edge_weight)
assert self.edge_index.size(1) == self.edge_weight.size(0)
# Sample initial search space (Algorithm 2, line 3-4)
self._sample_random_block(step_size)
# Return debug information
scalars = {
'number_positive_entries_in_gradient': (gradient > 0).sum().item()
}
return scalars
def _close(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
"""Clean up and prepare return argument."""
return self.edge_index, self.flipped_edges