# Explainable AI for Perses

In this notebook we'll use [Captum](https://captum.ai/) from the Facebook AI team to inspect our perses model.

Before we get going, make sure you have captum installed. See here for more information: https://captum.ai/#quickstart

## Basic setup

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
%matplotlib inline

In [None]:
import torch
import torch.nn as nn

import torchvision

In [None]:
from captum.attr import IntegratedGradients
from captum.attr import GradientShap
from captum.attr import Occlusion
from captum.attr import NoiseTunnel
from captum.attr import visualization as viz

In [None]:
from dnet_dataset.src.dnet_dataloader import DamageNetDataset
from dnet import Net as Perses

Matplotlib color scheme:

In [None]:
default_cmap = LinearSegmentedColormap.from_list('custom blue', 
                                                 [(0, '#ffffff'),
                                                  (0.5, '#000000'),
                                                  (1, '#000000')], N=256)

## Dataset

In [None]:
IMAGES_DIR = 'data/images/'
LABELS_DIR = 'data/labels/'
PERSES_MODEL = 'model/perses.pt'

In [None]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((75, 75)),
    torchvision.transforms.ToTensor()
])

In [None]:
dataset = DamageNetDataset(images_dir=IMAGES_DIR, labels_dir=LABELS_DIR, transform=transforms)

In [None]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)

## Perses

Load and initialize perses.

In [None]:
perses = Perses()
perses.load_state_dict(torch.load(PERSES_MODEL))
perses.eval()

## Explainable AI

In this section we'll define our explanation functions. These will help us interpret the model by visualizing important features of the image, that influenced the model most.

The following functions are adapted from the [captum introductory tutorial](https://captum.ai/tutorials/Resnet_TorchVision_Interpret) 

### Gradient-based approach

Working back up the gradients, i.e. differentiating.

In [None]:
def gradient_explain(img, largest_idx):
    integrated_gradients = IntegratedGradients(perses)
    attributions_ig = integrated_gradients.attribute(img, target=largest_idx, n_steps=200)
    
    noise_tunnel = NoiseTunnel(integrated_gradients)

    attributions_ig_nt = noise_tunnel.attribute(img, n_samples=10, nt_type='smoothgrad_sq', target=largest_idx)
    _ = viz.visualize_image_attr_multiple(np.transpose(attributions_ig_nt.squeeze().cpu().detach().numpy(), (1,2,0)),
                                          np.transpose(img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                          ["original_image", "heat_map"],
                                          ["all", "positive"],
                                          cmap=default_cmap,
                                          show_colorbar=True)

### Occlusion-based approach

Cover parts of the image with a sliding window and see what changes.

In [None]:
def occlusion_explain(img, largest_idx):
    occlusion = Occlusion(perses)

    attributions_occ = occlusion.attribute(img, strides = (3, 8, 8),
                                       target=largest_idx,
                                       sliding_window_shapes=(3,15, 15),
                                       baselines=0)
    
    _ = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1,2,0)),
                                          np.transpose(img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                          ["original_image", "heat_map"],
                                          ["all", "positive"],
                                          show_colorbar=True,
                                          outlier_perc=2,
                                         )

## Evaluation

Just hit enter when asked whether you want to continue.

You may quit by interrupting the kernel under `Kernel` > `Interrupt kernel`

In [None]:
dataiter = iter(dataloader)

In [None]:
for img, label in dataiter:    
    output = perses(img)
    output = torch.sigmoid(output).round()
    
    largest_idx = 0
    for i, val in enumerate(output.to('cpu').tolist()[0]):
        if val == 1:
            largest_idx = i
    
    print('Target: ', label)
    print('Ouput: ', output)
    print(largest_idx)
    
    _ = input('Explain?')
    
    gradient_explain(img, largest_idx)
    
    _ = input('Next?')
    
    occlusion_explain(img, largest_idx)