from functools import partial
from typing import Dict, List
from warnings import warn
import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.data import Data
[docs]class CKA:
"""Centered Kernel Alignment (CKA) metric,
where the features of the networks are compared.
Parameters
----------
model1 : nn.Module
model 1
model2 : nn.Module
model 2
model1_name : str, optional
name of model 1, by default None
model2_name : str, optional
name of model 2, by default None
model1_layers : List[str], optional
List of layers to extract features from, by default None
model2_layers : List[str], optional
List of layers to extract features from, by default None
training : bool, optional
whether to set training mode (True) or evaluation
mode (False) for models. by default False.
device : str, optional
device to run the models, by default 'cpu'
Example
-------
.. code-block:: python
data = ... # get your graph
m1 = ... # get your model1
m2 = ... # get your model2
cka = CKA(m1, m2)
cka.compare(data)
cka.plot_results()
Reference:
* Paper: https://arxiv.org/abs/2010.15327
* Code: https://github.com/AntixK/PyTorch-Model-Compare
"""
def __init__(self, model1: nn.Module, model2: nn.Module,
model1_name: str = None, model2_name: str = None,
model1_layers: List[str] = None,
model2_layers: List[str] = None, training: bool = False,
device: str = 'cpu'):
self.model1 = model1
self.model2 = model2
self.device = torch.device(device)
self.model1_info = {}
self.model2_info = {}
if model1_name is None:
self.model1_info['Name'] = model1.__repr__().split('(')[0]
else:
self.model1_info['Name'] = model1_name
if model2_name is None:
self.model2_info['Name'] = model2.__repr__().split('(')[0]
else:
self.model2_info['Name'] = model2_name
if self.model1_info['Name'] == self.model2_info['Name']:
warn("Both model have identical names - "
f"{self.model2_info['Name']}. "
"It may cause confusion when interpreting the results. "
"Consider giving unique names to the models :)")
self.model1_info['Layers'] = []
self.model2_info['Layers'] = []
self.model1_features = {}
self.model2_features = {}
self.model1_layers = model1_layers
self.model2_layers = model2_layers
self._insert_hooks()
self.model1 = self.model1.to(self.device)
self.model2 = self.model2.to(self.device)
self.model1.train(training)
self.model2.train(training)
def _log_layer(self, model: str, name: str, layer: nn.Module, inp: Tensor,
out: Tensor):
if out.ndim != 2:
# ignore those features that dimensions not equal to 2
return
if model == "model1":
self.model1_features[name] = out
elif model == "model2":
self.model2_features[name] = out
else:
raise RuntimeError(f"Unknown model name `{model}`.")
def _insert_hooks(self):
# Model 1
for name, layer in self.model1.named_modules():
if self.model1_layers is not None:
if name in self.model1_layers:
self.model1_info['Layers'] += [name]
layer.register_forward_hook(
partial(self._log_layer, "model1", name))
else:
self.model1_info['Layers'] += [name]
layer.register_forward_hook(
partial(self._log_layer, "model1", name))
# Model 2
for name, layer in self.model2.named_modules():
if self.model2_layers is not None:
if name in self.model2_layers:
self.model2_info['Layers'] += [name]
layer.register_forward_hook(
partial(self._log_layer, "model2", name))
else:
self.model2_info['Layers'] += [name]
layer.register_forward_hook(
partial(self._log_layer, "model2", name))
def _HSIC(self, K, L):
"""Computes the unbiased estimate of HSIC metric.
Reference: https://arxiv.org/pdf/2010.15327.pdf Eq (3)
"""
N = K.shape[0]
ones = torch.ones(N, 1).to(self.device)
result = torch.trace(K @ L)
result += ((ones.t() @ K @ ones @ ones.t() @ L @ ones) /
((N - 1) * (N - 2))).item()
result -= ((ones.t() @ K @ L @ ones) * 2 / (N - 2)).item()
result = (1 / (N * (N - 3)) * result).item()
return result
[docs] @torch.no_grad()
def compare(self, data1: Data, data2: Data = None) -> None:
"""
Computes the feature similarity between the models on the
given datasets.
Parameters
----------
data1 : Data
the dataset where model 1 run on.
data2 : Data, optional
If given, model 2 will run on this dataset. by default None
"""
data1.to(self.device)
if data2 is None:
warn("Data for Model 2 is not given. "
"Using the same data for both models.")
data2 = data1
else:
data2 = data2.to(self.device)
self.model1_features = {}
self.model2_features = {}
self.model1(data1.x, data1.edge_index, data1.edge_weight)
self.model2(data2.x, data2.edge_index, data2.edge_weight)
N = len(self.model1_layers) if self.model1_layers is not None else len(
self.model1_features)
M = len(self.model2_layers) if self.model2_layers is not None else len(
self.model2_features)
num_batches = 1
self.hsic_matrix = torch.zeros(N, M, 3)
for i, (name1, feat1) in enumerate(self.model1_features.items()):
X = feat1.flatten(1)
K = X @ X.t()
K.fill_diagonal_(0.0)
self.hsic_matrix[i, :, 0] += self._HSIC(K, K) / num_batches
for j, (name2, feat2) in enumerate(self.model2_features.items()):
Y = feat2.flatten(1)
L = Y @ Y.t()
L.fill_diagonal_(0)
if K.shape != L.shape:
raise RuntimeError(
f"Feature shape mistach! {K.shape} and {L.shape}")
self.hsic_matrix[i, j, 1] += self._HSIC(K, L) / num_batches
self.hsic_matrix[i, j, 2] += self._HSIC(L, L) / num_batches
self.hsic_matrix = self.hsic_matrix[:, :, 1] / (
self.hsic_matrix[:, :, 0].sqrt() *
self.hsic_matrix[:, :, 2].sqrt())
assert not torch.isnan(
self.hsic_matrix).any(), "HSIC computation resulted in NANs"
return self
[docs] def export(self) -> Dict:
"""
Exports the CKA data along with the respective model layer names.
:return:
"""
return {
"model1_name": self.model1_info['Name'],
"model2_name": self.model2_info['Name'],
"CKA": self.hsic_matrix,
"model1_layers": self.model1_info['Layers'],
"model2_layers": self.model2_info['Layers'],
}
[docs] def plot_results(self, save_path: str = None, title: str = None):
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
im = ax.imshow(self.hsic_matrix, origin='lower', cmap='magma')
ax.set_xlabel(f"Layers of {self.model2_info['Name']}", fontsize=15)
ax.set_ylabel(f"Layers of {self.model1_info['Name']}", fontsize=15)
if title is not None:
ax.set_title(f"{title}", fontsize=18)
else:
ax.set_title(
f"{self.model1_info['Name']} vs {self.model2_info['Name']}",
fontsize=18)
add_colorbar(im)
plt.tight_layout()
if save_path is not None:
plt.savefig(save_path, dpi=300)
plt.show()
def add_colorbar(im, aspect=10, pad_fraction=0.5, **kwargs):
"""Add a vertical color bar to an image plot."""
import matplotlib.pyplot as plt
from mpl_toolkits import axes_grid1
divider = axes_grid1.make_axes_locatable(im.axes)
width = axes_grid1.axes_size.AxesY(im.axes, aspect=1. / aspect)
pad = axes_grid1.axes_size.Fraction(pad_fraction, width)
current_ax = plt.gca()
cax = divider.append_axes("right", size=width, pad=pad)
plt.sca(current_ax)
return im.axes.figure.colorbar(im, cax=cax, **kwargs)