from fastai.vision.all import *
Captum
In all this notebook, we will use the following data:
= untar_data(URLs.PETS)/'images'
path = get_image_files(path)
fnames def is_cat(x): return x[0].isupper()
= ImageDataLoaders.from_name_func(
dls =0.2, seed=42,
path, fnames, valid_pct=is_cat, item_tfms=Resize(128)) label_func
from random import randint
= vision_learner(dls, resnet34, metrics=error_rate)
learn 1) learn.fine_tune(
Captum Interpretation
The Distill Article here provides a good overview of what baseline image to choose. We can try them one by one.
CaptumInterpretation
CaptumInterpretation (learn, cmap_name='custom blue', colors=None, N=256, methods=('original_image', 'heat_map'), signs=('all', 'positive'), outlier_perc=1)
Captum Interpretation for Resnet
Interpretation
=CaptumInterpretation(learn)
captum=randint(0,len(fnames))
idx captum.visualize(fnames[idx])
='uniform') captum.visualize(fnames[idx],baseline_type
='gauss') captum.visualize(fnames[idx],baseline_type
='NT',baseline_type='uniform') captum.visualize(fnames[idx],metric
='Occl',baseline_type='gauss') captum.visualize(fnames[idx],metric
Captum Insights Callback
@patch
def _formatted_data_iter(x: CaptumInterpretation,dl,normalize_func):
=iter(dl)
dl_iterwhile True:
=next(dl_iter)
images,labels=normalize_func.decode(images).to(dl.device)
imagesyield Batch(inputs=images, labels=labels)
CaptumInterpretation.insights
CaptumInterpretation.insights (x:__main__.CaptumInterpretation, inp_data, debug=True)
=CaptumInterpretation(learn)
captum captum.insights(fnames)