深度学习(模型参数直方图)

模型参数直方图可以展示模型参数在训练过程中的分布情况。

通过直方图,可以了解模型的学习状态,识别过拟合或欠拟合问题,从而进行模型调优。

下面以ResNet18为例,显示了不同层的参数直方图。

import torchvision
from matplotlib import pyplot as plt
import torch

model = torchvision.models.resnet18(pretrained=True)

num = 1
# 遍历模型的每一层
for name, module in model.named_modules():
    # 判断是否为卷积层
    if isinstance(module, torch.nn.Conv2d):
        # 输出卷积层名称和权重
        print(f"layer {name} : {module.weight.data.shape}")
        Oc,Ic,H,W = module.weight.data.shape
        data = module.weight.data.view(Oc*Ic*H*W).numpy()            
        plt.subplot(5,4,num)
        plt.hist(data,bins=50)
        num +=1

plt.show()           

结果如下:

posted @ 2024-10-03 11:57  Dsp Tian  阅读(24)  评论(0编辑  收藏  举报