-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_explainer.py
More file actions
107 lines (95 loc) · 3.84 KB
/
create_explainer.py
File metadata and controls
107 lines (95 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from explainer import backprop as bp
from explainer import deeplift as df
from explainer import gradcam as gc
from explainer import patterns as pt
from explainer import ebp
from explainer import real_time as rt
def get_explainer(model, name,layer):
methods = {
'vanilla_grad': bp.VanillaGradExplainer,
'grad_x_input': bp.GradxInputExplainer,
'saliency': bp.SaliencyExplainer,
'integrate_grad': bp.IntegrateGradExplainer,
'deconv': bp.DeconvExplainer,
'guided_backprop': bp.GuidedBackpropExplainer,
'deeplift_rescale': df.DeepLIFTRescaleExplainer,
'gradcam': gc.GradCAMExplainer,
'3dgradcam': gc.GradCAMExplainer,
'pattern_net': pt.PatternNetExplainer,
'pattern_lrp': pt.PatternLRPExplainer,
'excitation_backprop': ebp.ExcitationBackpropExplainer,
'contrastive_excitation_backprop': ebp.ContrastiveExcitationBackpropExplainer,
'real_time_saliency': rt.RealTimeSaliencyExplainer
}
#print model.__class__.__name__
if name == 'smooth_grad':
base_explainer = methods['vanilla_grad'](model)
explainer = bp.SmoothGradExplainer(base_explainer)
elif name.find('pattern') != -1:
explainer = methods[name](
model,
params_file='./weights/imagenet_224_vgg_16.npz',
pattern_file='./weights/imagenet_224_vgg_16.patterns.A_only.npz'
)
elif name == 'gradcam':
if model.__class__.__name__ == 'VGG':
explainer = methods[name](
model, target_layer_name_keys=['features', '30'] # pool5
)
elif model.__class__.__name__ == 'GoogleNet':
explainer = methods[name](
model, target_layer_name_keys=['pool5'], use_inp=True,
)
elif model.__class__.__name__ == 'ResNeXt':
explainer = methods[name](
model, target_layer_name_keys=[layer],use_inp=True,
)
elif name == 'excitation_backprop':
if model.__class__.__name__ == 'VGG': # vgg16
explainer = methods[name](
model,
output_layer_keys=['features', '23'] # pool4
)
elif model.__class__.__name__ == 'ResNet': # resnet50
explainer = methods[name](
model,
output_layer_keys=['layer4', '1', 'conv1'] # res4a
)
elif model.__class__.__name__ == 'GoogleNet': # googlent
explainer = methods[name](
model,
output_layer_keys=['pool2']
)
elif name == 'contrastive_excitation_backprop':
if model.__class__.__name__ == 'VGG': # vgg16
explainer = methods[name](
model,
intermediate_layer_keys=['features', '30'], # pool5
output_layer_keys=['features', '23'], # pool4
final_linear_keys=['classifier', '6'] # fc8
)
elif model.__class__.__name__ == 'ResNet': # resnet50
explainer = methods[name](
model,
intermediate_layer_keys=['avgpool'],
output_layer_keys=['layer4', '1', 'conv1'], # res4a
final_linear_keys=['fc']
)
elif model.__class__.__name__ == 'GoogleNet':
explainer = methods[name](
model,
intermediate_layer_keys=['pool5'],
output_layer_keys=['pool2'],
final_linear_keys=['loss3.classifier']
)
elif name == 'real_time_saliency':
explainer = methods[name]('./weights/model-1.ckpt')
else:
explainer = methods[name](model)
return explainer
def get_heatmap(saliency):
saliency = saliency.squeeze()
if len(saliency.size()) == 2:
return saliency.abs().cpu().numpy()
else:
return saliency.abs().max(0)[0].cpu().numpy()