Pytorch框架学习---(6)hook函数和CAM类激活图
本节简单总结Pytorch中hook函数,CAM算法生成注意力图【文中思维导图采用MindMaster软件】 |
1.hook函数
(1)定义
不改变主体(前向、后向传播等)情况下,实现额外的功能,如在backward之后,仍然可以得到特征图和非叶子节点的梯度,即便它们被释放。
(2)方法
节省精力, 由于网上已经有人对这4和hook函数总结的很好,故在此引用,不再复写。
这里我们直接来举一个例子,使用hook函数可视化所有层的特征图,即调用上面的register_forward_hook获取网络层的输出:
# 注册hook
fmap_dict = dict()
for name, sub_module in alexnet.named_modules(): # 如果是named_children()则是返回Sequential本身features
# print(sub_module) # sub_module Sequential本身features以及内部所有的网络层features.0
if isinstance(sub_module, nn.Conv2d):
key_name = str(sub_module.weight.shape)
fmap_dict.setdefault(key_name, list()) # 构建字典中key value对
n1, n2 = name.split(".") # features.0, 为nn.Sequential
def hook_func(module, i, o):
key_name = str(module.weight.shape)
fmap_dict[key_name].append(o) # 索引名字,添加特征图
# print("famp_dict:{}".format(fmap_dict))
alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func)
# forward
output = alexnet(img_tensor)
# add image
for layer_name, fmap_list in fmap_dict.items(): # 返回一个可迭代的列表
fmap = fmap_list[0] # 把list中元素取出
fmap.transpose_(0, 1)
nrow = int(np.sqrt(fmap.shape[0]))
fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)
writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=0)
对每一个卷积层得到的特征图,作tensorboard可视化:
注意:这里可视化卷积层,但是由于卷积层后面接的是激活函数relu,其中relu(inplace=True)原位操作,会对卷积层的输出做一定的改变。
2.CAM(Class Activation Map)类激活图
啥话先不说,直接上图!!!原来这个就是CAM算法出来的,当判别网络将图片归为“猫”这个类别时,红色代表网络注意的地方,蓝色则是没有注意的地方:
(1)原始CAM
最后一层卷积得到的特征图,经过全局平均池化GAP,得到对应神经元向量,全连接层的权重,即是CAM对特征图加权的权重,经过加权之后的特征图即是最终类似注意力的激活图。
局限性:最后必须是GAP,需要改动原始网络并重新训练,因而改进版Grad-CAM上线。
(2)Grad-CAM(利用特征图的梯度,作为加权权重)
对特征图梯度做平均,得到n个特征图对应的n个平均梯度,将其作为CAM权重。
实战代码如下参考:github,后续用到CAM时,再放入自己项目的激活图展示代码。
吾志所向,一往无前;愈挫愈勇,再接再厉。