Source code for greatx.utils.cka

from functools import partial
from typing import Dict, List
from warnings import warn

import torch
import torch.nn as nn
from torch import Tensor
from import Data

[docs]class CKA: """Centered Kernel Alignment (CKA) metric, where the features of the networks are compared. Parameters ---------- model1 : nn.Module model 1 model2 : nn.Module model 2 model1_name : str, optional name of model 1, by default None model2_name : str, optional name of model 2, by default None model1_layers : List[str], optional List of layers to extract features from, by default None model2_layers : List[str], optional List of layers to extract features from, by default None training : bool, optional whether to set training mode (True) or evaluation mode (False) for models. by default False. device : str, optional device to run the models, by default 'cpu' Example ------- .. code-block:: python data = ... # get your graph m1 = ... # get your model1 m2 = ... # get your model2 cka = CKA(m1, m2) cka.plot_results() Reference: * Paper: * Code: """ def __init__(self, model1: nn.Module, model2: nn.Module, model1_name: str = None, model2_name: str = None, model1_layers: List[str] = None, model2_layers: List[str] = None, training: bool = False, device: str = 'cpu'): self.model1 = model1 self.model2 = model2 self.device = torch.device(device) self.model1_info = {} self.model2_info = {} if model1_name is None: self.model1_info['Name'] = model1.__repr__().split('(')[0] else: self.model1_info['Name'] = model1_name if model2_name is None: self.model2_info['Name'] = model2.__repr__().split('(')[0] else: self.model2_info['Name'] = model2_name if self.model1_info['Name'] == self.model2_info['Name']: warn("Both model have identical names - " f"{self.model2_info['Name']}. " "It may cause confusion when interpreting the results. " "Consider giving unique names to the models :)") self.model1_info['Layers'] = [] self.model2_info['Layers'] = [] self.model1_features = {} self.model2_features = {} self.model1_layers = model1_layers self.model2_layers = model2_layers self._insert_hooks() self.model1 = self.model2 = self.model1.train(training) self.model2.train(training) def _log_layer(self, model: str, name: str, layer: nn.Module, inp: Tensor, out: Tensor): if out.ndim != 2: # ignore those features that dimensions not equal to 2 return if model == "model1": self.model1_features[name] = out elif model == "model2": self.model2_features[name] = out else: raise RuntimeError(f"Unknown model name `{model}`.") def _insert_hooks(self): # Model 1 for name, layer in self.model1.named_modules(): if self.model1_layers is not None: if name in self.model1_layers: self.model1_info['Layers'] += [name] layer.register_forward_hook( partial(self._log_layer, "model1", name)) else: self.model1_info['Layers'] += [name] layer.register_forward_hook( partial(self._log_layer, "model1", name)) # Model 2 for name, layer in self.model2.named_modules(): if self.model2_layers is not None: if name in self.model2_layers: self.model2_info['Layers'] += [name] layer.register_forward_hook( partial(self._log_layer, "model2", name)) else: self.model2_info['Layers'] += [name] layer.register_forward_hook( partial(self._log_layer, "model2", name)) def _HSIC(self, K, L): """Computes the unbiased estimate of HSIC metric. Reference: Eq (3) """ N = K.shape[0] ones = torch.ones(N, 1).to(self.device) result = torch.trace(K @ L) result += ((ones.t() @ K @ ones @ ones.t() @ L @ ones) / ((N - 1) * (N - 2))).item() result -= ((ones.t() @ K @ L @ ones) * 2 / (N - 2)).item() result = (1 / (N * (N - 3)) * result).item() return result
[docs] @torch.no_grad() def compare(self, data1: Data, data2: Data = None) -> None: """ Computes the feature similarity between the models on the given datasets. Parameters ---------- data1 : Data the dataset where model 1 run on. data2 : Data, optional If given, model 2 will run on this dataset. by default None """ if data2 is None: warn("Data for Model 2 is not given. " "Using the same data for both models.") data2 = data1 else: data2 = self.model1_features = {} self.model2_features = {} self.model1(data1.x, data1.edge_index, data1.edge_weight) self.model2(data2.x, data2.edge_index, data2.edge_weight) N = len(self.model1_layers) if self.model1_layers is not None else len( self.model1_features) M = len(self.model2_layers) if self.model2_layers is not None else len( self.model2_features) num_batches = 1 self.hsic_matrix = torch.zeros(N, M, 3) for i, (name1, feat1) in enumerate(self.model1_features.items()): X = feat1.flatten(1) K = X @ X.t() K.fill_diagonal_(0.0) self.hsic_matrix[i, :, 0] += self._HSIC(K, K) / num_batches for j, (name2, feat2) in enumerate(self.model2_features.items()): Y = feat2.flatten(1) L = Y @ Y.t() L.fill_diagonal_(0) if K.shape != L.shape: raise RuntimeError( f"Feature shape mistach! {K.shape} and {L.shape}") self.hsic_matrix[i, j, 1] += self._HSIC(K, L) / num_batches self.hsic_matrix[i, j, 2] += self._HSIC(L, L) / num_batches self.hsic_matrix = self.hsic_matrix[:, :, 1] / ( self.hsic_matrix[:, :, 0].sqrt() * self.hsic_matrix[:, :, 2].sqrt()) assert not torch.isnan( self.hsic_matrix).any(), "HSIC computation resulted in NANs" return self
[docs] def export(self) -> Dict: """ Exports the CKA data along with the respective model layer names. :return: """ return { "model1_name": self.model1_info['Name'], "model2_name": self.model2_info['Name'], "CKA": self.hsic_matrix, "model1_layers": self.model1_info['Layers'], "model2_layers": self.model2_info['Layers'], }
[docs] def plot_results(self, save_path: str = None, title: str = None): import matplotlib.pyplot as plt fig, ax = plt.subplots() im = ax.imshow(self.hsic_matrix, origin='lower', cmap='magma') ax.set_xlabel(f"Layers of {self.model2_info['Name']}", fontsize=15) ax.set_ylabel(f"Layers of {self.model1_info['Name']}", fontsize=15) if title is not None: ax.set_title(f"{title}", fontsize=18) else: ax.set_title( f"{self.model1_info['Name']} vs {self.model2_info['Name']}", fontsize=18) add_colorbar(im) plt.tight_layout() if save_path is not None: plt.savefig(save_path, dpi=300)
def add_colorbar(im, aspect=10, pad_fraction=0.5, **kwargs): """Add a vertical color bar to an image plot.""" import matplotlib.pyplot as plt from mpl_toolkits import axes_grid1 divider = axes_grid1.make_axes_locatable(im.axes) width = axes_grid1.axes_size.AxesY(im.axes, aspect=1. / aspect) pad = axes_grid1.axes_size.Fraction(pad_fraction, width) current_ax = plt.gca() cax = divider.append_axes("right", size=width, pad=pad) return im.axes.figure.colorbar(im, cax=cax, **kwargs)