From 9bdaf7a21eb077a84aa3a7be090264b41c4646ab Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Mon, 28 Nov 2022 10:08:55 +0800 Subject: [PATCH 01/15] RBCD --- greatx/attack/untargeted/rbcd.py | 500 +++++++++++++++++++++++++++++++ greatx/functional/__init__.py | 18 +- greatx/functional/losses.py | 104 +++++++ 3 files changed, 619 insertions(+), 3 deletions(-) create mode 100644 greatx/attack/untargeted/rbcd.py create mode 100644 greatx/functional/losses.py diff --git a/greatx/attack/untargeted/rbcd.py b/greatx/attack/untargeted/rbcd.py new file mode 100644 index 0000000..d61b9d9 --- /dev/null +++ b/greatx/attack/untargeted/rbcd.py @@ -0,0 +1,500 @@ +from collections import defaultdict +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.autograd import grad +from torch_geometric.utils import coalesce, to_undirected +from tqdm import tqdm +from tqdm.auto import tqdm + +from greatx.attack.untargeted.untargeted_attacker import UntargetedAttacker +from greatx.functional import ( + margin_loss, + masked_cross_entropy, + probability_margin_loss, + tanh_margin_loss, +) +from greatx.nn.models.surrogate import Surrogate +from greatx.utils import singleton_mask + +# (predictions, labels, ids/mask) -> Tensor with one element +LOSS_TYPE = Callable[[Tensor, Tensor, Optional[Tensor]], Tensor] + + +class PRBCDAttack(UntargetedAttacker, Surrogate): + # FGAttack can conduct feature attack + _allow_feature_attack: bool = True + is_undirected_graph: bool = True # TODO + + coeffs: Dict[str, Any] = { + 'max_final_samples': 20, + 'max_trials_sampling': 20, + 'with_early_stopping': True, + 'eps': 1e-7 + } + + def setup_surrogate(self, surrogate: torch.nn.Module, victim_nodes: Tensor, + victim_labels: Optional[Tensor] = None, *, + eps: float = 1.0): + + Surrogate.setup_surrogate(self, surrogate=surrogate, eps=eps, + freeze=True) + + if victim_nodes.dtype == torch.bool: + victim_nodes = victim_nodes.nonzero().view(-1) + self.victim_nodes = victim_nodes.to(self.device) + + if victim_labels is None: + victim_labels = self.label[victim_nodes] + self.victim_labels = victim_labels.to(self.device) + return self + + def reset(self): + super().reset() + self.current_block = None + self.block_edge_index = None + self.block_edge_weight = None + return self + + def attack(self, num_budgets: Union[int, float] = 0.05, + block_size: int = 250_000, epochs: int = 125, + epochs_resampling: int = 100, + loss: Optional[Union[str, LOSS_TYPE]] = 'prob_margin', + metric: Optional[Union[str, + LOSS_TYPE]] = None, lr: float = 1_000, *, + structure_attack: bool = True, feature_attack: bool = False, + disable: bool = False, **kwargs) -> "PRBCDAttack": + + super().attack(num_budgets=num_budgets, + structure_attack=structure_attack, + feature_attack=feature_attack) + + self.block_size = block_size + self.epochs = epochs + + if isinstance(loss, str): + if loss == 'masked': + self.loss = masked_cross_entropy + elif loss == 'margin': + self.loss = margin_loss + elif loss == 'prob_margin': + self.loss = probability_margin_loss + elif loss == 'tanh_margin': + self.loss = tanh_margin_loss + else: + raise ValueError(f'Unknown loss `{loss}`') + else: + self.loss = loss + + if metric is None: + self.metric = self.loss + else: + self.metric = metric + + self.epochs_resampling = epochs_resampling + self.lr = lr + + # self.coeffs.update(**kwargs) # TODO + self.edge_weights = torch.ones(self.edge_index.size(1), + device=self.device) + + # For collecting attack statistics + self.attack_statistics = defaultdict(list) + + # Prepare attack and define `self.iterable` to iterate over + budget = self.budget + + self.best_metric = float('-Inf') + # Sample initial search space (Algorithm 1, line 3-4) + self._sample_random_block(budget) + + # Loop over the epochs (Algorithm 1, line 5) + for step in tqdm(range(self.budgets), desc='Peturbing graph...', + disable=disable): + loss, gradient = self._forward_and_gradient( + self.feat, self.label, self.victim_nodes, **kwargs) + + scalars = self._update(step, gradient, self.feat, self.label, + budget, self.victim_nodes, **kwargs) + + scalars['loss'] = loss.item() + self._append_statistics(scalars) + + self._close(self.feat, self.label, budget, self.victim_nodes, **kwargs) + + return self + + @torch.no_grad() + def _prepare(self, budget: int) -> Iterable[int]: + """Prepare attack.""" + # For early stopping (not explicitly covered by pseudo code) + self.best_metric = float('-Inf') + + # Sample initial search space (Algorithm 1, line 3-4) + self._sample_random_block(budget) + + steps = range(self.epochs) + return steps + + @torch.no_grad() + def _update(self, epoch: int, gradient: Tensor, x: Tensor, labels: Tensor, + budget: int, idx_attack: Optional[Tensor] = None, + **kwargs) -> Dict[str, float]: + """Update edge weights given gradient.""" + # Gradient update step (Algorithm 1, line 7) + self._update_edge_weights(budget, epoch, gradient) + # For monitoring + pmass_update = torch.clamp(self.block_edge_weight, 0, 1) + # Projection to stay within relaxed `L_0` budget + # (Algorithm 1, line 8) + self.block_edge_weight = self._project(budget, self.block_edge_weight, + self.coeffs['eps']) + + # For monitoring + scalars = dict( + prob_mass_after_update=pmass_update.sum().item(), + prob_mass_after_update_max=pmass_update.max().item(), + prob_mass_after_projection=self.block_edge_weight.sum().item(), + prob_mass_after_projection_nonzero_weights=( + self.block_edge_weight > self.coeffs['eps']).sum().item(), + prob_mass_after_projection_max=self.block_edge_weight.max().item()) + if not self.coeffs['with_early_stopping']: + return scalars + + # Calculate metric after the current epoch (overhead + # for monitoring and early stopping) + topk_block_edge_weight = torch.zeros_like(self.block_edge_weight) + topk_block_edge_weight[torch.topk(self.block_edge_weight, + budget).indices] = 1 + edge_index, edge_weight = self._get_modified_adj( + self.edge_index, self.edge_weights, self.block_edge_index, + topk_block_edge_weight) + prediction = self._forward(x, edge_index, edge_weight, **kwargs) + metric = self.metric(prediction, labels, idx_attack) + + # Save best epoch for early stopping + # (not explicitly covered by pseudo code) + if metric > self.best_metric: + self.best_metric = metric + self.best_block = self.current_block.cpu().clone() + self.best_edge_index = self.block_edge_index.cpu().clone() + self.best_pert_edge_weight = self.block_edge_weight.cpu().detach() + + # Resampling of search space (Algorithm 1, line 9-14) + if epoch < self.epochs_resampling - 1: + self._resample_random_block(budget) + elif epoch == self.epochs_resampling - 1: + # Retrieve best epoch if early stopping is active + # (not explicitly covered by pseudo code) + self.current_block = self.best_block.to(self.device) + self.block_edge_index = self.best_edge_index.to(self.device) + block_edge_weight = self.best_pert_edge_weight.clone() + self.block_edge_weight = block_edge_weight.to(self.device) + + scalars['metric'] = metric.item() + return scalars + + @torch.no_grad() + def _close(self, x: Tensor, labels: Tensor, budget: int, + idx_attack: Optional[Tensor] = None, + **kwargs) -> Tuple[Tensor, Tensor]: + """Clean up and prepare return argument.""" + # Retrieve best epoch if early stopping is active + # (not explicitly covered by pseudo code) + if self.coeffs['with_early_stopping']: + self.current_block = self.best_block.to(self.device) + self.block_edge_index = self.best_edge_index.to(self.device) + self.block_edge_weight = self.best_pert_edge_weight.to(self.device) + + # Sample final discrete graph (Algorithm 1, line 16) + edge_index, flipped_edges = self._sample_final_edges( + x, labels, budget, idx_attack=idx_attack, **kwargs) + + assert flipped_edges.size(1) <= self.budget, ( + f'# perturbed edges {flipped_edges.size(1)} ' + f'exceeds budget {self.budget}') + + row, col = flipped_edges + # TODO: zip* + for it, (u, v) in enumerate(zip(row.tolist(), col.tolist())): + if self.adj[u, v] > 0: + self.remove_edge(u, v, it) + else: + self.add_edge(u, v, it) + + def _forward_and_gradient(self, x: Tensor, labels: Tensor, + victim_nodes: Optional[Tensor] = None, + **kwargs) -> Tuple[Tensor, Tensor]: + """Forward and update edge weights.""" + self.block_edge_weight.requires_grad_() + + # Retrieve sparse perturbed adjacency matrix `A \oplus p_{t-1}` + # (Algorithm 1, line 6 / Algorithm 2, line 7) + edge_index, edge_weight = self._get_modified_adj( + self.edge_index, self.edge_weights, self.block_edge_index, + self.block_edge_weight) + + # Get prediction (Algorithm 1, line 6 / Algorithm 2, line 7) + prediction = self.model(x, edge_index, edge_weight, **kwargs) + # Calculate loss combining all each node + # (Algorithm 1, line 7 / Algorithm 2, line 8) + loss = self.loss(prediction, labels, victim_nodes) + # Retrieve gradient towards the current block + # (Algorithm 1, line 7 / Algorithm 2, line 8) + gradient = torch.autograd.grad(loss, self.block_edge_weight)[0] + + return loss, gradient + + def _get_modified_adj( + self, + edge_index: Tensor, + edge_weight: Tensor, + block_edge_index: Tensor, + block_edge_weight: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Merges adjacency matrix with current block (incl. weights)""" + if self.is_undirected_graph: + block_edge_index, block_edge_weight = to_undirected( + block_edge_index, block_edge_weight, num_nodes=self.num_nodes, + reduce='mean') + + modified_edge_index = torch.cat((edge_index, block_edge_index), dim=-1) + modified_edge_weight = torch.cat((edge_weight, block_edge_weight)) + + modified_edge_index, modified_edge_weight = coalesce( + modified_edge_index, modified_edge_weight, + num_nodes=self.num_nodes, reduce='sum') + + # Allow (soft) removal of edges + is_edge_in_clean_adj = modified_edge_weight > 1 + modified_edge_weight[is_edge_in_clean_adj] = ( + 2 - modified_edge_weight[is_edge_in_clean_adj]) + + return modified_edge_index, modified_edge_weight + + def _filter_self_loops_in_block(self, with_weight: bool): + is_not_sl = self.block_edge_index[0] != self.block_edge_index[1] + self.current_block = self.current_block[is_not_sl] + self.block_edge_index = self.block_edge_index[:, is_not_sl] + if with_weight: + self.block_edge_weight = self.block_edge_weight[is_not_sl] + + @torch.no_grad() + def _sample_random_block(self, budget: int = 0): + for _ in range(self.coeffs['max_trials_sampling']): + num_possible_edges = self._num_possible_edges( + self.num_nodes, self.is_undirected_graph) + self.current_block = torch.randint(num_possible_edges, + (self.block_size, ), + device=self.device) + self.current_block = torch.unique(self.current_block, sorted=True) + if self.is_undirected_graph: + self.block_edge_index = self._linear_to_triu_idx( + self.num_nodes, self.current_block) + else: + self.block_edge_index = self._linear_to_full_idx( + self.num_nodes, self.current_block) + self._filter_self_loops_in_block(with_weight=False) + + self.block_edge_weight = torch.full(self.current_block.shape, + self.coeffs['eps'], + device=self.device) + if self.current_block.size(0) >= budget: + return + raise RuntimeError('Sampling random block was not successful. ' + 'Please decrease `budget`.') + + def _resample_random_block(self, budget: int): + # Keep at most half of the block (i.e. resample low weights) + sorted_idx = torch.argsort(self.block_edge_weight) + keep_above = (self.block_edge_weight <= + self.coeffs['eps']).sum().long() + if keep_above < sorted_idx.size(0) // 2: + keep_above = sorted_idx.size(0) // 2 + sorted_idx = sorted_idx[keep_above:] + + self.current_block = self.current_block[sorted_idx] + + # Sample until enough edges were drawn + for _ in range(self.coeffs['max_trials_sampling']): + n_edges_resample = self.block_size - self.current_block.size(0) + num_possible_edges = self._num_possible_edges( + self.num_nodes, self.is_undirected_graph) + lin_index = torch.randint(num_possible_edges, (n_edges_resample, ), + device=self.device) + + current_block = torch.cat((self.current_block, lin_index)) + self.current_block, unique_idx = torch.unique( + current_block, sorted=True, return_inverse=True) + + if self.is_undirected_graph: + self.block_edge_index = self._linear_to_triu_idx( + self.num_nodes, self.current_block) + else: + self.block_edge_index = self._linear_to_full_idx( + self.num_nodes, self.current_block) + + # Merge existing weights with new edge weights + block_edge_weight_prev = self.block_edge_weight[sorted_idx] + self.block_edge_weight = torch.full(self.current_block.shape, + self.coeffs['eps'], + device=self.device) + self.block_edge_weight[ + unique_idx[:sorted_idx.size(0)]] = block_edge_weight_prev + + if not self.is_undirected_graph: + self._filter_self_loops_in_block(with_weight=True) + + if self.current_block.size(0) > budget: + return + raise RuntimeError('Sampling random block was not successful.' + 'Please decrease `budget`.') + + def _sample_final_edges(self, x: Tensor, labels: Tensor, budget: int, + idx_attack: Optional[Tensor] = None, + **kwargs) -> Tuple[Tensor, Tensor]: + best_metric = float('-Inf') + block_edge_weight = self.block_edge_weight + block_edge_weight[block_edge_weight <= self.coeffs['eps']] = 0 + + for i in range(self.coeffs['max_final_samples']): + if i == 0: + # In first iteration employ top k heuristic instead of sampling + sampled_edges = torch.zeros_like(block_edge_weight) + sampled_edges[torch.topk(block_edge_weight, + budget).indices] = 1 + else: + sampled_edges = torch.bernoulli(block_edge_weight).float() + + if sampled_edges.sum() > budget: + # Allowed budget is exceeded + continue + self.block_edge_weight = sampled_edges + + edge_index, edge_weight = self._get_modified_adj( + self.edge_index, self.edge_weights, self.block_edge_index, + self.block_edge_weight) + prediction = self._forward(x, edge_index, edge_weight, **kwargs) + metric = self.metric(prediction, labels, idx_attack) + + # Save best sample + if metric > best_metric: + best_metric = metric + best_edge_weight = self.block_edge_weight.clone().cpu() + + # Recover best sample + self.block_edge_weight = best_edge_weight.to(self.device) + flipped_edges = self.block_edge_index[:, + torch.where(best_edge_weight)[0]] + + edge_index, edge_weight = self._get_modified_adj( + self.edge_index, self.edge_weights, self.block_edge_index, + self.block_edge_weight) + edge_mask = edge_weight == 1 + edge_index = edge_index[:, edge_mask] + + return edge_index, flipped_edges + + def _update_edge_weights(self, budget: int, epoch: int, gradient: Tensor): + # The learning rate is refined heuristically, s.t. (1) it is + # independent of the number of perturbations (assuming an undirected + # adjacency matrix) and (2) to decay learning rate during fine-tuning + # (i.e. fixed search space). + lr = (budget / self.num_nodes * self.lr / + np.sqrt(max(0, epoch - self.epochs_resampling) + 1)) + self.block_edge_weight.data.add_(lr * gradient) + + @staticmethod + def _project(budget: int, values: Tensor, eps: float = 1e-7) -> Tensor: + r"""Project :obj:`values`: + :math:`budget \ge \sum \Pi_{[0, 1]}(\text{values})`.""" + if torch.clamp(values, 0, 1).sum() > budget: + left = (values - 1).min() + right = values.max() + miu = PRBCDAttack._bisection(values, left, right, budget) + values = values - miu + return torch.clamp(values, min=eps, max=1 - eps) + + @staticmethod + def _bisection(edge_weights: Tensor, a: float, b: float, n_pert: int, + eps=1e-5, max_iter=1e3) -> Tensor: + """Bisection search for projection.""" + def shift(offset: float): + return (torch.clamp(edge_weights - offset, 0, 1).sum() - n_pert) + + miu = a + for _ in range(int(max_iter)): + miu = (a + b) / 2 + # Check if middle point is root + if (shift(miu) == 0.0): + break + # Decide the side to repeat the steps + if (shift(miu) * shift(a) < 0): + b = miu + else: + a = miu + if ((b - a) <= eps): + break + return miu + + @staticmethod + def _num_possible_edges(n: int, is_undirected_graph: bool) -> int: + """Determine number of possible edges for graph.""" + if is_undirected_graph: + return n * (n - 1) // 2 + else: + return int(n**2) # We filter self-loops later + + @staticmethod + def _linear_to_triu_idx(n: int, lin_idx: Tensor) -> Tensor: + """Linear index to upper triangular matrix without diagonal. This is + similar to + https://stackoverflow.com/questions/242711/algorithm-for-index-numbers-of-triangular-matrix-coefficients/28116498#28116498 + with number nodes decremented and col index incremented by one.""" + nn = n * (n - 1) + row_idx = n - 2 - torch.floor( + torch.sqrt(-8 * lin_idx.double() + 4 * nn - 7) / 2.0 - 0.5).long() + col_idx = 1 + lin_idx + row_idx - nn // 2 + torch.div( + (n - row_idx) * (n - row_idx - 1), 2, rounding_mode='floor') + return torch.stack((row_idx, col_idx)) + + @staticmethod + def _linear_to_full_idx(n: int, lin_idx: Tensor) -> Tensor: + """Linear index to dense matrix including diagonal.""" + row_idx = torch.div(lin_idx, n, rounding_mode='floor') + col_idx = lin_idx % n + return torch.stack((row_idx, col_idx)) + + def structure_score(self, modified_adj, adj_grad): + score = adj_grad * (1 - 2 * modified_adj) + score -= score.min() + score = torch.triu(score, diagonal=1) + if not self._allow_singleton: + # Set entries to 0 that could lead to singleton nodes. + score *= singleton_mask(modified_adj) + return score.view(-1) + + def feature_score(self, modified_feat, feat_grad): + score = feat_grad * (1 - 2 * modified_feat) + score -= score.min() + return score.view(-1) + + def compute_gradients(self, modified_adj, modified_feat, victim_nodes, + victim_labels): + + logit = self.surrogate(modified_feat, + modified_adj)[victim_nodes] / self.eps + loss = F.cross_entropy(logit, victim_labels) + + if self.structure_attack and self.feature_attack: + return grad(loss, [modified_adj, modified_feat], + create_graph=False) + + if self.structure_attack: + return grad(loss, modified_adj, create_graph=False)[0], None + + if self.feature_attack: + return None, grad(loss, modified_feat, create_graph=False)[0] diff --git a/greatx/functional/__init__.py b/greatx/functional/__init__.py index 0a486f5..09ec6c5 100644 --- a/greatx/functional/__init__.py +++ b/greatx/functional/__init__.py @@ -1,7 +1,19 @@ from .dropouts import drop_edge, drop_node, drop_path from .spmm import spmm from .transform import to_dense_adj, to_sparse_adj, to_sparse_tensor +from .losses import (margin_loss, tanh_margin_loss, probability_margin_loss, + masked_cross_entropy) -classes = __all__ = ['to_sparse_tensor', 'to_dense_adj', 'to_sparse_adj', - 'spmm', - 'drop_edge', 'drop_node', 'drop_path'] +classes = __all__ = [ + 'to_sparse_tensor', + 'to_dense_adj', + 'to_sparse_adj', + 'spmm', + 'drop_edge', + 'drop_node', + 'drop_path', + 'margin_loss', + 'tanh_margin_loss', + 'probability_margin_loss', + 'masked_cross_entropy', +] diff --git a/greatx/functional/losses.py b/greatx/functional/losses.py new file mode 100644 index 0000000..837b744 --- /dev/null +++ b/greatx/functional/losses.py @@ -0,0 +1,104 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor + + +def margin_loss(score: Tensor, labels: Tensor, + mask: Optional[Tensor] = None) -> Tensor: + r"""Margin loss between true score and highest non-target score: + + .. math:: + m = - s_{y} + max_{y' \ne y} s_{y'} + + where :math:`m` is the margin :math:`s` the score and :math:`y` the + labels. + + Args: + score (Tensor): Some score (e.g. logits) of shape + :obj:`[n_elem, dim]`. + labels (LongTensor): The labels of shape :obj:`[n_elem]`. + mask (Tensor, optional): To select subset of `score` and + `labels` of shape :obj:`[n_select]`. Defaults to None. + + :rtype: (Tensor) + """ + if mask is not None: + score = score[mask] + labels = labels[mask] + + linear_idx = torch.arange(score.size(0), device=score.device) + true_score = score[linear_idx, labels] + + score = score.clone() + score[linear_idx, labels] = float('-Inf') + best_non_target_score = score.amax(dim=-1) + + margin = best_non_target_score - true_score + return margin + + +def tanh_margin_loss(prediction: Tensor, labels: Tensor, + mask: Optional[Tensor] = None) -> Tensor: + """Calculate tanh margin loss, a node-classification loss that focuses + on nodes next to decision boundary. + + Args: + prediction (Tensor): Prediction of shape :obj:`[n_elem, dim]`. + labels (LongTensor): The labels of shape :obj:`[n_elem]`. + mask (Tensor, optional): To select subset of `score` and + `labels` of shape :obj:`[n_select]`. Defaults to None. + + :rtype: (Tensor) + """ + log_logits = F.log_softmax(prediction, dim=-1) + margin = margin_loss(log_logits, labels, mask) + loss = torch.tanh(margin).mean() + return loss + + +def probability_margin_loss(prediction: Tensor, labels: Tensor, + mask: Optional[Tensor] = None) -> Tensor: + """Calculate probability margin loss, a node-classification loss that + focuses on nodes next to decision boundary. See `Are Defenses for + Graph Neural Networks Robust? + `_ for details. + + Args: + prediction (Tensor): Prediction of shape :obj:`[n_elem, dim]`. + labels (LongTensor): The labels of shape :obj:`[n_elem]`. + mask (Tensor, optional): To select subset of `score` and + `labels` of shape :obj:`[n_select]`. Defaults to None. + + :rtype: (Tensor) + """ + logits = F.softmax(prediction, dim=-1) + margin = margin_loss(logits, labels, mask) + return margin.mean() + + +def masked_cross_entropy(log_logits: Tensor, labels: Tensor, + mask: Optional[Tensor] = None) -> Tensor: + """Calculate masked cross entropy loss, a node-classification loss that + focuses on nodes next to decision boundary. + + Args: + log_logits (Tensor): Log logits of shape :obj:`[n_elem, dim]`. + labels (LongTensor): The labels of shape :obj:`[n_elem]`. + mask (Tensor, optional): To select subset of `score` and + `labels` of shape :obj:`[n_select]`. Defaults to None. + + :rtype: (Tensor) + """ + if mask is not None: + log_logits = log_logits[mask] + labels = labels[mask] + + is_correct = log_logits.argmax(-1) == labels + if is_correct.any(): + log_logits = log_logits[is_correct] + labels = labels[is_correct] + + loss = F.cross_entropy(log_logits, labels) + return loss From a49bf6d7ae751d685a302ce75949c509731a96ad Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Mon, 28 Nov 2022 10:24:01 +0800 Subject: [PATCH 02/15] import --- greatx/attack/untargeted/__init__.py | 11 +++++++++-- greatx/attack/untargeted/{rbcd.py => rbcd_attack.py} | 0 2 files changed, 9 insertions(+), 2 deletions(-) rename greatx/attack/untargeted/{rbcd.py => rbcd_attack.py} (100%) diff --git a/greatx/attack/untargeted/__init__.py b/greatx/attack/untargeted/__init__.py index a7eefa2..7bc5c04 100644 --- a/greatx/attack/untargeted/__init__.py +++ b/greatx/attack/untargeted/__init__.py @@ -5,8 +5,15 @@ from .pgd_attack import PGDAttack from .random_attack import RandomAttack from .untargeted_attacker import UntargetedAttacker +from .rbcd_attack import PRBCDAttack classes = __all__ = [ - 'UntargetedAttacker', 'RandomAttack', 'DICEAttack', 'FGAttack', 'IGAttack', - 'Metattack', 'PGDAttack' + 'UntargetedAttacker', + 'RandomAttack', + 'DICEAttack', + 'FGAttack', + 'IGAttack', + 'Metattack', + 'PGDAttack', + 'PRBCDAttack', ] diff --git a/greatx/attack/untargeted/rbcd.py b/greatx/attack/untargeted/rbcd_attack.py similarity index 100% rename from greatx/attack/untargeted/rbcd.py rename to greatx/attack/untargeted/rbcd_attack.py From 862c4f0abef18471bba81ff419178ae8affebf9f Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Mon, 28 Nov 2022 16:14:58 +0800 Subject: [PATCH 03/15] update --- greatx/attack/untargeted/rbcd_attack.py | 180 ++++++++++-------------- 1 file changed, 71 insertions(+), 109 deletions(-) diff --git a/greatx/attack/untargeted/rbcd_attack.py b/greatx/attack/untargeted/rbcd_attack.py index d61b9d9..4216e77 100644 --- a/greatx/attack/untargeted/rbcd_attack.py +++ b/greatx/attack/untargeted/rbcd_attack.py @@ -1,13 +1,10 @@ from collections import defaultdict -from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import numpy as np import torch -import torch.nn.functional as F from torch import Tensor -from torch.autograd import grad from torch_geometric.utils import coalesce, to_undirected -from tqdm import tqdm from tqdm.auto import tqdm from greatx.attack.untargeted.untargeted_attacker import UntargetedAttacker @@ -18,7 +15,6 @@ tanh_margin_loss, ) from greatx.nn.models.surrogate import Surrogate -from greatx.utils import singleton_mask # (predictions, labels, ids/mask) -> Tensor with one element LOSS_TYPE = Callable[[Tensor, Tensor, Optional[Tensor]], Tensor] @@ -59,14 +55,21 @@ def reset(self): self.block_edge_weight = None return self - def attack(self, num_budgets: Union[int, float] = 0.05, - block_size: int = 250_000, epochs: int = 125, - epochs_resampling: int = 100, - loss: Optional[Union[str, LOSS_TYPE]] = 'prob_margin', - metric: Optional[Union[str, - LOSS_TYPE]] = None, lr: float = 1_000, *, - structure_attack: bool = True, feature_attack: bool = False, - disable: bool = False, **kwargs) -> "PRBCDAttack": + def attack( + self, + num_budgets: Union[int, float] = 0.05, + *, + block_size: int = 250_000, + epochs: int = 125, + epochs_resampling: int = 100, + loss: Optional[Union[str, LOSS_TYPE]] = 'prob_margin', + metric: Optional[Union[str, LOSS_TYPE]] = None, + lr: float = 2_000, + structure_attack: bool = True, + feature_attack: bool = False, + disable: bool = False, + **kwargs, + ) -> "PRBCDAttack": super().attack(num_budgets=num_budgets, structure_attack=structure_attack, @@ -104,18 +107,18 @@ def attack(self, num_budgets: Union[int, float] = 0.05, # For collecting attack statistics self.attack_statistics = defaultdict(list) - # Prepare attack and define `self.iterable` to iterate over - budget = self.budget + budget = self.num_budgets self.best_metric = float('-Inf') # Sample initial search space (Algorithm 1, line 3-4) - self._sample_random_block(budget) + self.sample_random_block(budget) # Loop over the epochs (Algorithm 1, line 5) - for step in tqdm(range(self.budgets), desc='Peturbing graph...', + for step in tqdm(range(self.num_budgets), desc='Peturbing graph...', disable=disable): - loss, gradient = self._forward_and_gradient( - self.feat, self.label, self.victim_nodes, **kwargs) + loss, gradient = self.compute_gradients(self.feat, self.label, + self.victim_nodes, + **kwargs) scalars = self._update(step, gradient, self.feat, self.label, budget, self.victim_nodes, **kwargs) @@ -127,40 +130,30 @@ def attack(self, num_budgets: Union[int, float] = 0.05, return self - @torch.no_grad() - def _prepare(self, budget: int) -> Iterable[int]: - """Prepare attack.""" - # For early stopping (not explicitly covered by pseudo code) - self.best_metric = float('-Inf') - - # Sample initial search space (Algorithm 1, line 3-4) - self._sample_random_block(budget) - - steps = range(self.epochs) - return steps - @torch.no_grad() def _update(self, epoch: int, gradient: Tensor, x: Tensor, labels: Tensor, budget: int, idx_attack: Optional[Tensor] = None, **kwargs) -> Dict[str, float]: """Update edge weights given gradient.""" # Gradient update step (Algorithm 1, line 7) - self._update_edge_weights(budget, epoch, gradient) + self.update_edge_weights(budget, epoch, gradient) # For monitoring pmass_update = torch.clamp(self.block_edge_weight, 0, 1) # Projection to stay within relaxed `L_0` budget # (Algorithm 1, line 8) - self.block_edge_weight = self._project(budget, self.block_edge_weight, - self.coeffs['eps']) + self.block_edge_weight = self.project(budget, self.block_edge_weight, + self.coeffs['eps']) # For monitoring scalars = dict( prob_mass_after_update=pmass_update.sum().item(), prob_mass_after_update_max=pmass_update.max().item(), - prob_mass_after_projection=self.block_edge_weight.sum().item(), - prob_mass_after_projection_nonzero_weights=( + prob_mass_afterprojection=self.block_edge_weight.sum().item(), + prob_mass_afterprojection_nonzero_weights=( self.block_edge_weight > self.coeffs['eps']).sum().item(), - prob_mass_after_projection_max=self.block_edge_weight.max().item()) + prob_mass_afterprojection_max=self.block_edge_weight.max().item(), + ) + if not self.coeffs['with_early_stopping']: return scalars @@ -169,10 +162,10 @@ def _update(self, epoch: int, gradient: Tensor, x: Tensor, labels: Tensor, topk_block_edge_weight = torch.zeros_like(self.block_edge_weight) topk_block_edge_weight[torch.topk(self.block_edge_weight, budget).indices] = 1 - edge_index, edge_weight = self._get_modified_adj( + edge_index, edge_weight = self.get_modified_graph( self.edge_index, self.edge_weights, self.block_edge_index, topk_block_edge_weight) - prediction = self._forward(x, edge_index, edge_weight, **kwargs) + prediction = self.surrogate(x, edge_index, edge_weight, **kwargs) metric = self.metric(prediction, labels, idx_attack) # Save best epoch for early stopping @@ -185,7 +178,7 @@ def _update(self, epoch: int, gradient: Tensor, x: Tensor, labels: Tensor, # Resampling of search space (Algorithm 1, line 9-14) if epoch < self.epochs_resampling - 1: - self._resample_random_block(budget) + self.resample_random_block(budget) elif epoch == self.epochs_resampling - 1: # Retrieve best epoch if early stopping is active # (not explicitly covered by pseudo code) @@ -210,35 +203,34 @@ def _close(self, x: Tensor, labels: Tensor, budget: int, self.block_edge_weight = self.best_pert_edge_weight.to(self.device) # Sample final discrete graph (Algorithm 1, line 16) - edge_index, flipped_edges = self._sample_final_edges( + flipped_edges, edge_weight = self.sample_final_edges( x, labels, budget, idx_attack=idx_attack, **kwargs) - assert flipped_edges.size(1) <= self.budget, ( + assert flipped_edges.size(1) <= self.num_budgets, ( f'# perturbed edges {flipped_edges.size(1)} ' - f'exceeds budget {self.budget}') + f'exceeds budget {self.num_budgets}') - row, col = flipped_edges - # TODO: zip* - for it, (u, v) in enumerate(zip(row.tolist(), col.tolist())): - if self.adj[u, v] > 0: + row, col = flipped_edges.tolist() + for it, (u, v, w) in enumerate(zip(row, col, edge_weight.tolist())): + if w > 0: self.remove_edge(u, v, it) else: self.add_edge(u, v, it) - def _forward_and_gradient(self, x: Tensor, labels: Tensor, - victim_nodes: Optional[Tensor] = None, - **kwargs) -> Tuple[Tensor, Tensor]: + def compute_gradients(self, x: Tensor, labels: Tensor, + victim_nodes: Optional[Tensor] = None, + **kwargs) -> Tuple[Tensor, Tensor]: """Forward and update edge weights.""" self.block_edge_weight.requires_grad_() # Retrieve sparse perturbed adjacency matrix `A \oplus p_{t-1}` # (Algorithm 1, line 6 / Algorithm 2, line 7) - edge_index, edge_weight = self._get_modified_adj( + edge_index, edge_weight = self.get_modified_graph( self.edge_index, self.edge_weights, self.block_edge_index, self.block_edge_weight) # Get prediction (Algorithm 1, line 6 / Algorithm 2, line 7) - prediction = self.model(x, edge_index, edge_weight, **kwargs) + prediction = self.surrogate(x, edge_index, edge_weight, **kwargs) # Calculate loss combining all each node # (Algorithm 1, line 7 / Algorithm 2, line 8) loss = self.loss(prediction, labels, victim_nodes) @@ -248,7 +240,7 @@ def _forward_and_gradient(self, x: Tensor, labels: Tensor, return loss, gradient - def _get_modified_adj( + def get_modified_graph( self, edge_index: Tensor, edge_weight: Tensor, @@ -270,8 +262,8 @@ def _get_modified_adj( # Allow (soft) removal of edges is_edge_in_clean_adj = modified_edge_weight > 1 - modified_edge_weight[is_edge_in_clean_adj] = ( - 2 - modified_edge_weight[is_edge_in_clean_adj]) + modified_edge_weight[is_edge_in_clean_adj] = 2 - modified_edge_weight[ + is_edge_in_clean_adj] return modified_edge_index, modified_edge_weight @@ -283,7 +275,7 @@ def _filter_self_loops_in_block(self, with_weight: bool): self.block_edge_weight = self.block_edge_weight[is_not_sl] @torch.no_grad() - def _sample_random_block(self, budget: int = 0): + def sample_random_block(self, budget: int = 0): for _ in range(self.coeffs['max_trials_sampling']): num_possible_edges = self._num_possible_edges( self.num_nodes, self.is_undirected_graph) @@ -291,12 +283,14 @@ def _sample_random_block(self, budget: int = 0): (self.block_size, ), device=self.device) self.current_block = torch.unique(self.current_block, sorted=True) + if self.is_undirected_graph: self.block_edge_index = self._linear_to_triu_idx( self.num_nodes, self.current_block) else: self.block_edge_index = self._linear_to_full_idx( self.num_nodes, self.current_block) + self._filter_self_loops_in_block(with_weight=False) self.block_edge_weight = torch.full(self.current_block.shape, @@ -304,10 +298,11 @@ def _sample_random_block(self, budget: int = 0): device=self.device) if self.current_block.size(0) >= budget: return + raise RuntimeError('Sampling random block was not successful. ' 'Please decrease `budget`.') - def _resample_random_block(self, budget: int): + def resample_random_block(self, budget: int): # Keep at most half of the block (i.e. resample low weights) sorted_idx = torch.argsort(self.block_edge_weight) keep_above = (self.block_edge_weight <= @@ -342,6 +337,7 @@ def _resample_random_block(self, budget: int): self.block_edge_weight = torch.full(self.current_block.shape, self.coeffs['eps'], device=self.device) + self.block_edge_weight[ unique_idx[:sorted_idx.size(0)]] = block_edge_weight_prev @@ -353,9 +349,10 @@ def _resample_random_block(self, budget: int): raise RuntimeError('Sampling random block was not successful.' 'Please decrease `budget`.') - def _sample_final_edges(self, x: Tensor, labels: Tensor, budget: int, - idx_attack: Optional[Tensor] = None, - **kwargs) -> Tuple[Tensor, Tensor]: + @torch.no_grad() + def sample_final_edges(self, x: Tensor, labels: Tensor, budget: int, + idx_attack: Optional[Tensor] = None, + **kwargs) -> Tuple[Tensor, Tensor]: best_metric = float('-Inf') block_edge_weight = self.block_edge_weight block_edge_weight[block_edge_weight <= self.coeffs['eps']] = 0 @@ -372,12 +369,13 @@ def _sample_final_edges(self, x: Tensor, labels: Tensor, budget: int, if sampled_edges.sum() > budget: # Allowed budget is exceeded continue + self.block_edge_weight = sampled_edges - edge_index, edge_weight = self._get_modified_adj( + edge_index, edge_weight = self.get_modified_graph( self.edge_index, self.edge_weights, self.block_edge_index, self.block_edge_weight) - prediction = self._forward(x, edge_index, edge_weight, **kwargs) + prediction = self.surrogate(x, edge_index, edge_weight, **kwargs) metric = self.metric(prediction, labels, idx_attack) # Save best sample @@ -385,20 +383,11 @@ def _sample_final_edges(self, x: Tensor, labels: Tensor, budget: int, best_metric = metric best_edge_weight = self.block_edge_weight.clone().cpu() - # Recover best sample - self.block_edge_weight = best_edge_weight.to(self.device) flipped_edges = self.block_edge_index[:, torch.where(best_edge_weight)[0]] + return flipped_edges, best_edge_weight - edge_index, edge_weight = self._get_modified_adj( - self.edge_index, self.edge_weights, self.block_edge_index, - self.block_edge_weight) - edge_mask = edge_weight == 1 - edge_index = edge_index[:, edge_mask] - - return edge_index, flipped_edges - - def _update_edge_weights(self, budget: int, epoch: int, gradient: Tensor): + def update_edge_weights(self, budget: int, epoch: int, gradient: Tensor): # The learning rate is refined heuristically, s.t. (1) it is # independent of the number of perturbations (assuming an undirected # adjacency matrix) and (2) to decay learning rate during fine-tuning @@ -408,19 +397,19 @@ def _update_edge_weights(self, budget: int, epoch: int, gradient: Tensor): self.block_edge_weight.data.add_(lr * gradient) @staticmethod - def _project(budget: int, values: Tensor, eps: float = 1e-7) -> Tensor: + def project(budget: int, values: Tensor, eps: float = 1e-7) -> Tensor: r"""Project :obj:`values`: :math:`budget \ge \sum \Pi_{[0, 1]}(\text{values})`.""" if torch.clamp(values, 0, 1).sum() > budget: left = (values - 1).min() right = values.max() - miu = PRBCDAttack._bisection(values, left, right, budget) + miu = PRBCDAttack.bisection(values, left, right, budget) values = values - miu return torch.clamp(values, min=eps, max=1 - eps) @staticmethod - def _bisection(edge_weights: Tensor, a: float, b: float, n_pert: int, - eps=1e-5, max_iter=1e3) -> Tensor: + def bisection(edge_weights: Tensor, a: float, b: float, n_pert: int, + eps=1e-5, max_iter=1e3) -> Tensor: """Bisection search for projection.""" def shift(offset: float): return (torch.clamp(edge_weights - offset, 0, 1).sum() - n_pert) @@ -450,8 +439,8 @@ def _num_possible_edges(n: int, is_undirected_graph: bool) -> int: @staticmethod def _linear_to_triu_idx(n: int, lin_idx: Tensor) -> Tensor: - """Linear index to upper triangular matrix without diagonal. This is - similar to + """Linear index to upper triangular matrix without diagonal. + This is similar to https://stackoverflow.com/questions/242711/algorithm-for-index-numbers-of-triangular-matrix-coefficients/28116498#28116498 with number nodes decremented and col index incremented by one.""" nn = n * (n - 1) @@ -468,33 +457,6 @@ def _linear_to_full_idx(n: int, lin_idx: Tensor) -> Tensor: col_idx = lin_idx % n return torch.stack((row_idx, col_idx)) - def structure_score(self, modified_adj, adj_grad): - score = adj_grad * (1 - 2 * modified_adj) - score -= score.min() - score = torch.triu(score, diagonal=1) - if not self._allow_singleton: - # Set entries to 0 that could lead to singleton nodes. - score *= singleton_mask(modified_adj) - return score.view(-1) - - def feature_score(self, modified_feat, feat_grad): - score = feat_grad * (1 - 2 * modified_feat) - score -= score.min() - return score.view(-1) - - def compute_gradients(self, modified_adj, modified_feat, victim_nodes, - victim_labels): - - logit = self.surrogate(modified_feat, - modified_adj)[victim_nodes] / self.eps - loss = F.cross_entropy(logit, victim_labels) - - if self.structure_attack and self.feature_attack: - return grad(loss, [modified_adj, modified_feat], - create_graph=False) - - if self.structure_attack: - return grad(loss, modified_adj, create_graph=False)[0], None - - if self.feature_attack: - return None, grad(loss, modified_feat, create_graph=False)[0] + def _append_statistics(self, mapping: Dict[str, Any]): + for key, value in mapping.items(): + self.attack_statistics[key].append(value) From 21f32b1cc08e0265fb23e25966ea7db26693c628 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Mon, 28 Nov 2022 16:23:36 +0800 Subject: [PATCH 04/15] update --- greatx/attack/untargeted/rbcd_attack.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/greatx/attack/untargeted/rbcd_attack.py b/greatx/attack/untargeted/rbcd_attack.py index 4216e77..15daff5 100644 --- a/greatx/attack/untargeted/rbcd_attack.py +++ b/greatx/attack/untargeted/rbcd_attack.py @@ -32,11 +32,16 @@ class PRBCDAttack(UntargetedAttacker, Surrogate): 'eps': 1e-7 } - def setup_surrogate(self, surrogate: torch.nn.Module, victim_nodes: Tensor, - victim_labels: Optional[Tensor] = None, *, - eps: float = 1.0): + def setup_surrogate( + self, + surrogate: torch.nn.Module, + victim_nodes: Tensor, + victim_labels: Optional[Tensor] = None, + *, + tau: float = 1.0, + ) -> "PRBCDAttack": - Surrogate.setup_surrogate(self, surrogate=surrogate, eps=eps, + Surrogate.setup_surrogate(self, surrogate=surrogate, tau=tau, freeze=True) if victim_nodes.dtype == torch.bool: @@ -48,7 +53,7 @@ def setup_surrogate(self, surrogate: torch.nn.Module, victim_nodes: Tensor, self.victim_labels = victim_labels.to(self.device) return self - def reset(self): + def reset(self) -> "PRBCDAttack": super().reset() self.current_block = None self.block_edge_index = None @@ -116,6 +121,7 @@ def attack( # Loop over the epochs (Algorithm 1, line 5) for step in tqdm(range(self.num_budgets), desc='Peturbing graph...', disable=disable): + loss, gradient = self.compute_gradients(self.feat, self.label, self.victim_nodes, **kwargs) From dec61bcc6b0182ff996914b85cb8f63de822ca5f Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Mon, 28 Nov 2022 18:24:07 +0800 Subject: [PATCH 05/15] update --- greatx/attack/untargeted/rbcd_attack.py | 197 +++++++++++++++--------- greatx/functional/losses.py | 123 +++++++-------- 2 files changed, 183 insertions(+), 137 deletions(-) diff --git a/greatx/attack/untargeted/rbcd_attack.py b/greatx/attack/untargeted/rbcd_attack.py index 15daff5..2b1359b 100644 --- a/greatx/attack/untargeted/rbcd_attack.py +++ b/greatx/attack/untargeted/rbcd_attack.py @@ -9,7 +9,6 @@ from greatx.attack.untargeted.untargeted_attacker import UntargetedAttacker from greatx.functional import ( - margin_loss, masked_cross_entropy, probability_margin_loss, tanh_margin_loss, @@ -17,13 +16,15 @@ from greatx.nn.models.surrogate import Surrogate # (predictions, labels, ids/mask) -> Tensor with one element -LOSS_TYPE = Callable[[Tensor, Tensor, Optional[Tensor]], Tensor] +METRIC = Callable[[Tensor, Tensor, Optional[Tensor]], Tensor] class PRBCDAttack(UntargetedAttacker, Surrogate): - # FGAttack can conduct feature attack - _allow_feature_attack: bool = True - is_undirected_graph: bool = True # TODO + + # TODO: Although PRBCDAttack accepts directed graphs, + # we currently don't explicitlyt support directed graphs. + # This should be made available in the future. + is_undirected_graph: bool = True coeffs: Dict[str, Any] = { 'max_final_samples': 20, @@ -36,21 +37,48 @@ def setup_surrogate( self, surrogate: torch.nn.Module, victim_nodes: Tensor, - victim_labels: Optional[Tensor] = None, + ground_truth: bool = True, *, tau: float = 1.0, + freeze: bool = True, ) -> "PRBCDAttack": + r"""Setup the surrogate model for adversarial attack. + + Parameters + ---------- + surrogate : torch.nn.Module + the surrogate model + victim_nodes : Tensor + the victim nodes_set + ground_truth : bool, optional + whether to use ground-truth label for victim nodes, + if False, the node labels are estimated by the surrogate model, + by default True + tau : float, optional + the temperature of softmax activation, by default 1.0 + freeze : bool, optional + whether to free the surrogate model to avoid the + gradient accumulation, by default True + + Returns + ------- + PRBCDAttack + the attacker itself + """ Surrogate.setup_surrogate(self, surrogate=surrogate, tau=tau, - freeze=True) + freeze=freeze) if victim_nodes.dtype == torch.bool: victim_nodes = victim_nodes.nonzero().view(-1) self.victim_nodes = victim_nodes.to(self.device) - if victim_labels is None: - victim_labels = self.label[victim_nodes] - self.victim_labels = victim_labels.to(self.device) + if ground_truth: + self.victim_labels = self.label[victim_nodes] + else: + self.victim_labels = self.estimate_self_training_labels( + victim_nodes) + return self def reset(self) -> "PRBCDAttack": @@ -58,6 +86,17 @@ def reset(self) -> "PRBCDAttack": self.current_block = None self.block_edge_index = None self.block_edge_weight = None + self.loss = None + self.metric = None + + # NOTE: `edge_weight` denotes the edge weight of the original graph + # it is None by default, so here we need to name it as `edge_weights` + self.edge_weights = torch.ones(self.num_edges, device=self.device) + self.best_metric = float('-Inf') + + # For collecting attack statistics + self.attack_statistics = defaultdict(list) + return self def attack( @@ -67,8 +106,8 @@ def attack( block_size: int = 250_000, epochs: int = 125, epochs_resampling: int = 100, - loss: Optional[Union[str, LOSS_TYPE]] = 'prob_margin', - metric: Optional[Union[str, LOSS_TYPE]] = None, + loss: Optional[str] = 'tanh_margin', + metric: Optional[Union[str, METRIC]] = None, lr: float = 2_000, structure_attack: bool = True, feature_attack: bool = False, @@ -83,19 +122,13 @@ def attack( self.block_size = block_size self.epochs = epochs - if isinstance(loss, str): - if loss == 'masked': - self.loss = masked_cross_entropy - elif loss == 'margin': - self.loss = margin_loss - elif loss == 'prob_margin': - self.loss = probability_margin_loss - elif loss == 'tanh_margin': - self.loss = tanh_margin_loss - else: - raise ValueError(f'Unknown loss `{loss}`') + assert loss in ['mce', 'prob_margin', 'tanh_margin'] + if loss == 'mce': + self.loss = masked_cross_entropy + elif loss == 'prob_margin': + self.loss = probability_margin_loss else: - self.loss = loss + self.loss = tanh_margin_loss if metric is None: self.metric = self.loss @@ -105,49 +138,43 @@ def attack( self.epochs_resampling = epochs_resampling self.lr = lr - # self.coeffs.update(**kwargs) # TODO - self.edge_weights = torch.ones(self.edge_index.size(1), - device=self.device) - - # For collecting attack statistics - self.attack_statistics = defaultdict(list) + self.coeffs.update(**kwargs) - budget = self.num_budgets + num_budgets = self.num_budgets - self.best_metric = float('-Inf') # Sample initial search space (Algorithm 1, line 3-4) - self.sample_random_block(budget) + self.sample_random_block(num_budgets) # Loop over the epochs (Algorithm 1, line 5) - for step in tqdm(range(self.num_budgets), desc='Peturbing graph...', + for step in tqdm(range(num_budgets), desc='Peturbing graph...', disable=disable): - loss, gradient = self.compute_gradients(self.feat, self.label, - self.victim_nodes, - **kwargs) + loss, gradient = self.compute_gradients(self.feat, + self.victim_labels, + self.victim_nodes) - scalars = self._update(step, gradient, self.feat, self.label, - budget, self.victim_nodes, **kwargs) + scalars = self.update(step, gradient, num_budgets) scalars['loss'] = loss.item() self._append_statistics(scalars) - self._close(self.feat, self.label, budget, self.victim_nodes, **kwargs) + self.close() return self @torch.no_grad() - def _update(self, epoch: int, gradient: Tensor, x: Tensor, labels: Tensor, - budget: int, idx_attack: Optional[Tensor] = None, - **kwargs) -> Dict[str, float]: + def update(self, epoch: int, gradient: Tensor, + num_budgets: int) -> Dict[str, float]: """Update edge weights given gradient.""" # Gradient update step (Algorithm 1, line 7) - self.update_edge_weights(budget, epoch, gradient) + self.update_edge_weights(num_budgets, epoch, gradient) + # For monitoring pmass_update = torch.clamp(self.block_edge_weight, 0, 1) # Projection to stay within relaxed `L_0` budget # (Algorithm 1, line 8) - self.block_edge_weight = self.project(budget, self.block_edge_weight, + self.block_edge_weight = self.project(num_budgets, + self.block_edge_weight, self.coeffs['eps']) # For monitoring @@ -165,14 +192,18 @@ def _update(self, epoch: int, gradient: Tensor, x: Tensor, labels: Tensor, # Calculate metric after the current epoch (overhead # for monitoring and early stopping) + topk_block_edge_weight = torch.zeros_like(self.block_edge_weight) topk_block_edge_weight[torch.topk(self.block_edge_weight, - budget).indices] = 1 + num_budgets).indices] = 1 + edge_index, edge_weight = self.get_modified_graph( self.edge_index, self.edge_weights, self.block_edge_index, topk_block_edge_weight) - prediction = self.surrogate(x, edge_index, edge_weight, **kwargs) - metric = self.metric(prediction, labels, idx_attack) + + prediction = self.surrogate(self.feat, edge_index, + edge_weight)[self.victim_nodes] + metric = self.metric(prediction, self.victim_labels) # Save best epoch for early stopping # (not explicitly covered by pseudo code) @@ -184,7 +215,7 @@ def _update(self, epoch: int, gradient: Tensor, x: Tensor, labels: Tensor, # Resampling of search space (Algorithm 1, line 9-14) if epoch < self.epochs_resampling - 1: - self.resample_random_block(budget) + self.resample_random_block(num_budgets) elif epoch == self.epochs_resampling - 1: # Retrieve best epoch if early stopping is active # (not explicitly covered by pseudo code) @@ -197,10 +228,9 @@ def _update(self, epoch: int, gradient: Tensor, x: Tensor, labels: Tensor, return scalars @torch.no_grad() - def _close(self, x: Tensor, labels: Tensor, budget: int, - idx_attack: Optional[Tensor] = None, - **kwargs) -> Tuple[Tensor, Tensor]: + def close(self): """Clean up and prepare return argument.""" + # Retrieve best epoch if early stopping is active # (not explicitly covered by pseudo code) if self.coeffs['with_early_stopping']: @@ -210,7 +240,11 @@ def _close(self, x: Tensor, labels: Tensor, budget: int, # Sample final discrete graph (Algorithm 1, line 16) flipped_edges, edge_weight = self.sample_final_edges( - x, labels, budget, idx_attack=idx_attack, **kwargs) + self.feat, + self.num_budgets, + self.victim_nodes, + self.victim_labels, + ) assert flipped_edges.size(1) <= self.num_budgets, ( f'# perturbed edges {flipped_edges.size(1)} ' @@ -223,9 +257,12 @@ def _close(self, x: Tensor, labels: Tensor, budget: int, else: self.add_edge(u, v, it) - def compute_gradients(self, x: Tensor, labels: Tensor, - victim_nodes: Optional[Tensor] = None, - **kwargs) -> Tuple[Tensor, Tensor]: + def compute_gradients( + self, + feat: Tensor, + victim_labels: Tensor, + victim_nodes: Tensor, + ) -> Tuple[Tensor, Tensor]: """Forward and update edge weights.""" self.block_edge_weight.requires_grad_() @@ -236,10 +273,11 @@ def compute_gradients(self, x: Tensor, labels: Tensor, self.block_edge_weight) # Get prediction (Algorithm 1, line 6 / Algorithm 2, line 7) - prediction = self.surrogate(x, edge_index, edge_weight, **kwargs) + prediction = self.surrogate(feat, edge_index, + edge_weight)[victim_nodes] # Calculate loss combining all each node # (Algorithm 1, line 7 / Algorithm 2, line 8) - loss = self.loss(prediction, labels, victim_nodes) + loss = self.loss(prediction, victim_labels) # Retrieve gradient towards the current block # (Algorithm 1, line 7 / Algorithm 2, line 8) gradient = torch.autograd.grad(loss, self.block_edge_weight)[0] @@ -267,19 +305,11 @@ def get_modified_graph( num_nodes=self.num_nodes, reduce='sum') # Allow (soft) removal of edges - is_edge_in_clean_adj = modified_edge_weight > 1 - modified_edge_weight[is_edge_in_clean_adj] = 2 - modified_edge_weight[ - is_edge_in_clean_adj] + mask = modified_edge_weight > 1 + modified_edge_weight[mask] = 2 - modified_edge_weight[mask] return modified_edge_index, modified_edge_weight - def _filter_self_loops_in_block(self, with_weight: bool): - is_not_sl = self.block_edge_index[0] != self.block_edge_index[1] - self.current_block = self.current_block[is_not_sl] - self.block_edge_index = self.block_edge_index[:, is_not_sl] - if with_weight: - self.block_edge_weight = self.block_edge_weight[is_not_sl] - @torch.no_grad() def sample_random_block(self, budget: int = 0): for _ in range(self.coeffs['max_trials_sampling']): @@ -356,9 +386,13 @@ def resample_random_block(self, budget: int): 'Please decrease `budget`.') @torch.no_grad() - def sample_final_edges(self, x: Tensor, labels: Tensor, budget: int, - idx_attack: Optional[Tensor] = None, - **kwargs) -> Tuple[Tensor, Tensor]: + def sample_final_edges( + self, + feat: Tensor, + num_budgets: int, + victim_nodes: Tensor, + victim_labels: Tensor, + ) -> Tuple[Tensor, Tensor]: best_metric = float('-Inf') block_edge_weight = self.block_edge_weight block_edge_weight[block_edge_weight <= self.coeffs['eps']] = 0 @@ -368,11 +402,11 @@ def sample_final_edges(self, x: Tensor, labels: Tensor, budget: int, # In first iteration employ top k heuristic instead of sampling sampled_edges = torch.zeros_like(block_edge_weight) sampled_edges[torch.topk(block_edge_weight, - budget).indices] = 1 + num_budgets).indices] = 1 else: sampled_edges = torch.bernoulli(block_edge_weight).float() - if sampled_edges.sum() > budget: + if sampled_edges.sum() > num_budgets: # Allowed budget is exceeded continue @@ -381,8 +415,9 @@ def sample_final_edges(self, x: Tensor, labels: Tensor, budget: int, edge_index, edge_weight = self.get_modified_graph( self.edge_index, self.edge_weights, self.block_edge_index, self.block_edge_weight) - prediction = self.surrogate(x, edge_index, edge_weight, **kwargs) - metric = self.metric(prediction, labels, idx_attack) + prediction = self.surrogate(feat, edge_index, + edge_weight)[victim_nodes] + metric = self.metric(prediction, victim_labels) # Save best sample if metric > best_metric: @@ -393,12 +428,13 @@ def sample_final_edges(self, x: Tensor, labels: Tensor, budget: int, torch.where(best_edge_weight)[0]] return flipped_edges, best_edge_weight - def update_edge_weights(self, budget: int, epoch: int, gradient: Tensor): + def update_edge_weights(self, num_budgets: int, epoch: int, + gradient: Tensor): # The learning rate is refined heuristically, s.t. (1) it is # independent of the number of perturbations (assuming an undirected # adjacency matrix) and (2) to decay learning rate during fine-tuning # (i.e. fixed search space). - lr = (budget / self.num_nodes * self.lr / + lr = (num_budgets / self.num_nodes * self.lr / np.sqrt(max(0, epoch - self.epochs_resampling) + 1)) self.block_edge_weight.data.add_(lr * gradient) @@ -435,6 +471,13 @@ def shift(offset: float): break return miu + def _filter_self_loops_in_block(self, with_weight: bool): + mask = self.block_edge_index[0] != self.block_edge_index[1] + self.current_block = self.current_block[mask] + self.block_edge_index = self.block_edge_index[:, mask] + if with_weight: + self.block_edge_weight = self.block_edge_weight[mask] + @staticmethod def _num_possible_edges(n: int, is_undirected_graph: bool) -> int: """Determine number of possible edges for graph.""" diff --git a/greatx/functional/losses.py b/greatx/functional/losses.py index 837b744..7f578fe 100644 --- a/greatx/functional/losses.py +++ b/greatx/functional/losses.py @@ -1,104 +1,107 @@ -from typing import Optional - import torch import torch.nn.functional as F from torch import Tensor -def margin_loss(score: Tensor, labels: Tensor, - mask: Optional[Tensor] = None) -> Tensor: +def margin_loss(score: Tensor, target: Tensor) -> Tensor: r"""Margin loss between true score and highest non-target score: .. math:: m = - s_{y} + max_{y' \ne y} s_{y'} where :math:`m` is the margin :math:`s` the score and :math:`y` the - labels. - - Args: - score (Tensor): Some score (e.g. logits) of shape - :obj:`[n_elem, dim]`. - labels (LongTensor): The labels of shape :obj:`[n_elem]`. - mask (Tensor, optional): To select subset of `score` and - `labels` of shape :obj:`[n_select]`. Defaults to None. - - :rtype: (Tensor) + target. + + Parameters + ---------- + score : Tensor + some score (e.g. prediction) of shape :obj:`[n_elem, dim]`. + target : LongTensor + the target of shape :obj:`[n_elem]`. + + Returns + ------- + Tensor + the calculated margins """ - if mask is not None: - score = score[mask] - labels = labels[mask] linear_idx = torch.arange(score.size(0), device=score.device) - true_score = score[linear_idx, labels] + true_score = score[linear_idx, target] score = score.clone() - score[linear_idx, labels] = float('-Inf') + score[linear_idx, target] = float('-Inf') best_non_target_score = score.amax(dim=-1) margin = best_non_target_score - true_score return margin -def tanh_margin_loss(prediction: Tensor, labels: Tensor, - mask: Optional[Tensor] = None) -> Tensor: - """Calculate tanh margin loss, a node-classification loss that focuses +def tanh_margin_loss(prediction: Tensor, target: Tensor) -> Tensor: + r"""Calculate tanh margin loss, a node-classification loss that focuses on nodes next to decision boundary. - Args: - prediction (Tensor): Prediction of shape :obj:`[n_elem, dim]`. - labels (LongTensor): The labels of shape :obj:`[n_elem]`. - mask (Tensor, optional): To select subset of `score` and - `labels` of shape :obj:`[n_select]`. Defaults to None. - - :rtype: (Tensor) + Parameters + ---------- + prediction : Tensor + prediction of shape :obj:`[n_elem, dim]`. + target : LongTensor + the target of shape :obj:`[n_elem]`. + + Returns + ------- + Tensor + the calculated loss """ - log_logits = F.log_softmax(prediction, dim=-1) - margin = margin_loss(log_logits, labels, mask) + prediction = F.log_softmax(prediction, dim=-1) + margin = margin_loss(prediction, target) loss = torch.tanh(margin).mean() return loss -def probability_margin_loss(prediction: Tensor, labels: Tensor, - mask: Optional[Tensor] = None) -> Tensor: - """Calculate probability margin loss, a node-classification loss that +def probability_margin_loss(prediction: Tensor, target: Tensor) -> Tensor: + r"""Calculate probability margin loss, a node-classification loss that focuses on nodes next to decision boundary. See `Are Defenses for Graph Neural Networks Robust? `_ for details. - Args: - prediction (Tensor): Prediction of shape :obj:`[n_elem, dim]`. - labels (LongTensor): The labels of shape :obj:`[n_elem]`. - mask (Tensor, optional): To select subset of `score` and - `labels` of shape :obj:`[n_select]`. Defaults to None. - - :rtype: (Tensor) + Parameters + ---------- + prediction : Tensor + prediction of shape :obj:`[n_elem, dim]`. + target : LongTensor + the target of shape :obj:`[n_elem]`. + + Returns + ------- + Tensor + the calculated loss """ - logits = F.softmax(prediction, dim=-1) - margin = margin_loss(logits, labels, mask) + prediction = F.softmax(prediction, dim=-1) + margin = margin_loss(prediction, target) return margin.mean() -def masked_cross_entropy(log_logits: Tensor, labels: Tensor, - mask: Optional[Tensor] = None) -> Tensor: - """Calculate masked cross entropy loss, a node-classification loss that +def masked_cross_entropy(prediction: Tensor, target: Tensor) -> Tensor: + r"""Calculate masked cross entropy loss, a node-classification loss that focuses on nodes next to decision boundary. - Args: - log_logits (Tensor): Log logits of shape :obj:`[n_elem, dim]`. - labels (LongTensor): The labels of shape :obj:`[n_elem]`. - mask (Tensor, optional): To select subset of `score` and - `labels` of shape :obj:`[n_select]`. Defaults to None. - - :rtype: (Tensor) + Parameters + ---------- + prediction : Tensor + prediction of shape :obj:`[n_elem, dim]`. + target : LongTensor + the target of shape :obj:`[n_elem]`. + + Returns + ------- + Tensor + the calculated loss """ - if mask is not None: - log_logits = log_logits[mask] - labels = labels[mask] - is_correct = log_logits.argmax(-1) == labels + is_correct = prediction.argmax(-1) == target if is_correct.any(): - log_logits = log_logits[is_correct] - labels = labels[is_correct] + prediction = prediction[is_correct] + target = target[is_correct] - loss = F.cross_entropy(log_logits, labels) + loss = F.cross_entropy(prediction, target) return loss From 7f4293b3c5d0c562f8cb4fb867c93f42a7a452bf Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Mon, 28 Nov 2022 18:26:10 +0800 Subject: [PATCH 06/15] update --- greatx/attack/untargeted/rbcd_attack.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/greatx/attack/untargeted/rbcd_attack.py b/greatx/attack/untargeted/rbcd_attack.py index 2b1359b..61c6c0f 100644 --- a/greatx/attack/untargeted/rbcd_attack.py +++ b/greatx/attack/untargeted/rbcd_attack.py @@ -141,6 +141,8 @@ def attack( self.coeffs.update(**kwargs) num_budgets = self.num_budgets + feat, victim_nodes, victim_labels = (self.feat, self.victim_nodes, + self.victim_labels) # Sample initial search space (Algorithm 1, line 3-4) self.sample_random_block(num_budgets) @@ -149,9 +151,8 @@ def attack( for step in tqdm(range(num_budgets), desc='Peturbing graph...', disable=disable): - loss, gradient = self.compute_gradients(self.feat, - self.victim_labels, - self.victim_nodes) + loss, gradient = self.compute_gradients(feat, victim_labels, + victim_nodes) scalars = self.update(step, gradient, num_budgets) From 9fda87179e2f799f271d1ef2fcddd20cec1b32e8 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Mon, 28 Nov 2022 18:34:16 +0800 Subject: [PATCH 07/15] update --- greatx/attack/untargeted/rbcd_attack.py | 48 ++++++++++++++----------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/greatx/attack/untargeted/rbcd_attack.py b/greatx/attack/untargeted/rbcd_attack.py index 61c6c0f..6647a55 100644 --- a/greatx/attack/untargeted/rbcd_attack.py +++ b/greatx/attack/untargeted/rbcd_attack.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union import numpy as np import torch @@ -92,6 +92,8 @@ def reset(self) -> "PRBCDAttack": # NOTE: `edge_weight` denotes the edge weight of the original graph # it is None by default, so here we need to name it as `edge_weights` self.edge_weights = torch.ones(self.num_edges, device=self.device) + + # For early stopping (not explicitly covered by pseudo code) self.best_metric = float('-Inf') # For collecting attack statistics @@ -120,7 +122,6 @@ def attack( feature_attack=feature_attack) self.block_size = block_size - self.epochs = epochs assert loss in ['mce', 'prob_margin', 'tanh_margin'] if loss == 'mce': @@ -144,12 +145,9 @@ def attack( feat, victim_nodes, victim_labels = (self.feat, self.victim_nodes, self.victim_labels) - # Sample initial search space (Algorithm 1, line 3-4) - self.sample_random_block(num_budgets) - # Loop over the epochs (Algorithm 1, line 5) - for step in tqdm(range(num_budgets), desc='Peturbing graph...', - disable=disable): + for step in tqdm(self.prepare(num_budgets, epochs), + desc='Peturbing graph...', disable=disable): loss, gradient = self.compute_gradients(feat, victim_labels, victim_nodes) @@ -163,6 +161,15 @@ def attack( return self + @torch.no_grad() + def prepare(self, num_budgets: int, epochs: int) -> Iterable[int]: + """Prepare attack and return the iterable sequence steps.""" + + # Sample initial search space (Algorithm 1, line 3-4) + self.sample_random_block(num_budgets) + + return range(epochs) + @torch.no_grad() def update(self, epoch: int, gradient: Tensor, num_budgets: int) -> Dict[str, float]: @@ -172,7 +179,7 @@ def update(self, epoch: int, gradient: Tensor, # For monitoring pmass_update = torch.clamp(self.block_edge_weight, 0, 1) - # Projection to stay within relaxed `L_0` budget + # Projection to stay within relaxed `L_0` num_budgets # (Algorithm 1, line 8) self.block_edge_weight = self.project(num_budgets, self.block_edge_weight, @@ -249,7 +256,7 @@ def close(self): assert flipped_edges.size(1) <= self.num_budgets, ( f'# perturbed edges {flipped_edges.size(1)} ' - f'exceeds budget {self.num_budgets}') + f'exceeds num_budgets {self.num_budgets}') row, col = flipped_edges.tolist() for it, (u, v, w) in enumerate(zip(row, col, edge_weight.tolist())): @@ -312,7 +319,7 @@ def get_modified_graph( return modified_edge_index, modified_edge_weight @torch.no_grad() - def sample_random_block(self, budget: int = 0): + def sample_random_block(self, num_budgets: int = 0): for _ in range(self.coeffs['max_trials_sampling']): num_possible_edges = self._num_possible_edges( self.num_nodes, self.is_undirected_graph) @@ -333,13 +340,13 @@ def sample_random_block(self, budget: int = 0): self.block_edge_weight = torch.full(self.current_block.shape, self.coeffs['eps'], device=self.device) - if self.current_block.size(0) >= budget: + if self.current_block.size(0) >= num_budgets: return raise RuntimeError('Sampling random block was not successful. ' - 'Please decrease `budget`.') + 'Please decrease `num_budgets`.') - def resample_random_block(self, budget: int): + def resample_random_block(self, num_budgets: int): # Keep at most half of the block (i.e. resample low weights) sorted_idx = torch.argsort(self.block_edge_weight) keep_above = (self.block_edge_weight <= @@ -381,10 +388,11 @@ def resample_random_block(self, budget: int): if not self.is_undirected_graph: self._filter_self_loops_in_block(with_weight=True) - if self.current_block.size(0) > budget: + if self.current_block.size(0) > num_budgets: return + raise RuntimeError('Sampling random block was not successful.' - 'Please decrease `budget`.') + 'Please decrease `num_budgets`.') @torch.no_grad() def sample_final_edges( @@ -408,7 +416,7 @@ def sample_final_edges( sampled_edges = torch.bernoulli(block_edge_weight).float() if sampled_edges.sum() > num_budgets: - # Allowed budget is exceeded + # Allowed num_budgets is exceeded continue self.block_edge_weight = sampled_edges @@ -440,13 +448,13 @@ def update_edge_weights(self, num_budgets: int, epoch: int, self.block_edge_weight.data.add_(lr * gradient) @staticmethod - def project(budget: int, values: Tensor, eps: float = 1e-7) -> Tensor: + def project(num_budgets: int, values: Tensor, eps: float = 1e-7) -> Tensor: r"""Project :obj:`values`: - :math:`budget \ge \sum \Pi_{[0, 1]}(\text{values})`.""" - if torch.clamp(values, 0, 1).sum() > budget: + :math:`num_budgets \ge \sum \Pi_{[0, 1]}(\text{values})`.""" + if torch.clamp(values, 0, 1).sum() > num_budgets: left = (values - 1).min() right = values.max() - miu = PRBCDAttack.bisection(values, left, right, budget) + miu = PRBCDAttack.bisection(values, left, right, num_budgets) values = values - miu return torch.clamp(values, min=eps, max=1 - eps) From 066adf29e0d40537519eb9bdcb8d20d930d09859 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Mon, 28 Nov 2022 21:47:25 +0800 Subject: [PATCH 08/15] update --- greatx/attack/untargeted/rbcd_attack.py | 270 ++++++++++++++---------- greatx/attack/untargeted/utils.py | 63 ++++++ 2 files changed, 223 insertions(+), 110 deletions(-) create mode 100644 greatx/attack/untargeted/utils.py diff --git a/greatx/attack/untargeted/rbcd_attack.py b/greatx/attack/untargeted/rbcd_attack.py index 6647a55..37eca74 100644 --- a/greatx/attack/untargeted/rbcd_attack.py +++ b/greatx/attack/untargeted/rbcd_attack.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -8,6 +8,12 @@ from tqdm.auto import tqdm from greatx.attack.untargeted.untargeted_attacker import UntargetedAttacker +from greatx.attack.untargeted.utils import ( + linear_to_full_idx, + linear_to_triu_idx, + num_possible_edges, + project, +) from greatx.functional import ( masked_cross_entropy, probability_margin_loss, @@ -20,11 +26,38 @@ class PRBCDAttack(UntargetedAttacker, Surrogate): + r"""Projected Randomized Block Coordinate Descent (PRBCD) adversarial + attack from the `Robustness of Graph Neural Networks at Scale + `_ paper. + + This attack uses an efficient gradient based approach that (during the + attack) relaxes the discrete entries in the adjacency matrix + :math:`\{0, 1\}` to :math:`[0, 1]` and solely perturbs the adjacency matrix + (no feature perturbations). Thus, this attack supports all models that can + handle weighted graphs that are differentiable w.r.t. these edge weights. + For non-differentiable models you might be able to e.g. use the gumble + softmax trick. + + The memory overhead is driven by the additional edges (at most + :attr:`block_size`). For scalability reasons, the block is drawn with + replacement and then the index is made unique. Thus, the actual block size + is typically slightly smaller than specified. + + This attack can be used for both global and local attacks as well as + test-time attacks (evasion) and training-time attacks (poisoning). Please + see the provided examples. + + This attack is designed with a focus on node- or graph-classification, + however, to adapt to other tasks you most likely only need to provide an + appropriate loss and model. However, we currently do not support batching + out of the box (sampling needs to be adapted). + + """ # TODO: Although PRBCDAttack accepts directed graphs, # we currently don't explicitlyt support directed graphs. # This should be made available in the future. - is_undirected_graph: bool = True + is_undirected: bool = True coeffs: Dict[str, Any] = { 'max_final_samples': 20, @@ -89,9 +122,10 @@ def reset(self) -> "PRBCDAttack": self.loss = None self.metric = None - # NOTE: `edge_weight` denotes the edge weight of the original graph - # it is None by default, so here we need to name it as `edge_weights` - self.edge_weights = torch.ones(self.num_edges, device=self.device) + # NOTE: Since `edge_index` and `edge_weight` denote the original graph + # here we need to name them as `edge_index`and `_edge_weight` + self._edge_index = self.edge_index + self._edge_weight = torch.ones(self.num_edges, device=self.device) # For early stopping (not explicitly covered by pseudo code) self.best_metric = float('-Inf') @@ -157,11 +191,20 @@ def attack( scalars['loss'] = loss.item() self._append_statistics(scalars) - self.close() + flipped_edges = self.get_flipped_edges() + + assert flipped_edges.size(1) <= self.num_budgets, ( + f'# perturbed edges {flipped_edges.size(1)} ' + f'exceeds num_budgets {self.num_budgets}') + + for it, (u, v) in enumerate(zip(*flipped_edges.tolist())): + if self.adjacency_matrix[u, v] > 0: + self.remove_edge(u, v, it) + else: + self.add_edge(u, v, it) return self - @torch.no_grad() def prepare(self, num_budgets: int, epochs: int) -> Iterable[int]: """Prepare attack and return the iterable sequence steps.""" @@ -181,9 +224,8 @@ def update(self, epoch: int, gradient: Tensor, pmass_update = torch.clamp(self.block_edge_weight, 0, 1) # Projection to stay within relaxed `L_0` num_budgets # (Algorithm 1, line 8) - self.block_edge_weight = self.project(num_budgets, - self.block_edge_weight, - self.coeffs['eps']) + self.block_edge_weight = project(num_budgets, self.block_edge_weight, + self.coeffs['eps']) # For monitoring scalars = dict( @@ -206,7 +248,7 @@ def update(self, epoch: int, gradient: Tensor, num_budgets).indices] = 1 edge_index, edge_weight = self.get_modified_graph( - self.edge_index, self.edge_weights, self.block_edge_index, + self._edge_index, self._edge_weight, self.block_edge_index, topk_block_edge_weight) prediction = self.surrogate(self.feat, edge_index, @@ -235,9 +277,8 @@ def update(self, epoch: int, gradient: Tensor, scalars['metric'] = metric.item() return scalars - @torch.no_grad() - def close(self): - """Clean up and prepare return argument.""" + def get_flipped_edges(self) -> Tensor: + """Clean up and prepare return flipped edges.""" # Retrieve best epoch if early stopping is active # (not explicitly covered by pseudo code) @@ -247,24 +288,13 @@ def close(self): self.block_edge_weight = self.best_pert_edge_weight.to(self.device) # Sample final discrete graph (Algorithm 1, line 16) - flipped_edges, edge_weight = self.sample_final_edges( + return self.sample_final_edges( self.feat, self.num_budgets, self.victim_nodes, self.victim_labels, ) - assert flipped_edges.size(1) <= self.num_budgets, ( - f'# perturbed edges {flipped_edges.size(1)} ' - f'exceeds num_budgets {self.num_budgets}') - - row, col = flipped_edges.tolist() - for it, (u, v, w) in enumerate(zip(row, col, edge_weight.tolist())): - if w > 0: - self.remove_edge(u, v, it) - else: - self.add_edge(u, v, it) - def compute_gradients( self, feat: Tensor, @@ -277,7 +307,7 @@ def compute_gradients( # Retrieve sparse perturbed adjacency matrix `A \oplus p_{t-1}` # (Algorithm 1, line 6 / Algorithm 2, line 7) edge_index, edge_weight = self.get_modified_graph( - self.edge_index, self.edge_weights, self.block_edge_index, + self._edge_index, self._edge_weight, self.block_edge_index, self.block_edge_weight) # Get prediction (Algorithm 1, line 6 / Algorithm 2, line 7) @@ -300,7 +330,7 @@ def get_modified_graph( block_edge_weight: Tensor, ) -> Tuple[Tensor, Tensor]: """Merges adjacency matrix with current block (incl. weights)""" - if self.is_undirected_graph: + if self.is_undirected: block_edge_index, block_edge_weight = to_undirected( block_edge_index, block_edge_weight, num_nodes=self.num_nodes, reduce='mean') @@ -321,18 +351,18 @@ def get_modified_graph( @torch.no_grad() def sample_random_block(self, num_budgets: int = 0): for _ in range(self.coeffs['max_trials_sampling']): - num_possible_edges = self._num_possible_edges( - self.num_nodes, self.is_undirected_graph) - self.current_block = torch.randint(num_possible_edges, + num_possible = num_possible_edges(self.num_nodes, + self.is_undirected) + self.current_block = torch.randint(num_possible, (self.block_size, ), device=self.device) self.current_block = torch.unique(self.current_block, sorted=True) - if self.is_undirected_graph: - self.block_edge_index = self._linear_to_triu_idx( + if self.is_undirected: + self.block_edge_index = linear_to_triu_idx( self.num_nodes, self.current_block) else: - self.block_edge_index = self._linear_to_full_idx( + self.block_edge_index = linear_to_full_idx( self.num_nodes, self.current_block) self._filter_self_loops_in_block(with_weight=False) @@ -343,8 +373,8 @@ def sample_random_block(self, num_budgets: int = 0): if self.current_block.size(0) >= num_budgets: return - raise RuntimeError('Sampling random block was not successful. ' - 'Please decrease `num_budgets`.') + raise RuntimeError("Sampling random block was not successful. " + "Please decrease `num_budgets`.") def resample_random_block(self, num_budgets: int): # Keep at most half of the block (i.e. resample low weights) @@ -360,20 +390,20 @@ def resample_random_block(self, num_budgets: int): # Sample until enough edges were drawn for _ in range(self.coeffs['max_trials_sampling']): n_edges_resample = self.block_size - self.current_block.size(0) - num_possible_edges = self._num_possible_edges( - self.num_nodes, self.is_undirected_graph) - lin_index = torch.randint(num_possible_edges, (n_edges_resample, ), + num_possible = num_possible_edges(self.num_nodes, + self.is_undirected) + lin_index = torch.randint(num_possible, (n_edges_resample, ), device=self.device) current_block = torch.cat((self.current_block, lin_index)) self.current_block, unique_idx = torch.unique( current_block, sorted=True, return_inverse=True) - if self.is_undirected_graph: - self.block_edge_index = self._linear_to_triu_idx( + if self.is_undirected: + self.block_edge_index = linear_to_triu_idx( self.num_nodes, self.current_block) else: - self.block_edge_index = self._linear_to_full_idx( + self.block_edge_index = linear_to_full_idx( self.num_nodes, self.current_block) # Merge existing weights with new edge weights @@ -385,14 +415,14 @@ def resample_random_block(self, num_budgets: int): self.block_edge_weight[ unique_idx[:sorted_idx.size(0)]] = block_edge_weight_prev - if not self.is_undirected_graph: + if not self.is_undirected: self._filter_self_loops_in_block(with_weight=True) if self.current_block.size(0) > num_budgets: return - raise RuntimeError('Sampling random block was not successful.' - 'Please decrease `num_budgets`.') + raise RuntimeError("Sampling random block was not successful." + "Please decrease `num_budgets`.") @torch.no_grad() def sample_final_edges( @@ -422,7 +452,7 @@ def sample_final_edges( self.block_edge_weight = sampled_edges edge_index, edge_weight = self.get_modified_graph( - self.edge_index, self.edge_weights, self.block_edge_index, + self._edge_index, self._edge_weight, self.block_edge_index, self.block_edge_weight) prediction = self.surrogate(feat, edge_index, edge_weight)[victim_nodes] @@ -433,9 +463,8 @@ def sample_final_edges( best_metric = metric best_edge_weight = self.block_edge_weight.clone().cpu() - flipped_edges = self.block_edge_index[:, - torch.where(best_edge_weight)[0]] - return flipped_edges, best_edge_weight + flipped_edges = self.block_edge_index[:, best_edge_weight != 0] + return flipped_edges def update_edge_weights(self, num_budgets: int, epoch: int, gradient: Tensor): @@ -447,39 +476,6 @@ def update_edge_weights(self, num_budgets: int, epoch: int, np.sqrt(max(0, epoch - self.epochs_resampling) + 1)) self.block_edge_weight.data.add_(lr * gradient) - @staticmethod - def project(num_budgets: int, values: Tensor, eps: float = 1e-7) -> Tensor: - r"""Project :obj:`values`: - :math:`num_budgets \ge \sum \Pi_{[0, 1]}(\text{values})`.""" - if torch.clamp(values, 0, 1).sum() > num_budgets: - left = (values - 1).min() - right = values.max() - miu = PRBCDAttack.bisection(values, left, right, num_budgets) - values = values - miu - return torch.clamp(values, min=eps, max=1 - eps) - - @staticmethod - def bisection(edge_weights: Tensor, a: float, b: float, n_pert: int, - eps=1e-5, max_iter=1e3) -> Tensor: - """Bisection search for projection.""" - def shift(offset: float): - return (torch.clamp(edge_weights - offset, 0, 1).sum() - n_pert) - - miu = a - for _ in range(int(max_iter)): - miu = (a + b) / 2 - # Check if middle point is root - if (shift(miu) == 0.0): - break - # Decide the side to repeat the steps - if (shift(miu) * shift(a) < 0): - b = miu - else: - a = miu - if ((b - a) <= eps): - break - return miu - def _filter_self_loops_in_block(self, with_weight: bool): mask = self.block_edge_index[0] != self.block_edge_index[1] self.current_block = self.current_block[mask] @@ -487,34 +483,88 @@ def _filter_self_loops_in_block(self, with_weight: bool): if with_weight: self.block_edge_weight = self.block_edge_weight[mask] - @staticmethod - def _num_possible_edges(n: int, is_undirected_graph: bool) -> int: - """Determine number of possible edges for graph.""" - if is_undirected_graph: - return n * (n - 1) // 2 - else: - return int(n**2) # We filter self-loops later - - @staticmethod - def _linear_to_triu_idx(n: int, lin_idx: Tensor) -> Tensor: - """Linear index to upper triangular matrix without diagonal. - This is similar to - https://stackoverflow.com/questions/242711/algorithm-for-index-numbers-of-triangular-matrix-coefficients/28116498#28116498 - with number nodes decremented and col index incremented by one.""" - nn = n * (n - 1) - row_idx = n - 2 - torch.floor( - torch.sqrt(-8 * lin_idx.double() + 4 * nn - 7) / 2.0 - 0.5).long() - col_idx = 1 + lin_idx + row_idx - nn // 2 + torch.div( - (n - row_idx) * (n - row_idx - 1), 2, rounding_mode='floor') - return torch.stack((row_idx, col_idx)) - - @staticmethod - def _linear_to_full_idx(n: int, lin_idx: Tensor) -> Tensor: - """Linear index to dense matrix including diagonal.""" - row_idx = torch.div(lin_idx, n, rounding_mode='floor') - col_idx = lin_idx % n - return torch.stack((row_idx, col_idx)) - def _append_statistics(self, mapping: Dict[str, Any]): for key, value in mapping.items(): self.attack_statistics[key].append(value) + + +class GRBCDAttack(PRBCDAttack): + r"""Greedy Randomized Block Coordinate Descent (GRBCD) adversarial attack + from the `Robustness of Graph Neural Networks at Scale + `_ paper. + + GRBCD shares most of the properties and requirements with + :class:`PRBCDAttack`. It also uses an efficient gradient based approach. + However, it greedily flips edges based on the gradient towards the + adjacency matrix. + + """ + def prepare(self, num_budgets: int, epochs: int) -> List[int]: + """Prepare attack.""" + + # Determine the number of edges to be flipped in each attach step/epoch + step_size = num_budgets // epochs + if step_size > 0: + steps = epochs * [step_size] + for i in range(num_budgets % epochs): + steps[i] += 1 + else: + steps = [1] * num_budgets + + # Sample initial search space (Algorithm 2, line 3-4) + self.sample_random_block(step_size) + + return steps + + def reset(self) -> "GRBCDAttack": + super().reset() + self.flipped_edges = self._edge_index.new_empty(2, 0) + return self + + @torch.no_grad() + def update( + self, + step_size: int, + gradient: Tensor, + num_budgets: int, + ) -> Dict[str, Any]: + """Update edge weights given gradient.""" + _, topk_edge_index = torch.topk(gradient, step_size) + + flip_edge_index = self.block_edge_index[:, topk_edge_index].to( + self.device) + flip_edge_weight = torch.ones(flip_edge_index.size(1), + device=self.device) + + self.flipped_edges = torch.cat((self.flipped_edges, flip_edge_index), + axis=-1) + + if self.is_undirected: + flip_edge_index, flip_edge_weight = to_undirected( + flip_edge_index, flip_edge_weight, num_nodes=self.num_nodes, + reduce='mean') + + edge_index = torch.cat((self._edge_index, flip_edge_index), dim=-1) + edge_weight = torch.cat((self._edge_weight, flip_edge_weight)) + + edge_index, edge_weight = coalesce(edge_index, edge_weight, + num_nodes=self.num_nodes, + reduce='sum') + + is_one_mask = torch.isclose(edge_weight, torch.tensor(1.)) + + self._edge_index = edge_index[:, is_one_mask] + self._edge_weight = edge_weight[is_one_mask] + + # Sample initial search space (Algorithm 2, line 3-4) + self.sample_random_block(step_size) + + # Return debug information + scalars = { + 'number_positive_entries_in_gradient': (gradient > 0).sum().item() + } + return scalars + + def get_flipped_edges(self) -> Tensor: + """Clean up and prepare return flipped edges.""" + return self.flipped_edges diff --git a/greatx/attack/untargeted/utils.py b/greatx/attack/untargeted/utils.py new file mode 100644 index 0000000..d446d9f --- /dev/null +++ b/greatx/attack/untargeted/utils.py @@ -0,0 +1,63 @@ +import torch +from torch import Tensor + + +def project(num_budgets: int, values: Tensor, eps: float = 1e-7) -> Tensor: + r"""Project :obj:`values`: + :math:`num_budgets \ge \sum \Pi_{[0, 1]}(\text{values})`.""" + if torch.clamp(values, 0, 1).sum() > num_budgets: + left = (values - 1).min() + right = values.max() + miu = bisection(values, left, right, num_budgets) + values = values - miu + return torch.clamp(values, min=eps, max=1 - eps) + + +def bisection(edge_weight: Tensor, a: float, b: float, n_pert: int, eps=1e-5, + max_iter=1e3) -> Tensor: + """Bisection search for projection.""" + def shift(offset: float): + return (torch.clamp(edge_weight - offset, 0, 1).sum() - n_pert) + + miu = a + for _ in range(int(max_iter)): + miu = (a + b) / 2 + # Check if middle point is root + if (shift(miu) == 0.0): + break + # Decide the side to repeat the steps + if (shift(miu) * shift(a) < 0): + b = miu + else: + a = miu + if ((b - a) <= eps): + break + return miu + + +def num_possible_edges(n: int, is_undirected_graph: bool) -> int: + """Determine number of possible edges for graph.""" + if is_undirected_graph: + return n * (n - 1) // 2 + else: + return int(n**2) # We filter self-loops later + + +def linear_to_triu_idx(n: int, lin_idx: Tensor) -> Tensor: + """Linear index to upper triangular matrix without diagonal. + This is similar to + https://stackoverflow.com/questions/242711/algorithm-for-index-numbers-of-triangular-matrix-coefficients/28116498#28116498 + with number nodes decremented and col index incremented by one.""" + nn = n * (n - 1) + row_idx = n - 2 - torch.floor( + torch.sqrt(-8 * lin_idx.double() + 4 * nn - 7) / 2.0 - 0.5).long() + col_idx = 1 + lin_idx + row_idx - nn // 2 + torch.div( + (n - row_idx) * (n - row_idx - 1), 2, rounding_mode='floor') + return torch.stack((row_idx, col_idx)) + + +def linear_to_full_idx(n: int, lin_idx: Tensor) -> Tensor: + """Linear index to dense matrix including diagonal.""" + row_idx = torch.div(lin_idx, n, rounding_mode='floor') + col_idx = lin_idx % n + return torch.stack((row_idx, col_idx)) From b38e70ed98922648532fdb744774ab1ab795763f Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Mon, 28 Nov 2022 23:20:22 +0800 Subject: [PATCH 09/15] update --- greatx/attack/untargeted/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/greatx/attack/untargeted/__init__.py b/greatx/attack/untargeted/__init__.py index 7bc5c04..c74fb2e 100644 --- a/greatx/attack/untargeted/__init__.py +++ b/greatx/attack/untargeted/__init__.py @@ -5,7 +5,7 @@ from .pgd_attack import PGDAttack from .random_attack import RandomAttack from .untargeted_attacker import UntargetedAttacker -from .rbcd_attack import PRBCDAttack +from .rbcd_attack import PRBCDAttack, GRBCDAttack classes = __all__ = [ 'UntargetedAttacker', @@ -16,4 +16,5 @@ 'Metattack', 'PGDAttack', 'PRBCDAttack', + 'GRBCDAttack', ] From a3eb36ff2cf2b1f792723016a33b39eebf05899a Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Mon, 28 Nov 2022 23:53:02 +0800 Subject: [PATCH 10/15] update --- greatx/attack/targeted/rbcd_attack.py | 249 ++++++++++++++ greatx/attack/untargeted/rbcd_attack.py | 426 ++++++++++++------------ 2 files changed, 465 insertions(+), 210 deletions(-) create mode 100644 greatx/attack/targeted/rbcd_attack.py diff --git a/greatx/attack/targeted/rbcd_attack.py b/greatx/attack/targeted/rbcd_attack.py new file mode 100644 index 0000000..0911655 --- /dev/null +++ b/greatx/attack/targeted/rbcd_attack.py @@ -0,0 +1,249 @@ +from collections import defaultdict +from typing import Callable, Dict, Iterable, Optional, Union + +import torch +from torch import Tensor +from tqdm.auto import tqdm + +from greatx.attack.targeted.targeted_attacker import TargetedAttacker +from greatx.attack.untargeted.rbcd_attack import RBCDAttack +from greatx.attack.untargeted.utils import project +from greatx.functional import ( + masked_cross_entropy, + probability_margin_loss, + tanh_margin_loss, +) +from greatx.nn.models.surrogate import Surrogate + +# (predictions, labels, ids/mask) -> Tensor with one element +METRIC = Callable[[Tensor, Tensor, Optional[Tensor]], Tensor] + + +class PRBCDAttack(TargetedAttacker, RBCDAttack, Surrogate): + r"""Projected Randomized Block Coordinate Descent (PRBCD) adversarial + attack from the `Robustness of Graph Neural Networks at Scale + `_ paper. + + This attack uses an efficient gradient based approach that (during the + attack) relaxes the discrete entries in the adjacency matrix + :math:`\{0, 1\}` to :math:`[0, 1]` and solely perturbs the adjacency matrix + (no feature perturbations). Thus, this attack supports all models that can + handle weighted graphs that are differentiable w.r.t. these edge weights. + For non-differentiable models you might be able to e.g. use the gumble + softmax trick. + + The memory overhead is driven by the additional edges (at most + :attr:`block_size`). For scalability reasons, the block is drawn with + replacement and then the index is made unique. Thus, the actual block size + is typically slightly smaller than specified. + + This attack can be used for both global and local attacks as well as + test-time attacks (evasion) and training-time attacks (poisoning). Please + see the provided examples. + + This attack is designed with a focus on node- or graph-classification, + however, to adapt to other tasks you most likely only need to provide an + appropriate loss and model. However, we currently do not support batching + out of the box (sampling needs to be adapted). + + """ + def reset(self) -> "PRBCDAttack": + super().reset() + self.current_block = None + self.block_edge_index = None + self.block_edge_weight = None + self.loss = None + self.metric = None + + self.victim_nodes = None + self.victim_labels = None + + # NOTE: Since `edge_index` and `edge_weight` denote the original graph + # here we need to name them as `edge_index`and `_edge_weight` + self._edge_index = self.edge_index + self._edge_weight = torch.ones(self.num_edges, device=self.device) + + # For early stopping (not explicitly covered by pseudo code) + self.best_metric = float('-Inf') + + # For collecting attack statistics + self.attack_statistics = defaultdict(list) + + return self + + def attack( + self, + target, + *, + target_label=None, + num_budgets=None, + direct_attack=True, + block_size: int = 250_000, + epochs: int = 125, + epochs_resampling: int = 100, + loss: Optional[str] = 'tanh_margin', + metric: Optional[Union[str, METRIC]] = None, + lr: float = 2_000, + structure_attack: bool = True, + feature_attack: bool = False, + disable: bool = False, + **kwargs, + ) -> "PRBCDAttack": + + super().attack(target, target_label, num_budgets=num_budgets, + direct_attack=direct_attack, + structure_attack=structure_attack, + feature_attack=feature_attack) + + self.block_size = block_size + + assert loss in ['mce', 'prob_margin', 'tanh_margin'] + if loss == 'mce': + self.loss = masked_cross_entropy + elif loss == 'prob_margin': + self.loss = probability_margin_loss + else: + self.loss = tanh_margin_loss + + if metric is None: + self.metric = self.loss + else: + self.metric = metric + + self.epochs_resampling = epochs_resampling + self.lr = lr + + self.coeffs.update(**kwargs) + + num_budgets = self.num_budgets + feat = self.feat + self.victim_nodes = torch.as_tensor( + target, + dtype=torch.long, + device=self.device, + ).view(-1) + + self.victim_labels = torch.as_tensor( + self.target_label, + dtype=torch.long, + device=self.device, + ).view(-1) + + feat, victim_nodes, victim_labels = (self.feat, self.victim_nodes, + self.victim_labels) + + # Loop over the epochs (Algorithm 1, line 5) + for step in tqdm(self.prepare(num_budgets, epochs), + desc='Peturbing graph...', disable=disable): + + loss, gradient = self.compute_gradients(feat, victim_labels, + victim_nodes) + + scalars = self.update(step, gradient, num_budgets) + + scalars['loss'] = loss.item() + self._append_statistics(scalars) + + flipped_edges = self.get_flipped_edges() + + assert flipped_edges.size(1) <= self.num_budgets, ( + f'# perturbed edges {flipped_edges.size(1)} ' + f'exceeds num_budgets {self.num_budgets}') + + for it, (u, v) in enumerate(zip(*flipped_edges.tolist())): + if self.adjacency_matrix[u, v] > 0: + self.remove_edge(u, v, it) + else: + self.add_edge(u, v, it) + + return self + + def prepare(self, num_budgets: int, epochs: int) -> Iterable[int]: + """Prepare attack and return the iterable sequence steps.""" + + # Sample initial search space (Algorithm 1, line 3-4) + self.sample_random_block(num_budgets) + + return range(epochs) + + @torch.no_grad() + def update(self, epoch: int, gradient: Tensor, + num_budgets: int) -> Dict[str, float]: + """Update edge weights given gradient.""" + # Gradient update step (Algorithm 1, line 7) + self.update_edge_weights(num_budgets, epoch, gradient) + + # For monitoring + pmass_update = torch.clamp(self.block_edge_weight, 0, 1) + # Projection to stay within relaxed `L_0` num_budgets + # (Algorithm 1, line 8) + self.block_edge_weight = project(num_budgets, self.block_edge_weight, + self.coeffs['eps']) + + # For monitoring + scalars = dict( + prob_mass_after_update=pmass_update.sum().item(), + prob_mass_after_update_max=pmass_update.max().item(), + prob_mass_afterprojection=self.block_edge_weight.sum().item(), + prob_mass_afterprojection_nonzero_weights=( + self.block_edge_weight > self.coeffs['eps']).sum().item(), + prob_mass_afterprojection_max=self.block_edge_weight.max().item(), + ) + + if not self.coeffs['with_early_stopping']: + return scalars + + # Calculate metric after the current epoch (overhead + # for monitoring and early stopping) + + topk_block_edge_weight = torch.zeros_like(self.block_edge_weight) + topk_block_edge_weight[torch.topk(self.block_edge_weight, + num_budgets).indices] = 1 + + edge_index, edge_weight = self.get_modified_graph( + self._edge_index, self._edge_weight, self.block_edge_index, + topk_block_edge_weight) + + prediction = self.surrogate(self.feat, edge_index, + edge_weight)[self.victim_nodes] + metric = self.metric(prediction, self.victim_labels) + + # Save best epoch for early stopping + # (not explicitly covered by pseudo code) + if metric > self.best_metric: + self.best_metric = metric + self.best_block = self.current_block.cpu().clone() + self.best_edge_index = self.block_edge_index.cpu().clone() + self.best_pert_edge_weight = self.block_edge_weight.cpu().detach() + + # Resampling of search space (Algorithm 1, line 9-14) + if epoch < self.epochs_resampling - 1: + self.resample_random_block(num_budgets) + elif epoch == self.epochs_resampling - 1: + # Retrieve best epoch if early stopping is active + # (not explicitly covered by pseudo code) + self.current_block = self.best_block.to(self.device) + self.block_edge_index = self.best_edge_index.to(self.device) + block_edge_weight = self.best_pert_edge_weight.clone() + self.block_edge_weight = block_edge_weight.to(self.device) + + scalars['metric'] = metric.item() + return scalars + + def get_flipped_edges(self) -> Tensor: + """Clean up and prepare return flipped edges.""" + + # Retrieve best epoch if early stopping is active + # (not explicitly covered by pseudo code) + if self.coeffs['with_early_stopping']: + self.current_block = self.best_block.to(self.device) + self.block_edge_index = self.best_edge_index.to(self.device) + self.block_edge_weight = self.best_pert_edge_weight.to(self.device) + + # Sample final discrete graph (Algorithm 1, line 16) + return self.sample_final_edges( + self.feat, + self.num_budgets, + self.victim_nodes, + self.victim_labels, + ) diff --git a/greatx/attack/untargeted/rbcd_attack.py b/greatx/attack/untargeted/rbcd_attack.py index 37eca74..2b0f954 100644 --- a/greatx/attack/untargeted/rbcd_attack.py +++ b/greatx/attack/untargeted/rbcd_attack.py @@ -25,7 +25,219 @@ METRIC = Callable[[Tensor, Tensor, Optional[Tensor]], Tensor] -class PRBCDAttack(UntargetedAttacker, Surrogate): +class RBCDAttack: + """Base class for :class:`PRBCDAttack` and + :class:`GRBCDEAttack`.""" + + # RBCDAttack will not ensure there are no singleton nodes + _allow_singleton: bool = False + + # TODO: Although RBCDAttack accepts directed graphs, + # we currently don't explicitlyt support directed graphs. + # This should be made available in the future. + is_undirected: bool = True + + coeffs: Dict[str, Any] = { + 'max_final_samples': 20, + 'max_trials_sampling': 20, + 'with_early_stopping': True, + 'eps': 1e-7 + } + + def compute_gradients( + self, + feat: Tensor, + victim_labels: Tensor, + victim_nodes: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Forward and update edge weights.""" + self.block_edge_weight.requires_grad_() + + # Retrieve sparse perturbed adjacency matrix `A \oplus p_{t-1}` + # (Algorithm 1, line 6 / Algorithm 2, line 7) + edge_index, edge_weight = self.get_modified_graph( + self._edge_index, self._edge_weight, self.block_edge_index, + self.block_edge_weight) + + # Get prediction (Algorithm 1, line 6 / Algorithm 2, line 7) + prediction = self.surrogate(feat, edge_index, + edge_weight)[victim_nodes] + # Calculate loss combining all each node + # (Algorithm 1, line 7 / Algorithm 2, line 8) + loss = self.loss(prediction, victim_labels) + # Retrieve gradient towards the current block + # (Algorithm 1, line 7 / Algorithm 2, line 8) + gradient = torch.autograd.grad(loss, self.block_edge_weight)[0] + + return loss, gradient + + def get_modified_graph( + self, + edge_index: Tensor, + edge_weight: Tensor, + block_edge_index: Tensor, + block_edge_weight: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Merges adjacency matrix with current block (incl. weights)""" + if self.is_undirected: + block_edge_index, block_edge_weight = to_undirected( + block_edge_index, block_edge_weight, num_nodes=self.num_nodes, + reduce='mean') + + modified_edge_index = torch.cat((edge_index, block_edge_index), dim=-1) + modified_edge_weight = torch.cat((edge_weight, block_edge_weight)) + + modified_edge_index, modified_edge_weight = coalesce( + modified_edge_index, modified_edge_weight, + num_nodes=self.num_nodes, reduce='sum') + + # Allow (soft) removal of edges + mask = modified_edge_weight > 1 + modified_edge_weight[mask] = 2 - modified_edge_weight[mask] + + return modified_edge_index, modified_edge_weight + + @torch.no_grad() + def sample_random_block(self, num_budgets: int = 0): + for _ in range(self.coeffs['max_trials_sampling']): + num_possible = num_possible_edges(self.num_nodes, + self.is_undirected) + self.current_block = torch.randint(num_possible, + (self.block_size, ), + device=self.device) + self.current_block = torch.unique(self.current_block, sorted=True) + + if self.is_undirected: + self.block_edge_index = linear_to_triu_idx( + self.num_nodes, self.current_block) + else: + self.block_edge_index = linear_to_full_idx( + self.num_nodes, self.current_block) + + self._filter_self_loops_in_block(with_weight=False) + + self.block_edge_weight = torch.full(self.current_block.shape, + self.coeffs['eps'], + device=self.device) + if self.current_block.size(0) >= num_budgets: + return + + raise RuntimeError("Sampling random block was not successful. " + "Please decrease `num_budgets`.") + + def resample_random_block(self, num_budgets: int): + # Keep at most half of the block (i.e. resample low weights) + sorted_idx = torch.argsort(self.block_edge_weight) + keep_above = (self.block_edge_weight <= + self.coeffs['eps']).sum().long() + if keep_above < sorted_idx.size(0) // 2: + keep_above = sorted_idx.size(0) // 2 + sorted_idx = sorted_idx[keep_above:] + + self.current_block = self.current_block[sorted_idx] + + # Sample until enough edges were drawn + for _ in range(self.coeffs['max_trials_sampling']): + n_edges_resample = self.block_size - self.current_block.size(0) + num_possible = num_possible_edges(self.num_nodes, + self.is_undirected) + lin_index = torch.randint(num_possible, (n_edges_resample, ), + device=self.device) + + current_block = torch.cat((self.current_block, lin_index)) + self.current_block, unique_idx = torch.unique( + current_block, sorted=True, return_inverse=True) + + if self.is_undirected: + self.block_edge_index = linear_to_triu_idx( + self.num_nodes, self.current_block) + else: + self.block_edge_index = linear_to_full_idx( + self.num_nodes, self.current_block) + + # Merge existing weights with new edge weights + block_edge_weight_prev = self.block_edge_weight[sorted_idx] + self.block_edge_weight = torch.full(self.current_block.shape, + self.coeffs['eps'], + device=self.device) + + self.block_edge_weight[ + unique_idx[:sorted_idx.size(0)]] = block_edge_weight_prev + + if not self.is_undirected: + self._filter_self_loops_in_block(with_weight=True) + + if self.current_block.size(0) > num_budgets: + return + + raise RuntimeError("Sampling random block was not successful." + "Please decrease `num_budgets`.") + + @torch.no_grad() + def sample_final_edges( + self, + feat: Tensor, + num_budgets: int, + victim_nodes: Tensor, + victim_labels: Tensor, + ) -> Tuple[Tensor, Tensor]: + best_metric = float('-Inf') + block_edge_weight = self.block_edge_weight + block_edge_weight[block_edge_weight <= self.coeffs['eps']] = 0 + + for i in range(self.coeffs['max_final_samples']): + if i == 0: + # In first iteration employ top k heuristic instead of sampling + sampled_edges = torch.zeros_like(block_edge_weight) + sampled_edges[torch.topk(block_edge_weight, + num_budgets).indices] = 1 + else: + sampled_edges = torch.bernoulli(block_edge_weight).float() + + if sampled_edges.sum() > num_budgets: + # Allowed num_budgets is exceeded + continue + + self.block_edge_weight = sampled_edges + + edge_index, edge_weight = self.get_modified_graph( + self._edge_index, self._edge_weight, self.block_edge_index, + self.block_edge_weight) + prediction = self.surrogate(feat, edge_index, + edge_weight)[victim_nodes] + metric = self.metric(prediction, victim_labels) + + # Save best sample + if metric > best_metric: + best_metric = metric + best_edge_weight = self.block_edge_weight.clone().cpu() + + flipped_edges = self.block_edge_index[:, best_edge_weight != 0] + return flipped_edges + + def update_edge_weights(self, num_budgets: int, epoch: int, + gradient: Tensor): + # The learning rate is refined heuristically, s.t. (1) it is + # independent of the number of perturbations (assuming an undirected + # adjacency matrix) and (2) to decay learning rate during fine-tuning + # (i.e. fixed search space). + lr = (num_budgets / self.num_nodes * self.lr / + np.sqrt(max(0, epoch - self.epochs_resampling) + 1)) + self.block_edge_weight.data.add_(lr * gradient) + + def _filter_self_loops_in_block(self, with_weight: bool): + mask = self.block_edge_index[0] != self.block_edge_index[1] + self.current_block = self.current_block[mask] + self.block_edge_index = self.block_edge_index[:, mask] + if with_weight: + self.block_edge_weight = self.block_edge_weight[mask] + + def _append_statistics(self, mapping: Dict[str, Any]): + for key, value in mapping.items(): + self.attack_statistics[key].append(value) + + +class PRBCDAttack(UntargetedAttacker, RBCDAttack, Surrogate): r"""Projected Randomized Block Coordinate Descent (PRBCD) adversarial attack from the `Robustness of Graph Neural Networks at Scale `_ paper. @@ -53,19 +265,6 @@ class PRBCDAttack(UntargetedAttacker, Surrogate): out of the box (sampling needs to be adapted). """ - - # TODO: Although PRBCDAttack accepts directed graphs, - # we currently don't explicitlyt support directed graphs. - # This should be made available in the future. - is_undirected: bool = True - - coeffs: Dict[str, Any] = { - 'max_final_samples': 20, - 'max_trials_sampling': 20, - 'with_early_stopping': True, - 'eps': 1e-7 - } - def setup_surrogate( self, surrogate: torch.nn.Module, @@ -295,198 +494,6 @@ def get_flipped_edges(self) -> Tensor: self.victim_labels, ) - def compute_gradients( - self, - feat: Tensor, - victim_labels: Tensor, - victim_nodes: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Forward and update edge weights.""" - self.block_edge_weight.requires_grad_() - - # Retrieve sparse perturbed adjacency matrix `A \oplus p_{t-1}` - # (Algorithm 1, line 6 / Algorithm 2, line 7) - edge_index, edge_weight = self.get_modified_graph( - self._edge_index, self._edge_weight, self.block_edge_index, - self.block_edge_weight) - - # Get prediction (Algorithm 1, line 6 / Algorithm 2, line 7) - prediction = self.surrogate(feat, edge_index, - edge_weight)[victim_nodes] - # Calculate loss combining all each node - # (Algorithm 1, line 7 / Algorithm 2, line 8) - loss = self.loss(prediction, victim_labels) - # Retrieve gradient towards the current block - # (Algorithm 1, line 7 / Algorithm 2, line 8) - gradient = torch.autograd.grad(loss, self.block_edge_weight)[0] - - return loss, gradient - - def get_modified_graph( - self, - edge_index: Tensor, - edge_weight: Tensor, - block_edge_index: Tensor, - block_edge_weight: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Merges adjacency matrix with current block (incl. weights)""" - if self.is_undirected: - block_edge_index, block_edge_weight = to_undirected( - block_edge_index, block_edge_weight, num_nodes=self.num_nodes, - reduce='mean') - - modified_edge_index = torch.cat((edge_index, block_edge_index), dim=-1) - modified_edge_weight = torch.cat((edge_weight, block_edge_weight)) - - modified_edge_index, modified_edge_weight = coalesce( - modified_edge_index, modified_edge_weight, - num_nodes=self.num_nodes, reduce='sum') - - # Allow (soft) removal of edges - mask = modified_edge_weight > 1 - modified_edge_weight[mask] = 2 - modified_edge_weight[mask] - - return modified_edge_index, modified_edge_weight - - @torch.no_grad() - def sample_random_block(self, num_budgets: int = 0): - for _ in range(self.coeffs['max_trials_sampling']): - num_possible = num_possible_edges(self.num_nodes, - self.is_undirected) - self.current_block = torch.randint(num_possible, - (self.block_size, ), - device=self.device) - self.current_block = torch.unique(self.current_block, sorted=True) - - if self.is_undirected: - self.block_edge_index = linear_to_triu_idx( - self.num_nodes, self.current_block) - else: - self.block_edge_index = linear_to_full_idx( - self.num_nodes, self.current_block) - - self._filter_self_loops_in_block(with_weight=False) - - self.block_edge_weight = torch.full(self.current_block.shape, - self.coeffs['eps'], - device=self.device) - if self.current_block.size(0) >= num_budgets: - return - - raise RuntimeError("Sampling random block was not successful. " - "Please decrease `num_budgets`.") - - def resample_random_block(self, num_budgets: int): - # Keep at most half of the block (i.e. resample low weights) - sorted_idx = torch.argsort(self.block_edge_weight) - keep_above = (self.block_edge_weight <= - self.coeffs['eps']).sum().long() - if keep_above < sorted_idx.size(0) // 2: - keep_above = sorted_idx.size(0) // 2 - sorted_idx = sorted_idx[keep_above:] - - self.current_block = self.current_block[sorted_idx] - - # Sample until enough edges were drawn - for _ in range(self.coeffs['max_trials_sampling']): - n_edges_resample = self.block_size - self.current_block.size(0) - num_possible = num_possible_edges(self.num_nodes, - self.is_undirected) - lin_index = torch.randint(num_possible, (n_edges_resample, ), - device=self.device) - - current_block = torch.cat((self.current_block, lin_index)) - self.current_block, unique_idx = torch.unique( - current_block, sorted=True, return_inverse=True) - - if self.is_undirected: - self.block_edge_index = linear_to_triu_idx( - self.num_nodes, self.current_block) - else: - self.block_edge_index = linear_to_full_idx( - self.num_nodes, self.current_block) - - # Merge existing weights with new edge weights - block_edge_weight_prev = self.block_edge_weight[sorted_idx] - self.block_edge_weight = torch.full(self.current_block.shape, - self.coeffs['eps'], - device=self.device) - - self.block_edge_weight[ - unique_idx[:sorted_idx.size(0)]] = block_edge_weight_prev - - if not self.is_undirected: - self._filter_self_loops_in_block(with_weight=True) - - if self.current_block.size(0) > num_budgets: - return - - raise RuntimeError("Sampling random block was not successful." - "Please decrease `num_budgets`.") - - @torch.no_grad() - def sample_final_edges( - self, - feat: Tensor, - num_budgets: int, - victim_nodes: Tensor, - victim_labels: Tensor, - ) -> Tuple[Tensor, Tensor]: - best_metric = float('-Inf') - block_edge_weight = self.block_edge_weight - block_edge_weight[block_edge_weight <= self.coeffs['eps']] = 0 - - for i in range(self.coeffs['max_final_samples']): - if i == 0: - # In first iteration employ top k heuristic instead of sampling - sampled_edges = torch.zeros_like(block_edge_weight) - sampled_edges[torch.topk(block_edge_weight, - num_budgets).indices] = 1 - else: - sampled_edges = torch.bernoulli(block_edge_weight).float() - - if sampled_edges.sum() > num_budgets: - # Allowed num_budgets is exceeded - continue - - self.block_edge_weight = sampled_edges - - edge_index, edge_weight = self.get_modified_graph( - self._edge_index, self._edge_weight, self.block_edge_index, - self.block_edge_weight) - prediction = self.surrogate(feat, edge_index, - edge_weight)[victim_nodes] - metric = self.metric(prediction, victim_labels) - - # Save best sample - if metric > best_metric: - best_metric = metric - best_edge_weight = self.block_edge_weight.clone().cpu() - - flipped_edges = self.block_edge_index[:, best_edge_weight != 0] - return flipped_edges - - def update_edge_weights(self, num_budgets: int, epoch: int, - gradient: Tensor): - # The learning rate is refined heuristically, s.t. (1) it is - # independent of the number of perturbations (assuming an undirected - # adjacency matrix) and (2) to decay learning rate during fine-tuning - # (i.e. fixed search space). - lr = (num_budgets / self.num_nodes * self.lr / - np.sqrt(max(0, epoch - self.epochs_resampling) + 1)) - self.block_edge_weight.data.add_(lr * gradient) - - def _filter_self_loops_in_block(self, with_weight: bool): - mask = self.block_edge_index[0] != self.block_edge_index[1] - self.current_block = self.current_block[mask] - self.block_edge_index = self.block_edge_index[:, mask] - if with_weight: - self.block_edge_weight = self.block_edge_weight[mask] - - def _append_statistics(self, mapping: Dict[str, Any]): - for key, value in mapping.items(): - self.attack_statistics[key].append(value) - class GRBCDAttack(PRBCDAttack): r"""Greedy Randomized Block Coordinate Descent (GRBCD) adversarial attack @@ -497,7 +504,6 @@ class GRBCDAttack(PRBCDAttack): :class:`PRBCDAttack`. It also uses an efficient gradient based approach. However, it greedily flips edges based on the gradient towards the adjacency matrix. - """ def prepare(self, num_budgets: int, epochs: int) -> List[int]: """Prepare attack.""" @@ -551,10 +557,10 @@ def update( num_nodes=self.num_nodes, reduce='sum') - is_one_mask = torch.isclose(edge_weight, torch.tensor(1.)) + mask = torch.isclose(edge_weight, torch.tensor(1.)) - self._edge_index = edge_index[:, is_one_mask] - self._edge_weight = edge_weight[is_one_mask] + self._edge_index = edge_index[:, mask] + self._edge_weight = edge_weight[mask] # Sample initial search space (Algorithm 2, line 3-4) self.sample_random_block(step_size) From d89e2938cceafaeb2b7b5b680d2d71d0efe48497 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Tue, 29 Nov 2022 14:57:53 +0800 Subject: [PATCH 11/15] Update --- greatx/attack/targeted/__init__.py | 5 +- greatx/attack/targeted/rbcd_attack.py | 90 +++++++++++++++++++++++-- greatx/attack/untargeted/rbcd_attack.py | 5 ++ 3 files changed, 93 insertions(+), 7 deletions(-) diff --git a/greatx/attack/targeted/__init__.py b/greatx/attack/targeted/__init__.py index 4204717..31c6a78 100644 --- a/greatx/attack/targeted/__init__.py +++ b/greatx/attack/targeted/__init__.py @@ -1,3 +1,4 @@ +from .targeted_attacker import TargetedAttacker from .dice_attack import DICEAttack from .fg_attack import FGAttack from .gf_attack import GFAttack @@ -6,7 +7,7 @@ from .pgd_attack import PGDAttack from .random_attack import RandomAttack from .sg_attack import SGAttack -from .targeted_attacker import TargetedAttacker +from .rbcd_attack import PRBCDAttack, GRBCDAttack classes = __all__ = [ 'TargetedAttacker', @@ -18,4 +19,6 @@ 'Nettack', 'GFAttack', 'PGDAttack', + 'PRBCDAttack', + 'GRBCDAttack', ] diff --git a/greatx/attack/targeted/rbcd_attack.py b/greatx/attack/targeted/rbcd_attack.py index 0911655..87c0636 100644 --- a/greatx/attack/targeted/rbcd_attack.py +++ b/greatx/attack/targeted/rbcd_attack.py @@ -1,8 +1,9 @@ from collections import defaultdict -from typing import Callable, Dict, Iterable, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Union import torch from torch import Tensor +from torch_geometric.utils import coalesce, to_undirected from tqdm.auto import tqdm from greatx.attack.targeted.targeted_attacker import TargetedAttacker @@ -123,11 +124,7 @@ def attack( device=self.device, ).view(-1) - self.victim_labels = torch.as_tensor( - self.target_label, - dtype=torch.long, - device=self.device, - ).view(-1) + self.victim_labels = self.target_label.view(-1) feat, victim_nodes, victim_labels = (self.feat, self.victim_nodes, self.victim_labels) @@ -247,3 +244,84 @@ def get_flipped_edges(self) -> Tensor: self.victim_nodes, self.victim_labels, ) + + +class GRBCDAttack(PRBCDAttack): + r"""Greedy Randomized Block Coordinate Descent (GRBCD) adversarial attack + from the `Robustness of Graph Neural Networks at Scale + `_ paper. + + GRBCD shares most of the properties and requirements with + :class:`PRBCDAttack`. It also uses an efficient gradient based approach. + However, it greedily flips edges based on the gradient towards the + adjacency matrix. + """ + def prepare(self, num_budgets: int, epochs: int) -> List[int]: + """Prepare attack.""" + + # Determine the number of edges to be flipped in each attach step/epoch + step_size = num_budgets // epochs + if step_size > 0: + steps = epochs * [step_size] + for i in range(num_budgets % epochs): + steps[i] += 1 + else: + steps = [1] * num_budgets + + # Sample initial search space (Algorithm 2, line 3-4) + self.sample_random_block(step_size) + + return steps + + def reset(self) -> "GRBCDAttack": + super().reset() + self.flipped_edges = self._edge_index.new_empty(2, 0) + return self + + @torch.no_grad() + def update( + self, + step_size: int, + gradient: Tensor, + num_budgets: int, + ) -> Dict[str, Any]: + """Update edge weights given gradient.""" + _, topk_edge_index = torch.topk(gradient, step_size) + + flip_edge_index = self.block_edge_index[:, topk_edge_index].to( + self.device) + flip_edge_weight = torch.ones(flip_edge_index.size(1), + device=self.device) + + self.flipped_edges = torch.cat((self.flipped_edges, flip_edge_index), + axis=-1) + + if self.is_undirected: + flip_edge_index, flip_edge_weight = to_undirected( + flip_edge_index, flip_edge_weight, num_nodes=self.num_nodes, + reduce='mean') + + edge_index = torch.cat((self._edge_index, flip_edge_index), dim=-1) + edge_weight = torch.cat((self._edge_weight, flip_edge_weight)) + + edge_index, edge_weight = coalesce(edge_index, edge_weight, + num_nodes=self.num_nodes, + reduce='sum') + + mask = torch.isclose(edge_weight, torch.tensor(1.)) + + self._edge_index = edge_index[:, mask] + self._edge_weight = edge_weight[mask] + + # Sample initial search space (Algorithm 2, line 3-4) + self.sample_random_block(step_size) + + # Return debug information + scalars = { + 'number_positive_entries_in_gradient': (gradient > 0).sum().item() + } + return scalars + + def get_flipped_edges(self) -> Tensor: + """Clean up and prepare return flipped edges.""" + return self.flipped_edges diff --git a/greatx/attack/untargeted/rbcd_attack.py b/greatx/attack/untargeted/rbcd_attack.py index 2b0f954..f5d5999 100644 --- a/greatx/attack/untargeted/rbcd_attack.py +++ b/greatx/attack/untargeted/rbcd_attack.py @@ -62,6 +62,11 @@ def compute_gradients( # Get prediction (Algorithm 1, line 6 / Algorithm 2, line 7) prediction = self.surrogate(feat, edge_index, edge_weight)[victim_nodes] + + # temperature scaling, work for cross-entropy loss + if self.tau != 1: + prediction /= self.tau + # Calculate loss combining all each node # (Algorithm 1, line 7 / Algorithm 2, line 8) loss = self.loss(prediction, victim_labels) From 733fd9ad53a82110fa2758a2a769bb9900f15deb Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Thu, 1 Dec 2022 11:44:09 +0800 Subject: [PATCH 12/15] Update --- greatx/attack/targeted/rbcd_attack.py | 90 ++++++--------- greatx/attack/untargeted/rbcd_attack.py | 140 +++++++++++++++--------- 2 files changed, 124 insertions(+), 106 deletions(-) diff --git a/greatx/attack/targeted/rbcd_attack.py b/greatx/attack/targeted/rbcd_attack.py index 87c0636..bc267b2 100644 --- a/greatx/attack/targeted/rbcd_attack.py +++ b/greatx/attack/targeted/rbcd_attack.py @@ -4,16 +4,10 @@ import torch from torch import Tensor from torch_geometric.utils import coalesce, to_undirected -from tqdm.auto import tqdm from greatx.attack.targeted.targeted_attacker import TargetedAttacker from greatx.attack.untargeted.rbcd_attack import RBCDAttack from greatx.attack.untargeted.utils import project -from greatx.functional import ( - masked_cross_entropy, - probability_margin_loss, - tanh_margin_loss, -) from greatx.nn.models.surrogate import Surrogate # (predictions, labels, ids/mask) -> Tensor with one element @@ -95,29 +89,6 @@ def attack( direct_attack=direct_attack, structure_attack=structure_attack, feature_attack=feature_attack) - - self.block_size = block_size - - assert loss in ['mce', 'prob_margin', 'tanh_margin'] - if loss == 'mce': - self.loss = masked_cross_entropy - elif loss == 'prob_margin': - self.loss = probability_margin_loss - else: - self.loss = tanh_margin_loss - - if metric is None: - self.metric = self.loss - else: - self.metric = metric - - self.epochs_resampling = epochs_resampling - self.lr = lr - - self.coeffs.update(**kwargs) - - num_budgets = self.num_budgets - feat = self.feat self.victim_nodes = torch.as_tensor( target, dtype=torch.long, @@ -126,34 +97,11 @@ def attack( self.victim_labels = self.target_label.view(-1) - feat, victim_nodes, victim_labels = (self.feat, self.victim_nodes, - self.victim_labels) - - # Loop over the epochs (Algorithm 1, line 5) - for step in tqdm(self.prepare(num_budgets, epochs), - desc='Peturbing graph...', disable=disable): - - loss, gradient = self.compute_gradients(feat, victim_labels, - victim_nodes) - - scalars = self.update(step, gradient, num_budgets) - - scalars['loss'] = loss.item() - self._append_statistics(scalars) - - flipped_edges = self.get_flipped_edges() - - assert flipped_edges.size(1) <= self.num_budgets, ( - f'# perturbed edges {flipped_edges.size(1)} ' - f'exceeds num_budgets {self.num_budgets}') - - for it, (u, v) in enumerate(zip(*flipped_edges.tolist())): - if self.adjacency_matrix[u, v] > 0: - self.remove_edge(u, v, it) - else: - self.add_edge(u, v, it) - - return self + return RBCDAttack.attack(self, self.num_budgets, block_size=block_size, + epochs=epochs, + epochs_resampling=epochs_resampling, + loss=loss, metric=metric, lr=lr, + disable=disable, **kwargs) def prepare(self, num_budgets: int, epochs: int) -> Iterable[int]: """Prepare attack and return the iterable sequence steps.""" @@ -256,6 +204,34 @@ class GRBCDAttack(PRBCDAttack): However, it greedily flips edges based on the gradient towards the adjacency matrix. """ + def attack( + self, + target, + *, + target_label=None, + num_budgets=None, + direct_attack=True, + block_size: int = 250_000, + epochs: int = 125, + epochs_resampling: int = 100, + loss: Optional[str] = 'mce', + metric: Optional[Union[str, METRIC]] = None, + lr: float = 1_000, + structure_attack: bool = True, + feature_attack: bool = False, + disable: bool = False, + **kwargs, + ) -> "GRBCDAttack": + + return super().attack(target=target, target_label=target_label, + direct_attack=direct_attack, + num_budgets=num_budgets, block_size=block_size, + epochs=epochs, + epochs_resampling=epochs_resampling, + metric=metric, loss=loss, lr=lr, disable=disable, + structure_attack=structure_attack, + feature_attack=feature_attack, **kwargs) + def prepare(self, num_budgets: int, epochs: int) -> List[int]: """Prepare attack.""" diff --git a/greatx/attack/untargeted/rbcd_attack.py b/greatx/attack/untargeted/rbcd_attack.py index f5d5999..8b44735 100644 --- a/greatx/attack/untargeted/rbcd_attack.py +++ b/greatx/attack/untargeted/rbcd_attack.py @@ -44,6 +44,69 @@ class RBCDAttack: 'eps': 1e-7 } + def attack( + self, + num_budgets: int, + block_size: int = 250_000, + epochs: int = 125, + epochs_resampling: int = 100, + loss: Optional[str] = 'tanh_margin', + metric: Optional[Union[str, METRIC]] = None, + lr: float = 2_000, + disable: bool = False, + **kwargs, + ) -> "RBCDAttack": + + self.block_size = block_size + + assert loss in ['mce', 'prob_margin', 'tanh_margin'] + if loss == 'mce': + self.loss = masked_cross_entropy + elif loss == 'prob_margin': + self.loss = probability_margin_loss + else: + self.loss = tanh_margin_loss + + if metric is None: + self.metric = self.loss + else: + self.metric = metric + + self.epochs_resampling = epochs_resampling + self.lr = lr + + self.coeffs.update(**kwargs) + + num_budgets = self.num_budgets + feat, victim_nodes, victim_labels = (self.feat, self.victim_nodes, + self.victim_labels) + + # Loop over the epochs (Algorithm 1, line 5) + for step in tqdm(self.prepare(num_budgets, epochs), + desc='Peturbing graph...', disable=disable): + + loss, gradient = self.compute_gradients(feat, victim_labels, + victim_nodes) + + scalars = self.update(step, gradient, num_budgets) + + scalars['loss'] = loss.item() + self._append_statistics(scalars) + + flipped_edges = self.get_flipped_edges() + + assert flipped_edges.size(1) <= self.num_budgets, ( + f'# perturbed edges {flipped_edges.size(1)} ' + f'exceeds num_budgets {self.num_budgets}') + + for it, (u, v) in enumerate(zip(*flipped_edges.tolist())): + if self.adjacency_matrix[u, v] > 0: + self.remove_edge(u, v, it) + else: + self.add_edge(u, v, it) + + return self + def compute_gradients( self, feat: Tensor, @@ -359,55 +422,11 @@ def attack( structure_attack=structure_attack, feature_attack=feature_attack) - self.block_size = block_size - - assert loss in ['mce', 'prob_margin', 'tanh_margin'] - if loss == 'mce': - self.loss = masked_cross_entropy - elif loss == 'prob_margin': - self.loss = probability_margin_loss - else: - self.loss = tanh_margin_loss - - if metric is None: - self.metric = self.loss - else: - self.metric = metric - - self.epochs_resampling = epochs_resampling - self.lr = lr - - self.coeffs.update(**kwargs) - - num_budgets = self.num_budgets - feat, victim_nodes, victim_labels = (self.feat, self.victim_nodes, - self.victim_labels) - - # Loop over the epochs (Algorithm 1, line 5) - for step in tqdm(self.prepare(num_budgets, epochs), - desc='Peturbing graph...', disable=disable): - - loss, gradient = self.compute_gradients(feat, victim_labels, - victim_nodes) - - scalars = self.update(step, gradient, num_budgets) - - scalars['loss'] = loss.item() - self._append_statistics(scalars) - - flipped_edges = self.get_flipped_edges() - - assert flipped_edges.size(1) <= self.num_budgets, ( - f'# perturbed edges {flipped_edges.size(1)} ' - f'exceeds num_budgets {self.num_budgets}') - - for it, (u, v) in enumerate(zip(*flipped_edges.tolist())): - if self.adjacency_matrix[u, v] > 0: - self.remove_edge(u, v, it) - else: - self.add_edge(u, v, it) - - return self + return RBCDAttack.attack(self, self.num_budgets, block_size=block_size, + epochs=epochs, + epochs_resampling=epochs_resampling, + loss=loss, metric=metric, lr=lr, + disable=disable, **kwargs) def prepare(self, num_budgets: int, epochs: int) -> Iterable[int]: """Prepare attack and return the iterable sequence steps.""" @@ -510,6 +529,29 @@ class GRBCDAttack(PRBCDAttack): However, it greedily flips edges based on the gradient towards the adjacency matrix. """ + def attack( + self, + num_budgets: Union[int, float] = 0.05, + *, + block_size: int = 250_000, + epochs: int = 125, + epochs_resampling: int = 100, + loss: Optional[str] = 'mce', + metric: Optional[Union[str, METRIC]] = None, + lr: float = 1_000, + structure_attack: bool = True, + feature_attack: bool = False, + disable: bool = False, + **kwargs, + ) -> "GRBCDAttack": + + return super().attack(num_budgets=num_budgets, block_size=block_size, + epochs=epochs, + epochs_resampling=epochs_resampling, + metric=metric, loss=loss, lr=lr, disable=disable, + structure_attack=structure_attack, + feature_attack=feature_attack, **kwargs) + def prepare(self, num_budgets: int, epochs: int) -> List[int]: """Prepare attack.""" From 521ce3e6a5a09adaf4791830a68d4966a7d00076 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Sat, 3 Dec 2022 21:22:28 +0800 Subject: [PATCH 13/15] Update --- examples/attack/untargeted/rbcd_attack.py | 90 +++++++++++++++++++++++ greatx/attack/untargeted/rbcd_attack.py | 71 +++++++----------- 2 files changed, 115 insertions(+), 46 deletions(-) create mode 100644 examples/attack/untargeted/rbcd_attack.py diff --git a/examples/attack/untargeted/rbcd_attack.py b/examples/attack/untargeted/rbcd_attack.py new file mode 100644 index 0000000..f6388e5 --- /dev/null +++ b/examples/attack/untargeted/rbcd_attack.py @@ -0,0 +1,90 @@ +import os.path as osp + +import torch +import torch_geometric.transforms as T + +from greatx.attack.untargeted import GRBCDAttack, PRBCDAttack +from greatx.datasets import GraphDataset +from greatx.nn.models import GCN +from greatx.training import Trainer +from greatx.training.callbacks import ModelCheckpoint +from greatx.utils import split_nodes + +dataset = 'Cora' +root = osp.join(osp.dirname(osp.realpath(__file__)), '../../..', 'data') +dataset = GraphDataset(root=root, name=dataset, + transform=T.LargestConnectedComponents()) + +data = dataset[0] +splits = split_nodes(data.y, random_state=15) + +num_features = data.x.size(-1) +num_classes = data.y.max().item() + 1 + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# ================================================================== # +# Before Attack # +# ================================================================== # +trainer_before = Trainer(GCN(num_features, num_classes), device=device) +ckp = ModelCheckpoint('model_before.pth', monitor='val_acc') +trainer_before.fit(data, mask=(splits.train_nodes, splits.val_nodes), + callbacks=[ckp]) +logs = trainer_before.evaluate(data, splits.test_nodes) +print(f"Before attack\n {logs}") + +# ================================================================== # +# Attacking (PRBCDAttack) # +# ================================================================== # +attacker = PRBCDAttack(data, device=device) +attacker.setup_surrogate( + trainer_before.model, + victim_nodes=splits.test_nodes, + # set True to use ground-truth labels + ground_truth=False, +) +attacker.reset() +attacker.attack(0.05) + +# ================================================================== # +# After evasion Attack # +# ================================================================== # +logs = trainer_before.evaluate(attacker.data(), splits.test_nodes) +print(f"After evasion attack\n {logs}") +# ================================================================== # +# After poisoning Attack # +# ================================================================== # +trainer_after = Trainer(GCN(num_features, num_classes), device=device) +ckp = ModelCheckpoint('model_after.pth', monitor='val_acc') +trainer_after.fit(attacker.data(), mask=(splits.train_nodes, splits.val_nodes), + callbacks=[ckp]) +logs = trainer_after.evaluate(attacker.data(), splits.test_nodes) +print(f"After poisoning attack\n {logs}") + +# ================================================================== # +# Attacking (GRBCDAttack) # +# ================================================================== # +attacker = GRBCDAttack(data, device=device) +attacker.setup_surrogate( + trainer_before.model, + victim_nodes=splits.test_nodes, + # set True to use ground-truth labels + ground_truth=False, +) +attacker.reset() +attacker.attack(0.05) + +# ================================================================== # +# After evasion Attack # +# ================================================================== # +logs = trainer_before.evaluate(attacker.data(), splits.test_nodes) +print(f"After evasion attack\n {logs}") +# ================================================================== # +# After poisoning Attack # +# ================================================================== # +trainer_after = Trainer(GCN(num_features, num_classes), device=device) +ckp = ModelCheckpoint('model_after.pth', monitor='val_acc') +trainer_after.fit(attacker.data(), mask=(splits.train_nodes, splits.val_nodes), + callbacks=[ckp]) +logs = trainer_after.evaluate(attacker.data(), splits.test_nodes) +print(f"After poisoning attack\n {logs}") diff --git a/greatx/attack/untargeted/rbcd_attack.py b/greatx/attack/untargeted/rbcd_attack.py index 8b44735..1a82ab3 100644 --- a/greatx/attack/untargeted/rbcd_attack.py +++ b/greatx/attack/untargeted/rbcd_attack.py @@ -46,7 +46,6 @@ class RBCDAttack: def attack( self, - num_budgets: int, block_size: int = 250_000, epochs: int = 125, epochs_resampling: int = 100, @@ -77,18 +76,13 @@ def attack( self.coeffs.update(**kwargs) - num_budgets = self.num_budgets - feat, victim_nodes, victim_labels = (self.feat, self.victim_nodes, - self.victim_labels) - # Loop over the epochs (Algorithm 1, line 5) - for step in tqdm(self.prepare(num_budgets, epochs), + for step in tqdm(self.prepare(self.num_budgets, epochs), desc='Peturbing graph...', disable=disable): - loss, gradient = self.compute_gradients(feat, victim_labels, - victim_nodes) + loss, gradient = self.compute_gradients() - scalars = self.update(step, gradient, num_budgets) + scalars = self.update(step, gradient) scalars['loss'] = loss.item() self._append_statistics(scalars) @@ -107,12 +101,7 @@ def attack( return self - def compute_gradients( - self, - feat: Tensor, - victim_labels: Tensor, - victim_nodes: Tensor, - ) -> Tuple[Tensor, Tensor]: + def compute_gradients(self) -> Tuple[Tensor, Tensor]: """Forward and update edge weights.""" self.block_edge_weight.requires_grad_() @@ -123,8 +112,8 @@ def compute_gradients( self.block_edge_weight) # Get prediction (Algorithm 1, line 6 / Algorithm 2, line 7) - prediction = self.surrogate(feat, edge_index, - edge_weight)[victim_nodes] + prediction = self.surrogate(self.feat, edge_index, + edge_weight)[self.victim_nodes] # temperature scaling, work for cross-entropy loss if self.tau != 1: @@ -132,7 +121,7 @@ def compute_gradients( # Calculate loss combining all each node # (Algorithm 1, line 7 / Algorithm 2, line 8) - loss = self.loss(prediction, victim_labels) + loss = self.loss(prediction, self.victim_labels) # Retrieve gradient towards the current block # (Algorithm 1, line 7 / Algorithm 2, line 8) gradient = torch.autograd.grad(loss, self.block_edge_weight)[0] @@ -242,16 +231,14 @@ def resample_random_block(self, num_budgets: int): "Please decrease `num_budgets`.") @torch.no_grad() - def sample_final_edges( - self, - feat: Tensor, - num_budgets: int, - victim_nodes: Tensor, - victim_labels: Tensor, - ) -> Tuple[Tensor, Tensor]: + def sample_final_edges(self) -> Tuple[Tensor, Tensor]: best_metric = float('-Inf') block_edge_weight = self.block_edge_weight block_edge_weight[block_edge_weight <= self.coeffs['eps']] = 0 + num_budgets = self.num_budgets + feat = self.feat + victim_nodes = self.victim_nodes + victim_labels = self.victim_labels for i in range(self.coeffs['max_final_samples']): if i == 0: @@ -283,13 +270,12 @@ def sample_final_edges( flipped_edges = self.block_edge_index[:, best_edge_weight != 0] return flipped_edges - def update_edge_weights(self, num_budgets: int, epoch: int, - gradient: Tensor): + def update_edge_weights(self, epoch: int, gradient: Tensor): # The learning rate is refined heuristically, s.t. (1) it is # independent of the number of perturbations (assuming an undirected # adjacency matrix) and (2) to decay learning rate during fine-tuning # (i.e. fixed search space). - lr = (num_budgets / self.num_nodes * self.lr / + lr = (self.num_budgets / self.num_nodes * self.lr / np.sqrt(max(0, epoch - self.epochs_resampling) + 1)) self.block_edge_weight.data.add_(lr * gradient) @@ -337,7 +323,7 @@ def setup_surrogate( self, surrogate: torch.nn.Module, victim_nodes: Tensor, - ground_truth: bool = True, + ground_truth: bool = False, *, tau: float = 1.0, freeze: bool = True, @@ -353,7 +339,7 @@ def setup_surrogate( ground_truth : bool, optional whether to use ground-truth label for victim nodes, if False, the node labels are estimated by the surrogate model, - by default True + by default False tau : float, optional the temperature of softmax activation, by default 1.0 freeze : bool, optional @@ -422,8 +408,7 @@ def attack( structure_attack=structure_attack, feature_attack=feature_attack) - return RBCDAttack.attack(self, self.num_budgets, block_size=block_size, - epochs=epochs, + return RBCDAttack.attack(self, block_size=block_size, epochs=epochs, epochs_resampling=epochs_resampling, loss=loss, metric=metric, lr=lr, disable=disable, **kwargs) @@ -437,17 +422,17 @@ def prepare(self, num_budgets: int, epochs: int) -> Iterable[int]: return range(epochs) @torch.no_grad() - def update(self, epoch: int, gradient: Tensor, - num_budgets: int) -> Dict[str, float]: + def update(self, epoch: int, gradient: Tensor) -> Dict[str, float]: """Update edge weights given gradient.""" # Gradient update step (Algorithm 1, line 7) - self.update_edge_weights(num_budgets, epoch, gradient) + self.update_edge_weights(epoch, gradient) # For monitoring pmass_update = torch.clamp(self.block_edge_weight, 0, 1) # Projection to stay within relaxed `L_0` num_budgets # (Algorithm 1, line 8) - self.block_edge_weight = project(num_budgets, self.block_edge_weight, + self.block_edge_weight = project(self.num_budgets, + self.block_edge_weight, self.coeffs['eps']) # For monitoring @@ -468,7 +453,7 @@ def update(self, epoch: int, gradient: Tensor, topk_block_edge_weight = torch.zeros_like(self.block_edge_weight) topk_block_edge_weight[torch.topk(self.block_edge_weight, - num_budgets).indices] = 1 + self.num_budgets).indices] = 1 edge_index, edge_weight = self.get_modified_graph( self._edge_index, self._edge_weight, self.block_edge_index, @@ -488,7 +473,7 @@ def update(self, epoch: int, gradient: Tensor, # Resampling of search space (Algorithm 1, line 9-14) if epoch < self.epochs_resampling - 1: - self.resample_random_block(num_budgets) + self.resample_random_block(self.num_budgets) elif epoch == self.epochs_resampling - 1: # Retrieve best epoch if early stopping is active # (not explicitly covered by pseudo code) @@ -511,12 +496,7 @@ def get_flipped_edges(self) -> Tensor: self.block_edge_weight = self.best_pert_edge_weight.to(self.device) # Sample final discrete graph (Algorithm 1, line 16) - return self.sample_final_edges( - self.feat, - self.num_budgets, - self.victim_nodes, - self.victim_labels, - ) + return self.sample_final_edges() class GRBCDAttack(PRBCDAttack): @@ -538,7 +518,7 @@ def attack( epochs_resampling: int = 100, loss: Optional[str] = 'mce', metric: Optional[Union[str, METRIC]] = None, - lr: float = 1_000, + lr: float = 2_000, structure_attack: bool = True, feature_attack: bool = False, disable: bool = False, @@ -579,7 +559,6 @@ def update( self, step_size: int, gradient: Tensor, - num_budgets: int, ) -> Dict[str, Any]: """Update edge weights given gradient.""" _, topk_edge_index = torch.topk(gradient, step_size) From 76c3fb09dec02b7f9f171cde3923a18fe2b5f977 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Sat, 3 Dec 2022 21:32:31 +0800 Subject: [PATCH 14/15] Update --- examples/attack/targeted/rbcd_attack.py | 93 +++++++++++++++++++++++++ greatx/attack/targeted/rbcd_attack.py | 23 +++--- 2 files changed, 101 insertions(+), 15 deletions(-) create mode 100644 examples/attack/targeted/rbcd_attack.py diff --git a/examples/attack/targeted/rbcd_attack.py b/examples/attack/targeted/rbcd_attack.py new file mode 100644 index 0000000..3b443ea --- /dev/null +++ b/examples/attack/targeted/rbcd_attack.py @@ -0,0 +1,93 @@ +import os.path as osp + +import torch +import torch_geometric.transforms as T + +from greatx.attack.targeted import GRBCDAttack, PRBCDAttack +from greatx.datasets import GraphDataset +from greatx.nn.models import GCN +from greatx.training import Trainer +from greatx.training.callbacks import ModelCheckpoint +from greatx.utils import mark, split_nodes + +dataset = 'Cora' +root = osp.join(osp.dirname(osp.realpath(__file__)), '../../..', 'data') +dataset = GraphDataset(root=root, name=dataset, + transform=T.LargestConnectedComponents()) + +data = dataset[0] +splits = split_nodes(data.y, random_state=15) + +num_features = data.x.size(-1) +num_classes = data.y.max().item() + 1 + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# ================================================================== # +# Attack Setting # +# ================================================================== # +target = 1 # target node to attack +target_label = data.y[target].item() + +# ================================================================== # +# Before Attack # +# ================================================================== # +trainer_before = Trainer(GCN(num_features, num_classes), device=device) +ckp = ModelCheckpoint('model_before.pth', monitor='val_acc') +trainer_before.fit(data, mask=(splits.train_nodes, splits.val_nodes), + callbacks=[ckp]) +output = trainer_before.predict(data, mask=target) +print("Before attack:") +print(mark(output, target_label)) + +# ================================================================== # +# Attacking (PRBCDAttack) # +# ================================================================== # +attacker = PRBCDAttack(data, device=device) +attacker.setup_surrogate(trainer_before.model) +attacker.reset() +attacker.attack(target) + +# ================================================================== # +# After evasion Attack # +# ================================================================== # +output = trainer_before.predict(attacker.data(), mask=target) +print("After evasion attack:") +print(mark(output, target_label)) + +# ================================================================== # +# After poisoning Attack # +# ================================================================== # +trainer_after = Trainer(GCN(num_features, num_classes), device=device) +ckp = ModelCheckpoint('model_after.pth', monitor='val_acc') +trainer_after.fit(attacker.data(), mask=(splits.train_nodes, splits.val_nodes), + callbacks=[ckp]) +output = trainer_after.predict(attacker.data(), mask=target) +print("After poisoning attack:") +print(mark(output, target_label)) + +# ================================================================== # +# Attacking (GRBCDAttack) # +# ================================================================== # +attacker = GRBCDAttack(data, device=device) +attacker.setup_surrogate(trainer_before.model) +attacker.reset() +attacker.attack(target) + +# ================================================================== # +# After evasion Attack # +# ================================================================== # +output = trainer_before.predict(attacker.data(), mask=target) +print("After evasion attack:") +print(mark(output, target_label)) + +# ================================================================== # +# After poisoning Attack # +# ================================================================== # +trainer_after = Trainer(GCN(num_features, num_classes), device=device) +ckp = ModelCheckpoint('model_after.pth', monitor='val_acc') +trainer_after.fit(attacker.data(), mask=(splits.train_nodes, splits.val_nodes), + callbacks=[ckp]) +output = trainer_after.predict(attacker.data(), mask=target) +print("After poisoning attack:") +print(mark(output, target_label)) diff --git a/greatx/attack/targeted/rbcd_attack.py b/greatx/attack/targeted/rbcd_attack.py index bc267b2..adb862e 100644 --- a/greatx/attack/targeted/rbcd_attack.py +++ b/greatx/attack/targeted/rbcd_attack.py @@ -97,8 +97,7 @@ def attack( self.victim_labels = self.target_label.view(-1) - return RBCDAttack.attack(self, self.num_budgets, block_size=block_size, - epochs=epochs, + return RBCDAttack.attack(self, block_size=block_size, epochs=epochs, epochs_resampling=epochs_resampling, loss=loss, metric=metric, lr=lr, disable=disable, **kwargs) @@ -112,17 +111,17 @@ def prepare(self, num_budgets: int, epochs: int) -> Iterable[int]: return range(epochs) @torch.no_grad() - def update(self, epoch: int, gradient: Tensor, - num_budgets: int) -> Dict[str, float]: + def update(self, epoch: int, gradient: Tensor) -> Dict[str, float]: """Update edge weights given gradient.""" # Gradient update step (Algorithm 1, line 7) - self.update_edge_weights(num_budgets, epoch, gradient) + self.update_edge_weights(epoch, gradient) # For monitoring pmass_update = torch.clamp(self.block_edge_weight, 0, 1) # Projection to stay within relaxed `L_0` num_budgets # (Algorithm 1, line 8) - self.block_edge_weight = project(num_budgets, self.block_edge_weight, + self.block_edge_weight = project(self.num_budgets, + self.block_edge_weight, self.coeffs['eps']) # For monitoring @@ -143,7 +142,7 @@ def update(self, epoch: int, gradient: Tensor, topk_block_edge_weight = torch.zeros_like(self.block_edge_weight) topk_block_edge_weight[torch.topk(self.block_edge_weight, - num_budgets).indices] = 1 + self.num_budgets).indices] = 1 edge_index, edge_weight = self.get_modified_graph( self._edge_index, self._edge_weight, self.block_edge_index, @@ -163,7 +162,7 @@ def update(self, epoch: int, gradient: Tensor, # Resampling of search space (Algorithm 1, line 9-14) if epoch < self.epochs_resampling - 1: - self.resample_random_block(num_budgets) + self.resample_random_block(self.num_budgets) elif epoch == self.epochs_resampling - 1: # Retrieve best epoch if early stopping is active # (not explicitly covered by pseudo code) @@ -186,12 +185,7 @@ def get_flipped_edges(self) -> Tensor: self.block_edge_weight = self.best_pert_edge_weight.to(self.device) # Sample final discrete graph (Algorithm 1, line 16) - return self.sample_final_edges( - self.feat, - self.num_budgets, - self.victim_nodes, - self.victim_labels, - ) + return self.sample_final_edges() class GRBCDAttack(PRBCDAttack): @@ -259,7 +253,6 @@ def update( self, step_size: int, gradient: Tensor, - num_budgets: int, ) -> Dict[str, Any]: """Update edge weights given gradient.""" _, topk_edge_index = torch.topk(gradient, step_size) From a486b21e49a2f74d83a9393bb719b0305c8bd7c2 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Fri, 3 Mar 2023 09:56:02 +0800 Subject: [PATCH 15/15] Update --- greatx/attack/targeted/rbcd_attack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/greatx/attack/targeted/rbcd_attack.py b/greatx/attack/targeted/rbcd_attack.py index adb862e..224aa41 100644 --- a/greatx/attack/targeted/rbcd_attack.py +++ b/greatx/attack/targeted/rbcd_attack.py @@ -54,7 +54,7 @@ def reset(self) -> "PRBCDAttack": self.victim_labels = None # NOTE: Since `edge_index` and `edge_weight` denote the original graph - # here we need to name them as `edge_index`and `_edge_weight` + # here we need to name them as `_edge_index`and `_edge_weight` self._edge_index = self.edge_index self._edge_weight = torch.ones(self.num_edges, device=self.device)