Evolution of Grad-CAM heat-maps along a ResNet-34
This exercise is a continuation of my last post, which was an exploration in generating class discriminative localization maps for a convnet. In particular, I used feature map activations of the last convolutional layer (after BatchNorm), along with gradients of a specific class score wrt these activations to create heat-maps that help visualize parts of input image that contribute most coming up with a prediction.
I wanted to extend that approach to see how these heat-maps shape up as we move deeper into the network, starting with the very first convolutional layer. Similar to the last post, inspiration for this comes from a fastai Deep Learning MOOC lecture which is itself inspired by Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization by Ramprasaath R. Selvaraju, Michael Cogswell, Abhishek Das, Ramakrishna Vedantam, Devi Parikh, Dhruv Batra.
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation, rc
from IPython.display import HTML
from fastai import *
from fastai.vision import *
from fastai import version as fastai_version
print(f'fastai version -> {fastai_version.__version__}')
bs = 32 #batch size
path = untar_data(URLs.PETS)/'images'
tfms = get_transforms(max_rotate=20, max_zoom=1.3, max_lighting=0.4, max_warp=0.4,
p_affine=.2, p_lighting=.2)
src = ImageItemList.from_folder(path).random_split_by_pct(0.2, seed=2)
def get_data(size, bs, padding_mode='reflection'):
return (src.label_from_re(r'([^/]+)_\d+.jpg$')
.transform(tfms, size=size, padding_mode=padding_mode)
.databunch(bs=bs).normalize(imagenet_stats))
data = get_data(224, bs)
data.show_batch(rows=3, figsize=(6,6))
data = get_data(352,16)
learn = create_cnn(data, models.resnet34, metrics=error_rate, bn_final=True).load('352')
Using pre-trained weights from the last post to keep things simple. Let's check out the network architecture.
learn.model
I'll be indexing into learn.model
to get the building blocks of the model.
learn.model[0][0]
learn.model[0][4][0]
Forward pass heat-maps along the network¶
Let's work with this image of a miniature_pinscher.
idx=4
x,y = data.valid_ds[idx]
x.show()
print(f'class name: {y}\nclass_index: {y.data}')
from fastai.callbacks.hooks import *
m = learn.model.eval();
xb,_ = data.one_item(x) #get tensor from Image x
xb_im = Image(data.denorm(xb)[0])
xb = xb.cuda()
non_class_discriminative_activations_multi
returns a list of forward pass activations at different points in the network along with layer_names
.
def non_class_discriminative_activations_multi(xb):
hooks = []
layer_names = []
hooks.append(hook_output(m[0][3]))
layer_names.append('first conv')
for ind in [4,5,6,7]:
for i,el in enumerate(learn.model[0][ind]):
hooks.append(hook_output(el))
layer_names.append(f'layer-{ind-3} - conv-{i+1}')
# layer 1 (layer is combination of resnet blocks) is model[0][4]
preds = m(xb)
for hook in hooks:
hook.remove()
return hooks,layer_names
hooks,layer_names = non_class_discriminative_activations_multi(xb)
Printing out shapes of these activation tensors. I'm using the terminology of calling a group of ResNet blocks a "layer", similar to PyTorch's implementation. There are 4 such "layers" in ResNet-34 having [3,4,6,3]
ResNet blocks.
for layer_name,hook in zip(layer_names,hooks):
print(f'{layer_name}{" "*(18-len(layer_name))} --> {hook.stored[0].shape}')
acts = hook_1.stored[0].cpu()
acts.shape
Averaging the values of these activations over the channel axis to get a 2 dimensional tensor.
avg_acts = acts.mean(0)
avg_acts.shape
Plotting these averaged activations.
plt.imshow(avg_acts, cmap='magma');
from math import ceil
Plotting all of the stored activations.
def plot_forward_activations_multi(hooks):
num_cols = 4
num_rows = ceil(len(hooks)/num_cols)
fig,ax = plt.subplots(num_rows,num_cols)
fig.set_size_inches(num_cols*3,num_rows*3)
ind = 0
for i in range(num_rows):
for j in range(num_cols):
if ind>=len(hooks):
break
acts = hooks[ind].stored[0].cpu()
avg_acts = acts.mean(0)
ax[i,j].imshow(avg_acts, cmap='magma')
ind+=1
plt.show()
plot_forward_activations_multi(hooks)
Plotting heat-maps based on these activations by extrapolating them to the size of the input image.
def plot_non_class_discriminative_heatmaps_multi(x):
xb,_ = data.one_item(x)
xb_im = Image(data.denorm(xb)[0])
xb = xb.cuda()
hooks,_ = non_class_discriminative_activations_multi(xb)
num_cols = 4
num_rows = ceil(len(hooks)/num_cols)
fig,ax = plt.subplots(num_rows,num_cols)
fig.set_size_inches(num_cols*3,num_rows*3)
ind = 0
for i in range(num_rows):
for j in range(num_cols):
if ind>=len(hooks):
break
acts = hooks[ind].stored[0].cpu()
avg_acts = acts.mean(0)
xb_im.show(ax[i,j])
ax[i,j].imshow(avg_acts, alpha=0.6, extent=(0,352,352,0),
interpolation='bilinear', cmap='magma');
ind+=1
plt.show()
plot_non_class_discriminative_heatmaps_multi(x)