Generating class discriminative heat-maps using Grad-CAM
I’m quite interested in understanding and interpreting how convnets “see” and process input images that we feed them. I first got a taste of this kind of work after reading Visualizing and Understanding Convolutional Networks by Matthew D Zeiler and Rob Fergus, which is 5 years old as of today. I’m guessing a lot of work has been/is being done by the deep learning research community to make convnets more intuitive and understandable. I’m trying to take strides towards understanding that work.
This post/notebook is an exercise in generating localization heat maps to help visualise areas of an image which contribute the most when making a prediction. Inspiration for this comes from a fastai Deep Learning MOOC (2018) lecture which is itself inspired by Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization by Ramprasaath R. Selvaraju, Michael Cogswell, Abhishek Das, Ramakrishna Vedantam, Devi Parikh, Dhruv Batra.
Setup¶
I'm using The Oxford-IIIT Pet Dataset for this experiment, which is hosted for fastai by AWS here.
%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai import *
from fastai.vision import *
bs = 32 #batch size
path = untar_data(URLs.PETS)/'images'
tfms = get_transforms(max_rotate=20, max_zoom=1.3, max_lighting=0.4, max_warp=0.4,
p_affine=.2, p_lighting=.2)
src = ImageItemList.from_folder(path).random_split_by_pct(0.2, seed=2)
def get_data(size, bs, padding_mode='reflection'):
return (src.label_from_re(r'([^/]+)_\d+.jpg$')
.transform(tfms, size=size, padding_mode=padding_mode)
.databunch(bs=bs).normalize(imagenet_stats))
data = get_data(224, bs)
data.show_batch(rows=3, figsize=(6,6))
Training¶
I'll be using ResNet-34's architecture for this network. Fastai's create_cnn
handles altering the final fully connected layer to output a 37 dimensional output behind the scenes. This is done based on the number of classes in the dataset which is stored as data.c
. The classes themselves are stored in data.classes
.
print(data.c)
print(data.classes)
gc.collect()
learn = create_cnn(data, models.resnet34, metrics=error_rate, bn_final=True)
Taking a look at the network's architecture.
learn.summary()
learn.fit_one_cycle(3, slice(1e-2), pct_start=0.8)
learn.unfreeze()
learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-3), pct_start=0.8)
data = get_data(352,bs)
learn.data = data
learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-4))
learn.save('352')
The model predicts 37 classes of dogs and cats with a 5.0744%
error rate on the validation dataset. That's good enough for this experiment.
Non class-discriminative heat-maps¶
One of the reasons why convolutional layers are used for deep learning on image data is that they naturally retain spatial information present in the inputs which is manipulated to represent high-level semantics as we move deeper in the network, and is finally handed over to fully connected layers that come up with the relevant outputs based on their weights and biases.
Selvaraju et al. state in their paper:
we can expect the last convolutional layers to have the best compromise between high-level semantics and detailed spatial information. The neurons in these layers look for semantic class-specific information in the image (say object parts).
Activations of these feature maps can be directly used to visualise which parts of an image the network “focuses” on the most. This might not be instinctively apparent at first (it surely wasn't for me), but it works. Let’s see that in action.
data = get_data(352,16)
learn = create_cnn(data, models.resnet34, metrics=error_rate, bn_final=True).load('352')
Let's work with this image of a miniature_pinscher
.
idx=4
x,y = data.valid_ds[idx]
x.show()
print(f'class name: {y}\nclass_index: {y.data}')
Fastai's learn.model
for a CNN is of type torch.nn.modules.container.Sequential
of length 2. The first element of this model is another torch.nn.modules.container.Sequential
which contains all of the convolutional layers, while the second element contains the fully connected layers.
type(learn.model)
len(learn.model)
Let's see the model itself.
learn.model
In order to get the activations of the feature maps of the last convolutional layer, we need to place a Hook
on the output of this layer.
from fastai.callbacks.hooks import *
m = learn.model.eval();
xb,_ = data.one_item(x) #get tensor from Image x
xb_im = Image(data.denorm(xb)[0])
xb = xb.cuda()
xb_im
def non_class_discriminative_activations(xb):
with hook_output(m[0]) as hook_a:
preds = m(xb)
return hook_a
hook_a = non_class_discriminative_activations(xb)
acts = hook_a.stored[0].cpu()
acts.shape
As expected, the shape of the activations of the final convolutional layers is (512,11,11)
, where 512 is the number of channels, and 11 is both the height and width of the feature maps.
Now let's do something that felt totally unintuitive to me the first time I did it. Let's average the values of these activations over the channel axis to get a (11,11)
tensor.
avg_acts = acts.mean(0)
avg_acts.shape
We have an (11,11)
tensor that represents the spatial information captured by the convnet till the last convolutional layer averaged over the channels axis. Let's plot it in 2 dimensions.
plt.imshow(avg_acts, cmap='magma');
This definitely looks to be concentrated at the face of the miniature_pinscher
image above. Let's plot the input image with the heatmap over it. The (11,11)
tensor can be extrapolated to the size of the input, ie (352,352)
, by using the extent
argument of imshow
.
def show_non_class_discriminative_heatmap(x):
xb,_ = data.one_item(x)
xb_im = Image(data.denorm(xb)[0])
xb = xb.cuda()
hook_a = non_class_discriminative_activations(xb)
acts = hook_a.stored[0].cpu()
avg_acts = acts.mean(0)
_,ax = plt.subplots()
xb_im.show(ax)
ax.imshow(avg_acts, alpha=0.6, extent=(0,352,352,0),
interpolation='bilinear', cmap='magma');
show_non_class_discriminative_heatmap(x)
It actually works!! Let's do this for a bunch of images.
import random
def plot_non_class_disc_multi():
random.seed(25)
val_size = len(data.valid_ds)
fig,ax = plt.subplots(2,4)
fig.set_size_inches(12,6)
for i in range(2):
for j in range(0,4,2):
idx=random.randint(0, val_size)
x,y = data.valid_ds[idx]
xb,_ = data.one_item(x)
xb_im = Image(data.denorm(xb)[0])
xb = xb.cuda()
hook_a = non_class_discriminative_activations(xb)
acts = hook_a.stored[0].cpu()
avg_acts = acts.mean(0)
xb_im.show(ax[i,j])
xb_im.show(ax[i,j+1])
ax[i,j+1].imshow(avg_acts, alpha=0.6, extent=(0,352,352,0),
interpolation='bilinear', cmap='magma');
plt.show()
plot_non_class_disc_multi()