Skip to main content

Generating class discriminative heat-maps using Grad-CAM

I’m quite interested in understanding and interpreting how convnets “see” and process input images that we feed them. I first got a taste of this kind of work after reading Visualizing and Understanding Convolutional Networks by Matthew D Zeiler and Rob Fergus, which is 5 years old as of today. I’m guessing a lot of work has been/is being done by the deep learning research community to make convnets more intuitive and understandable. I’m trying to take strides towards understanding that work.

This post/notebook is an exercise in generating localization heat maps to help visualise areas of an image which contribute the most when making a prediction. Inspiration for this comes from a fastai Deep Learning MOOC (2018) lecture which is itself inspired by Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization by Ramprasaath R. Selvaraju, Michael Cogswell, Abhishek Das, Ramakrishna Vedantam, Devi Parikh, Dhruv Batra.

Setup

I'm using The Oxford-IIIT Pet Dataset for this experiment, which is hosted for fastai by AWS here.

In [0]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from fastai import *
from fastai.vision import *
In [0]:
bs = 32  #batch size
In [0]:
path = untar_data(URLs.PETS)/'images'

Data Augmentation

Setting up a data generator which also augments the data based on various transforms. More on fastai's transforms here.

In [0]:
tfms = get_transforms(max_rotate=20, max_zoom=1.3, max_lighting=0.4, max_warp=0.4,
                      p_affine=.2, p_lighting=.2)
src = ImageItemList.from_folder(path).random_split_by_pct(0.2, seed=2)
In [0]:
def get_data(size, bs, padding_mode='reflection'):
    return (src.label_from_re(r'([^/]+)_\d+.jpg$')
           .transform(tfms, size=size, padding_mode=padding_mode)
           .databunch(bs=bs).normalize(imagenet_stats))
In [16]:
data = get_data(224, bs)
data.show_batch(rows=3, figsize=(6,6))

Training

I'll be using ResNet-34's architecture for this network. Fastai's create_cnn handles altering the final fully connected layer to output a 37 dimensional output behind the scenes. This is done based on the number of classes in the dataset which is stored as data.c. The classes themselves are stored in data.classes.

In [26]:
print(data.c)
37
In [24]:
print(data.classes)
['Abyssinian', 'Bengal', 'Birman', 'Bombay', 'British_Shorthair', 'Egyptian_Mau', 'Maine_Coon', 'Persian', 'Ragdoll', 'Russian_Blue', 'Siamese', 'Sphynx', 'american_bulldog', 'american_pit_bull_terrier', 'basset_hound', 'beagle', 'boxer', 'chihuahua', 'english_cocker_spaniel', 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', 'japanese_chin', 'keeshond', 'leonberger', 'miniature_pinscher', 'newfoundland', 'pomeranian', 'pug', 'saint_bernard', 'samoyed', 'scottish_terrier', 'shiba_inu', 'staffordshire_bull_terrier', 'wheaten_terrier', 'yorkshire_terrier']
In [0]:
gc.collect()
learn = create_cnn(data, models.resnet34, metrics=error_rate, bn_final=True)
Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.torch/models/resnet34-333f7ec4.pth
100%|██████████| 87306240/87306240 [00:01<00:00, 83351997.05it/s]

Taking a look at the network's architecture.

In [0]:
learn.summary()
Input Size override by Learner.data.train_dl
Input Size passed in: 16 

================================================================================
Layer (type)               Output Shape         Param #   
================================================================================
Conv2d                    [16, 64, 176, 176]   9408                
________________________________________________________________________________
BatchNorm2d               [16, 64, 176, 176]   128                 
________________________________________________________________________________
ReLU                      [16, 64, 176, 176]   0                   
________________________________________________________________________________
MaxPool2d                 [16, 64, 88, 88]     0                   
________________________________________________________________________________
Conv2d                    [16, 64, 88, 88]     36864               
________________________________________________________________________________
BatchNorm2d               [16, 64, 88, 88]     128                 
________________________________________________________________________________
ReLU                      [16, 64, 88, 88]     0                   
________________________________________________________________________________
Conv2d                    [16, 64, 88, 88]     36864               
________________________________________________________________________________
BatchNorm2d               [16, 64, 88, 88]     128                 
________________________________________________________________________________
Conv2d                    [16, 64, 88, 88]     36864               
________________________________________________________________________________
BatchNorm2d               [16, 64, 88, 88]     128                 
________________________________________________________________________________
ReLU                      [16, 64, 88, 88]     0                   
________________________________________________________________________________
Conv2d                    [16, 64, 88, 88]     36864               
________________________________________________________________________________
BatchNorm2d               [16, 64, 88, 88]     128                 
________________________________________________________________________________
Conv2d                    [16, 64, 88, 88]     36864               
________________________________________________________________________________
BatchNorm2d               [16, 64, 88, 88]     128                 
________________________________________________________________________________
ReLU                      [16, 64, 88, 88]     0                   
________________________________________________________________________________
Conv2d                    [16, 64, 88, 88]     36864               
________________________________________________________________________________
BatchNorm2d               [16, 64, 88, 88]     128                 
________________________________________________________________________________
Conv2d                    [16, 128, 44, 44]    73728               
________________________________________________________________________________
BatchNorm2d               [16, 128, 44, 44]    256                 
________________________________________________________________________________
ReLU                      [16, 128, 44, 44]    0                   
________________________________________________________________________________
Conv2d                    [16, 128, 44, 44]    147456              
________________________________________________________________________________
BatchNorm2d               [16, 128, 44, 44]    256                 
________________________________________________________________________________
Conv2d                    [16, 128, 44, 44]    8192                
________________________________________________________________________________
BatchNorm2d               [16, 128, 44, 44]    256                 
________________________________________________________________________________
Conv2d                    [16, 128, 44, 44]    147456              
________________________________________________________________________________
BatchNorm2d               [16, 128, 44, 44]    256                 
________________________________________________________________________________
ReLU                      [16, 128, 44, 44]    0                   
________________________________________________________________________________
Conv2d                    [16, 128, 44, 44]    147456              
________________________________________________________________________________
BatchNorm2d               [16, 128, 44, 44]    256                 
________________________________________________________________________________
Conv2d                    [16, 128, 44, 44]    147456              
________________________________________________________________________________
BatchNorm2d               [16, 128, 44, 44]    256                 
________________________________________________________________________________
ReLU                      [16, 128, 44, 44]    0                   
________________________________________________________________________________
Conv2d                    [16, 128, 44, 44]    147456              
________________________________________________________________________________
BatchNorm2d               [16, 128, 44, 44]    256                 
________________________________________________________________________________
Conv2d                    [16, 128, 44, 44]    147456              
________________________________________________________________________________
BatchNorm2d               [16, 128, 44, 44]    256                 
________________________________________________________________________________
ReLU                      [16, 128, 44, 44]    0                   
________________________________________________________________________________
Conv2d                    [16, 128, 44, 44]    147456              
________________________________________________________________________________
BatchNorm2d               [16, 128, 44, 44]    256                 
________________________________________________________________________________
Conv2d                    [16, 256, 22, 22]    294912              
________________________________________________________________________________
BatchNorm2d               [16, 256, 22, 22]    512                 
________________________________________________________________________________
ReLU                      [16, 256, 22, 22]    0                   
________________________________________________________________________________
Conv2d                    [16, 256, 22, 22]    589824              
________________________________________________________________________________
BatchNorm2d               [16, 256, 22, 22]    512                 
________________________________________________________________________________
Conv2d                    [16, 256, 22, 22]    32768               
________________________________________________________________________________
BatchNorm2d               [16, 256, 22, 22]    512                 
________________________________________________________________________________
Conv2d                    [16, 256, 22, 22]    589824              
________________________________________________________________________________
BatchNorm2d               [16, 256, 22, 22]    512                 
________________________________________________________________________________
ReLU                      [16, 256, 22, 22]    0                   
________________________________________________________________________________
Conv2d                    [16, 256, 22, 22]    589824              
________________________________________________________________________________
BatchNorm2d               [16, 256, 22, 22]    512                 
________________________________________________________________________________
Conv2d                    [16, 256, 22, 22]    589824              
________________________________________________________________________________
BatchNorm2d               [16, 256, 22, 22]    512                 
________________________________________________________________________________
ReLU                      [16, 256, 22, 22]    0                   
________________________________________________________________________________
Conv2d                    [16, 256, 22, 22]    589824              
________________________________________________________________________________
BatchNorm2d               [16, 256, 22, 22]    512                 
________________________________________________________________________________
Conv2d                    [16, 256, 22, 22]    589824              
________________________________________________________________________________
BatchNorm2d               [16, 256, 22, 22]    512                 
________________________________________________________________________________
ReLU                      [16, 256, 22, 22]    0                   
________________________________________________________________________________
Conv2d                    [16, 256, 22, 22]    589824              
________________________________________________________________________________
BatchNorm2d               [16, 256, 22, 22]    512                 
________________________________________________________________________________
Conv2d                    [16, 256, 22, 22]    589824              
________________________________________________________________________________
BatchNorm2d               [16, 256, 22, 22]    512                 
________________________________________________________________________________
ReLU                      [16, 256, 22, 22]    0                   
________________________________________________________________________________
Conv2d                    [16, 256, 22, 22]    589824              
________________________________________________________________________________
BatchNorm2d               [16, 256, 22, 22]    512                 
________________________________________________________________________________
Conv2d                    [16, 256, 22, 22]    589824              
________________________________________________________________________________
BatchNorm2d               [16, 256, 22, 22]    512                 
________________________________________________________________________________
ReLU                      [16, 256, 22, 22]    0                   
________________________________________________________________________________
Conv2d                    [16, 256, 22, 22]    589824              
________________________________________________________________________________
BatchNorm2d               [16, 256, 22, 22]    512                 
________________________________________________________________________________
Conv2d                    [16, 512, 11, 11]    1179648             
________________________________________________________________________________
BatchNorm2d               [16, 512, 11, 11]    1024                
________________________________________________________________________________
ReLU                      [16, 512, 11, 11]    0                   
________________________________________________________________________________
Conv2d                    [16, 512, 11, 11]    2359296             
________________________________________________________________________________
BatchNorm2d               [16, 512, 11, 11]    1024                
________________________________________________________________________________
Conv2d                    [16, 512, 11, 11]    131072              
________________________________________________________________________________
BatchNorm2d               [16, 512, 11, 11]    1024                
________________________________________________________________________________
Conv2d                    [16, 512, 11, 11]    2359296             
________________________________________________________________________________
BatchNorm2d               [16, 512, 11, 11]    1024                
________________________________________________________________________________
ReLU                      [16, 512, 11, 11]    0                   
________________________________________________________________________________
Conv2d                    [16, 512, 11, 11]    2359296             
________________________________________________________________________________
BatchNorm2d               [16, 512, 11, 11]    1024                
________________________________________________________________________________
Conv2d                    [16, 512, 11, 11]    2359296             
________________________________________________________________________________
BatchNorm2d               [16, 512, 11, 11]    1024                
________________________________________________________________________________
ReLU                      [16, 512, 11, 11]    0                   
________________________________________________________________________________
Conv2d                    [16, 512, 11, 11]    2359296             
________________________________________________________________________________
BatchNorm2d               [16, 512, 11, 11]    1024                
________________________________________________________________________________
AdaptiveAvgPool2d         [16, 512, 1, 1]      0                   
________________________________________________________________________________
AdaptiveMaxPool2d         [16, 512, 1, 1]      0                   
________________________________________________________________________________
Lambda                    [16, 1024]           0                   
________________________________________________________________________________
BatchNorm1d               [16, 1024]           2048                
________________________________________________________________________________
Dropout                   [16, 1024]           0                   
________________________________________________________________________________
Linear                    [16, 512]            524800              
________________________________________________________________________________
ReLU                      [16, 512]            0                   
________________________________________________________________________________
BatchNorm1d               [16, 512]            1024                
________________________________________________________________________________
Dropout                   [16, 512]            0                   
________________________________________________________________________________
Linear                    [16, 37]             18981               
________________________________________________________________________________
BatchNorm1d               [16, 37]             74                  
________________________________________________________________________________
Total params:  21831599
In [0]:
learn.fit_one_cycle(3, slice(1e-2), pct_start=0.8)
Total time: 04:55

epoch train_loss valid_loss error_rate
1 2.148532 1.041757 0.173207
2 1.111558 0.338378 0.102165
3 0.740522 0.288959 0.077131
In [0]:
learn.unfreeze()
learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-3), pct_start=0.8)
Total time: 03:30

epoch train_loss valid_loss error_rate
1 0.705978 0.266104 0.067659
2 0.620636 0.266208 0.071719
In [0]:
data = get_data(352,bs)
learn.data = data
In [0]:
learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-4))
Total time: 06:05

epoch train_loss valid_loss error_rate
1 0.576229 0.227271 0.052774
2 0.511623 0.220940 0.050744
In [0]:
learn.save('352')

The model predicts 37 classes of dogs and cats with a 5.0744% error rate on the validation dataset. That's good enough for this experiment.

Non class-discriminative heat-maps

One of the reasons why convolutional layers are used for deep learning on image data is that they naturally retain spatial information present in the inputs which is manipulated to represent high-level semantics as we move deeper in the network, and is finally handed over to fully connected layers that come up with the relevant outputs based on their weights and biases.

Selvaraju et al. state in their paper:

we can expect the last convolutional layers to have the best compromise between high-level semantics and detailed spatial information. The neurons in these layers look for semantic class-specific information in the image (say object parts).

Activations of these feature maps can be directly used to visualise which parts of an image the network “focuses” on the most. This might not be instinctively apparent at first (it surely wasn't for me), but it works. Let’s see that in action.

In [0]:
data = get_data(352,16)
In [0]:
learn = create_cnn(data, models.resnet34, metrics=error_rate, bn_final=True).load('352')

Let's work with this image of a miniature_pinscher.

In [138]:
idx=4
x,y = data.valid_ds[idx]
x.show()
print(f'class name: {y}\nclass_index: {y.data}')
class name: miniature_pinscher
class_index: 26

Fastai's learn.model for a CNN is of type torch.nn.modules.container.Sequential of length 2. The first element of this model is another torch.nn.modules.container.Sequential which contains all of the convolutional layers, while the second element contains the fully connected layers.

In [41]:
type(learn.model)
Out[41]:
torch.nn.modules.container.Sequential
In [42]:
len(learn.model)
Out[42]:
2

Let's see the model itself.

In [0]:
learn.model
Out[0]:
Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (5): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): BasicBlock(
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (3): BasicBlock(
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (6): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): BasicBlock(
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (3): BasicBlock(
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (4): BasicBlock(
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (5): BasicBlock(
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (7): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): BasicBlock(
        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (1): Sequential(
    (0): AdaptiveConcatPool2d(
      (ap): AdaptiveAvgPool2d(output_size=1)
      (mp): AdaptiveMaxPool2d(output_size=1)
    )
    (1): Lambda()
    (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.25)
    (4): Linear(in_features=1024, out_features=512, bias=True)
    (5): ReLU(inplace)
    (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Dropout(p=0.5)
    (8): Linear(in_features=512, out_features=37, bias=True)
    (9): BatchNorm1d(37, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  )
)

In order to get the activations of the feature maps of the last convolutional layer, we need to place a Hook on the output of this layer.

In [0]:
from fastai.callbacks.hooks import *
In [0]:
m = learn.model.eval();
In [0]:
xb,_ = data.one_item(x)    #get tensor from Image x
xb_im = Image(data.denorm(xb)[0])
xb = xb.cuda()
In [96]:
xb_im
Out[96]:
In [0]:
def non_class_discriminative_activations(xb):
    with hook_output(m[0]) as hook_a: 
        preds = m(xb)
    return hook_a
In [0]:
hook_a = non_class_discriminative_activations(xb)
In [84]:
acts  = hook_a.stored[0].cpu()
acts.shape
Out[84]:
torch.Size([512, 11, 11])

As expected, the shape of the activations of the final convolutional layers is (512,11,11), where 512 is the number of channels, and 11 is both the height and width of the feature maps.

Now let's do something that felt totally unintuitive to me the first time I did it. Let's average the values of these activations over the channel axis to get a (11,11) tensor.

In [55]:
avg_acts = acts.mean(0)
avg_acts.shape
Out[55]:
torch.Size([11, 11])

We have an (11,11) tensor that represents the spatial information captured by the convnet till the last convolutional layer averaged over the channels axis. Let's plot it in 2 dimensions.

In [59]:
plt.imshow(avg_acts, cmap='magma');

This definitely looks to be concentrated at the face of the miniature_pinscher image above. Let's plot the input image with the heatmap over it. The (11,11) tensor can be extrapolated to the size of the input, ie (352,352), by using the extent argument of imshow.

In [0]:
def show_non_class_discriminative_heatmap(x):
    xb,_ = data.one_item(x)
    xb_im = Image(data.denorm(xb)[0])
    xb = xb.cuda()
    hook_a = non_class_discriminative_activations(xb)
    acts  = hook_a.stored[0].cpu()
    avg_acts = acts.mean(0)
    
    _,ax = plt.subplots()
    xb_im.show(ax)
    ax.imshow(avg_acts, alpha=0.6, extent=(0,352,352,0),
              interpolation='bilinear', cmap='magma');
In [115]:
show_non_class_discriminative_heatmap(x)

It actually works!! Let's do this for a bunch of images.

In [0]:
import random
In [0]:
def plot_non_class_disc_multi():
    random.seed(25)

    val_size = len(data.valid_ds)

    fig,ax = plt.subplots(2,4)
    fig.set_size_inches(12,6)

    for i in range(2):
        for j in range(0,4,2):
            idx=random.randint(0, val_size)
            x,y = data.valid_ds[idx]
            xb,_ = data.one_item(x)
            xb_im = Image(data.denorm(xb)[0])
            xb = xb.cuda()
            hook_a = non_class_discriminative_activations(xb)
            acts  = hook_a.stored[0].cpu()
            avg_acts = acts.mean(0)
            xb_im.show(ax[i,j])
            xb_im.show(ax[i,j+1])
            ax[i,j+1].imshow(avg_acts, alpha=0.6, extent=(0,352,352,0),
                      interpolation='bilinear', cmap='magma');
    plt.show()
In [248]:
plot_non_class_disc_multi()