Source code for greatx.nn.models.unsupervised.grace

from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.nn import GCNConv

from greatx.nn.layers import GCNConv, Sequential, activations
from greatx.nn.models import MLP
from greatx.utils import wrapper

try:
    from torch_geometric.utils import dropout_edge, mask_feature
except ImportError:
    dropout_edge = mask_feature = None


[docs]class GRACE(torch.nn.Module): r"""GRAph Contrastive rEpresentation learning (GRACE) from the `"Deep Graph Contrastive Representation Learning" <https://arxiv.org/abs/2006.04131>`_ paper (ICML'20) Parameters ---------- in_channels : int, the input dimensions of model hids : List[int], optional the number of hidden units for each hidden layer, by default [128] acts : List[str], optional the activation function for each hidden layer, by default ['prelu'] project_hids : List[int], optional the projection dimensions of model, by default [128] tau : float, optional the temperature coefficient of softmax, by default 0.5 dropout : float, optional the dropout ratio of model, by default 0.0 bias : bool, optional whether to use bias in the layers, by default True bn: bool, optional whether to use :class:`BatchNorm1d` after the convolution layer, by default False drop_edge1 : float, optional the dropout ratio of edges for the first view, by default 0.8 drop_edge1 : float, optional the dropout ratio of edges for the second view, by default 0.7 drop_feat1 : float, optional the dropout ratio of features for the first view, by default 0.4 drop_feat2 : float, optional the dropout ratio of features for the second view, by default 0.3 Examples -------- >>> # GRACE with one hidden layer >>> model = GRACE(100) >>> # GRACE with two hidden layers >>> model = GRACE(100, hids=[32, 16], acts=['relu', 'elu']) >>> # GRACE with two hidden layers, without first activation >>> model = GRACE(100, hids=[32, 16], acts=[None, 'relu']) >>> # GRACE with deep architectures, each layer has elu activation >>> model = GRACE(100, hids=[16]*8, acts=['elu']) Reference: * Author's code: https://github.com/CRIPAC-DIG/GRACE """ @wrapper def __init__( self, in_channels: int, hids: List[int] = [128], acts: List[str] = ['prelu'], project_hids: List[int] = [128], dropout: float = 0., tau: float = 0.5, bias: bool = True, bn: bool = False, drop_edge1: float = 0.8, drop_edge2: float = 0.7, drop_feat1: float = 0.4, drop_feat2: float = 0.3, ): super().__init__() if dropout_edge is None: # TODO: support them raise ImportError( "Please install the latest version of `torch_geometric`.") encoder = [] for hid, act in zip(hids, acts): encoder.append(GCNConv(in_channels, hid, bias=bias)) if bn: encoder.append(nn.BatchNorm1d(hid)) encoder.append(activations.get(act)) encoder.append(nn.Dropout(dropout)) in_channels = hid self.encoder = Sequential(*encoder) self.decoder = MLP(in_channels, in_channels, project_hids, dropout=0.) self.tau = tau self.drop_edge1 = drop_edge1 self.drop_edge2 = drop_edge2 self.drop_feat1 = drop_feat1 self.drop_feat2 = drop_feat2
[docs] def encode( self, x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None, ) -> Tensor: z = self.encoder(x, edge_index, edge_weight) return z
[docs] def forward( self, x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None, ) -> Tensor: """""" edge_index1, mask1 = dropout_edge(edge_index, p=self.drop_edge1) edge_index2, mask2 = dropout_edge(edge_index, p=self.drop_edge2) if edge_weight is not None: edge_weight1 = edge_weight[mask1] edge_weight2 = edge_weight[mask2] else: edge_weight1 = edge_weight2 = None x1 = mask_feature(x, self.drop_feat1)[0] x2 = mask_feature(x, self.drop_feat2)[0] h1 = self.encoder(x1, edge_index1, edge_weight1) h2 = self.encoder(x2, edge_index2, edge_weight2) z1 = self.decoder(h1) z2 = self.decoder(h2) return z1, z2
[docs] def sim(self, z1: Tensor, z2: Tensor) -> Tensor: z1 = F.normalize(z1) z2 = F.normalize(z2) return torch.mm(z1, z2.t())
[docs] def loss(self, z1: Tensor, z2: Tensor) -> Tensor: refl_sim = torch.exp(self.sim(z1, z1) / self.tau) between_sim = torch.exp(self.sim(z1, z2) / self.tau) a = between_sim.diag() b = refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag() c = -torch.log(a / b) return c.mean()