Source code for greatx.defense.gnnguard

import warnings

import torch
import torch.nn.functional as F
from torch_geometric.utils import scatter

EPS = 1e-10


[docs]class GNNGUARD(torch.nn.Module): r"""Implementation of GNNGUARD from the `"GNNGUARD: Defending Graph Neural Networks against Adversarial Attacks" <https://arxiv.org/abs/2006.08149>`_ paper (NeurIPS'20) Parameters ---------- threshold : float, optional threshold for removing edges based on attention scores, by default 0.1 add_self_loops : bool, optional whether to add self-loops to the input graph, by default False """ def __init__(self, threshold: float = 0.1, add_self_loops: bool = False): super().__init__() self.threshold = threshold self.add_self_loops = add_self_loops
[docs] def forward(self, x, edge_index, edge_weight=None): """""" if edge_weight is not None: warnings.warn("`edge_weight` is supported in GNNGUARD " "and will be ignored for computation.") row, col = edge_index A, B = x[row], x[col] att_score = F.cosine_similarity(A, B) mask = att_score >= self.threshold edge_index = edge_index[:, mask] att_score = att_score[mask] row, col = edge_index row_sum = scatter(att_score, col, dim_size=x.size(0)) att_score_norm = att_score / (row_sum[row] + EPS) if self.add_self_loops: degree = scatter(torch.ones_like(att_score_norm), col, dim_size=x.size(0)) self_weight = 1.0 / (degree + 1) att_score_norm = torch.cat([att_score_norm, self_weight]) loop_index = torch.arange(0, x.size(0), dtype=torch.long, device=edge_index.device) loop_index = loop_index.unsqueeze(0).repeat(2, 1) edge_index = torch.cat([edge_index, loop_index], dim=1) att_score_norm = att_score_norm.exp() return edge_index, att_score_norm
[docs] def extra_repr(self) -> str: return f"threshold={self.threshold}"