from typing import Optional
import torch
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.transforms import BaseTransform
from greatx.functional import spmm
[docs]class FeaturePropagation(BaseTransform):
r"""Implementation of FeaturePropagation
from the `"On the Unreasonable Effectiveness
of Feature propagation in Learning
on Graphs with Missing Node Features"
<https://openreview.net/forum?id=qe_qOarxjg>`_
paper (Log'22)
Parameters
----------
num_iterations : int, optional
number of iterations to run, by default 40
missing_mask : Optional[Tensor], optional
mask on missing features, by default None
normalize : bool, optional
whether to compute symmetric normalization
coefficients on the fly, by default True
add_self_loops : bool, optional
whether to add self-loops to the input graph,
by default True
Example
-------
.. code-block:: python
data = ... # PyG-like data
data = FeaturePropagation(num_iterations=40)(data)
# missing_mask is a mask `[num_nodes, num_features]`
# indicating where the feature is missing
data = FeaturePropagation(missing_mask=missing_mask)(data)
See also
--------
:class:`torch_geometric.transforms.FeaturePropagation`
Reference:
https://github.com/twitter-research/feature-propagation
"""
def __init__(self, missing_mask: Optional[Tensor] = None,
num_iterations: int = 40, normalize: bool = True):
super().__init__()
self.missing_mask = missing_mask
self.num_iterations = num_iterations
self.normalize = normalize
def __call__(self, data: Data) -> Data:
# out is inizialized to 0 for missing values.
# However, its initialization does not
# matter for the final value at convergence
x = data.x
known_feature_mask = None
missing_mask = self.missing_mask
if missing_mask is not None:
out = torch.zeros_like(x)
known_feature_mask = ~missing_mask
out[known_feature_mask] = x[known_feature_mask]
else:
out = x.clone()
edge_index, edge_weight = data.edge_index, data.edge_weight
if self.normalize:
edge_index, edge_weight = gcn_norm(edge_index, edge_weight,
x.size(0), improved=False,
add_self_loops=False,
dtype=x.dtype)
for _ in range(self.num_iterations):
# Diffuse current features
out = spmm(out, edge_index, edge_weight)
if known_feature_mask is not None:
# Reset original known features
out[known_feature_mask] = x[known_feature_mask]
data.x = out
return data