PyTorch:可视化

随着深度神经网络做的的发展,网络的结构越来越复杂,我们也很难确定每一层的输入结构,输出结构以及参数等信息,这样导致我们很难在短时间内完成debug。因此掌握一个可以用来可视化网络结构的工具是十分有必要的。

1. 可视化网络结构

1.1 使用 print 函数打印模型基础信息

import torchvision.models as models
net = models.resnet50()
print(net)
# 输出结果
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
  ......   
  (layer4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  ......
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=2048, out_features=1000, bias=True)
)

print函数只能得出基础构件的信息,既不能显示出每一层的 output shape (与输入数据有关),也不能显示对应参数量的大小。

1.2 使用 torchinfo 可视化网络结构

import torchvision.models as models
from torchinfo import summary
resnet18 = models.resnet18() # 实例化模型
summary(resnet18, (1, 3, 224, 224)) # 1:batch_size 3:图片的通道数 224: 图片的高宽
# 输出结果
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   --                        --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 64, 56, 56]           --
│    └─BasicBlock: 2-1                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-6                    [1, 64, 56, 56]           --
│    └─BasicBlock: 2-2                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-7                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-8             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-9                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-10                 [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-11            [1, 64, 56, 56]           128
│    │    └─ReLU: 3-12                   [1, 64, 56, 56]           --
├─Sequential: 1-6                        [1, 128, 28, 28]          --
│    └─BasicBlock: 2-3                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-13                 [1, 128, 28, 28]          73,728
│    │    └─BatchNorm2d: 3-14            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-15                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-16                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-17            [1, 128, 28, 28]          256
│    │    └─Sequential: 3-18             [1, 128, 28, 28]          8,448
│    │    └─ReLU: 3-19                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-4                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-20                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-21            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-22                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-23                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-24            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-25                   [1, 128, 28, 28]          --
├─Sequential: 1-7                        [1, 256, 14, 14]          --
│    └─BasicBlock: 2-5                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-26                 [1, 256, 14, 14]          294,912
│    │    └─BatchNorm2d: 3-27            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-28                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-29                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-30            [1, 256, 14, 14]          512
│    │    └─Sequential: 3-31             [1, 256, 14, 14]          33,280
│    │    └─ReLU: 3-32                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-6                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-33                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-34            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-35                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-36                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-37            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-38                   [1, 256, 14, 14]          --
├─Sequential: 1-8                        [1, 512, 7, 7]            --
│    └─BasicBlock: 2-7                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-39                 [1, 512, 7, 7]            1,179,648
│    │    └─BatchNorm2d: 3-40            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-41                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-42                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-43            [1, 512, 7, 7]            1,024
│    │    └─Sequential: 3-44             [1, 512, 7, 7]            132,096
│    │    └─ReLU: 3-45                   [1, 512, 7, 7]            --
│    └─BasicBlock: 2-8                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-46                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-47            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-48                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-49                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-50            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-51                   [1, 512, 7, 7]            --
├─AdaptiveAvgPool2d: 1-9                 [1, 512, 1, 1]            --
├─Linear: 1-10                           [1, 1000]                 513,000
==========================================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
Total mult-adds (G): 1.81
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 39.75
Params size (MB): 46.76
Estimated Total Size (MB): 87.11

2. CNN 可视化

理解 CNN 的重要一步是可视化,包括可视化特征是如何提取的、提取到的特征的形式以及模型在输入数据上的关注点等。

2.1 卷积核可视化

卷积核在CNN中负责提取特征,可视化卷积核能够帮助人们理解CNN各个层在提取什么样的特征,进而理解模型的工作原理。一般来说,靠近输入的层提取的特征是相对简单的结构,而靠近输出的层提取的特征和图中的实体形状相近。

在PyTorch中,可视化卷积核就等价于可视化对应的权重矩阵。基本步骤是找到已经训练好的模型中卷积核的位置,然后把权重矩阵调出来,最后可视化。

import torch
from torchvision.models import vgg11

model = vgg11(pretrained=True)
print(dict(model.features.named_children()))
# 输出结果
{'0': Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '1': ReLU(inplace=True),
 '2': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
 '3': Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '4': ReLU(inplace=True),
 '5': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
 '6': Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '7': ReLU(inplace=True),
 '8': Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '9': ReLU(inplace=True),
 '10': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
 '11': Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '12': ReLU(inplace=True),
 '13': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '14': ReLU(inplace=True),
 '15': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
 '16': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '17': ReLU(inplace=True),
 '18': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '19': ReLU(inplace=True),
 '20': MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)}

可以发现模型中有很多个卷积层,靠近输入层的是命名是 “0”,靠近输出层的是 “18”。我们以第 “0” 层为例,可视化对应的参数:

conv0 = dict(model.features.named_children())['0']
# Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

kernel_set = conv0.weight.detach()  # Tensor.detach()  从计算图中脱离出来,返回一个新的tensor,新的tensor和原tensor共享数据内存
# tensor([[[[ 0.2882,  0.0358, -0.3850],
#           [ 0.1795,  0.3668, -0.5012],
#           [-0.0974,  0.3648, -0.2296]],
# ...


num = len(conv0.weight.detach())  
# 64

print(kernel_set.shape)
# torch.Size([64, 3, 3, 3])

import matplotlib.pyplot as plt

for i in range(0,num):
    i_kernel = kernel_set[i] # torch.Size([3, 3, 3])
    plt.figure(figsize=(20, 17))
    if (len(i_kernel)) > 1:  # 3
        for idx, filter in enumerate(i_kernel):
            plt.subplot(9, 9, idx+1) 
            plt.axis('off')
            plt.imshow(filter[ :, :].detach(),cmap='bwr')  # filter, torch.Size([3, 3])

由于第“0”层的特征图由 3 维变为 64 维,因此共有 3*64 个卷积核,其中部分卷积核可视化效果如下图所示:

2.2 特征图可视化

与卷积核相对应,输入的原始图像经过每次卷积层得到的数据称为特征图。这个特征图与卷积核(参数)不同,是会随着输入的不同而不同(数值)。

获取特征图的方法有很多种,可以从输入开始,逐层做前向传播,直到想要的特征图处将其返回,但这种方法有些麻烦。在PyTorch中,提供了一个专用的接口 hook 使得网络在前向传播过程中能够获取到特征图,具体实现如下:

class Hook(object):
    def __init__(self):
        self.module_name = []
        self.features_in_hook = []
        self.features_out_hook = []

    def __call__(self,module, fea_in, fea_out):
        print("hooker working", self)
        self.module_name.append(module.__class__)
        self.features_in_hook.append(fea_in)
        self.features_out_hook.append(fea_out)
        return None
    

def plot_feature(model, idx, inputs):
    hh = Hook()
    model.features[idx].register_forward_hook(hh)   #  将该 hook 类的对象注册到要进行可视化的网络的某层中。model在进行前向传播的时候会调用 hook 的__call__函数,存储当前层的输入和输出。
    
    model.eval()  # 不更新参数
    _ = model(inputs)  # 模型前向传播,此时 hook 会拦截目标层的输入输出
    print(hh.module_name)   # [<class 'torch.nn.modules.conv.Conv2d'>]
    print((hh.features_in_hook[0][0].shape))  # torch.Size([1, 3, 244, 244])
    print((hh.features_out_hook[0].shape))  # torch.Size([1, 64, 244, 244])
    
    out1 = hh.features_out_hook[0]

    total_ft  = out1.shape[1]  # 64
    first_item = out1[0].cpu().clone()  # torch.Size([64, 244, 244])

    plt.figure(figsize=(20, 17))  # figsize: 指定figure的宽和高,单位为英寸;
    
    for ftidx in range(total_ft):
        if ftidx > 99:   # 可能是为了防止 ftidx 输入错误
            break
        ft = first_item[ftidx]  # torch.Size([244, 244])
        plt.subplot(10, 10, ftidx+1)  # 把 figure 分成 nrows*ncols 的子图表示, plot_number 索引值,表示把图画在第 plot_number 个位置
        
        plt.axis('off')
        # plt.imshow(ft[ :, :].detach(),cmap='gray')
        plt.imshow(ft[ :, :].detach())  

# 测试用例
inputs = torch.ones([1, 3, 244, 244])
idx = 0
plot_feature(model, idx, inputs)   # '0': Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# hooker working <__main__.Hook object at 0x000001E41DC84B50>
# hooker working <__main__.Hook object at 0x000001E41F97F250>
# [<class 'torch.nn.modules.conv.Conv2d'>]
# torch.Size([1, 3, 244, 244])
# torch.Size([1, 64, 244, 244])

这里的features_out_hook 是一个 list,每次前向传播一次,都是调用一次,也就是 features_out_hook 长度会增加1。也就是说,会记录每一个输入图像在这一层的特征图。

2.3 Class activation map 可视化

class activation map(CAM)的作用是判断哪些变量对模型来说是重要的。CAM 系列操作的实现可以通过开源工具包 pytorch-grad-cam 来实现。

import torch
from torchvision.models import vgg11
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

model = vgg11(pretrained=True)
img_path = './dog.png'  # 需要事先把图片放置到与代码文件相同的文件目录下

# resize操作是为了和传入神经网络训练图片大小一致
img = Image.open(img_path).resize((224,224))
plt.imshow(img)
from pytorch_grad_cam import GradCAM,ScoreCAM,GradCAMPlusPlus,AblationCAM,XGradCAM,EigenCAM,FullGrad  # 不同的类激活图,选择一个即可 
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import torchvision.transforms as transforms

target_layers = [model.features[-1]] # feature组的最后一层 [MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)]

# 选取合适的类激活图,但是ScoreCAM和AblationCAM需要batch_size
cam = GradCAM(model=model,target_layers=target_layers)
targets = [ClassifierOutputTarget(preds)]   
# 上方preds需要设定,比如ImageNet有1000类,这里可以设为200

img_tensor = transforms.ToTensor()(img)  # tensor数据格式是torch(C,H,W)
print(img_tensor.size())  # torch.Size([3, 224, 224])

grayscale_cam = cam(input_tensor=img_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :]

# 需要将原始图片转为np.float32格式并且在0-1之间 
rgb_img = np.float32(img)/255

cam_img = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
print(type(cam_img))
Image.fromarray(cam_img)

3. 使用TensorBoard可视化训练过程

参考资料

posted @   Junwei_Kuang  阅读(246)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)
点击右上角即可分享
微信分享提示