模型权重初始化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def weight_init(m):  # 初始化权重
    # print(m)
    if isinstance(m, torch.nn.Conv3d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
        m.weight.data = torch.randint_like(m.weight.data, low=-128, high=127)
        # m.bias.data.zero_()
        if m.bias!=None:
            m.bias.data = torch.randint_like(m.bias.data, low=-128, high=127)
    elif isinstance(m, torch.nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        # data = np.load("weight.npy")
        # m.weight.data = torch.tensor(data)
        m.weight.data = torch.randint_like(m.weight.data, low=-128, high=127)
        # print("weight",m.weight.data.shape)
        # print(m.weight.data)
        # print(m.weight.data)
        # m=torch.nn.Conv2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, bias=True, stride=m.stride, padding=m.padding)
        if m.bias!=None:
            m.bias.data = torch.randint_like(m.bias.data, low=-128, high=127)
    elif isinstance(m, torch.nn.BatchNorm3d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif isinstance(m, torch.nn.Linear):
        m.weight.data=torch.randint_like(m.weight.data, low=-128, high=127)
        if m.bias is not None:
            m.bias.data.zero_()
             
             
 # 将模型权重初始化为int8
 model.apply(weight_init)

  

posted @   Truman001  阅读(36)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
历史上的今天:
2018-03-23 C++11-元组
2018-03-23 文件读写
点击右上角即可分享
微信分享提示