模型权重初始化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | def weight_init(m): # 初始化权重 # print(m) if isinstance (m, torch.nn.Conv3d): n = m.kernel_size[ 0 ] * m.kernel_size[ 1 ] * m.kernel_size[ 2 ] * m.out_channels m.weight.data = torch.randint_like(m.weight.data, low = - 128 , high = 127 ) # m.bias.data.zero_() if m.bias! = None : m.bias.data = torch.randint_like(m.bias.data, low = - 128 , high = 127 ) elif isinstance (m, torch.nn.Conv2d): n = m.kernel_size[ 0 ] * m.kernel_size[ 1 ] * m.out_channels # data = np.load("weight.npy") # m.weight.data = torch.tensor(data) m.weight.data = torch.randint_like(m.weight.data, low = - 128 , high = 127 ) # print("weight",m.weight.data.shape) # print(m.weight.data) # print(m.weight.data) # m=torch.nn.Conv2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, bias=True, stride=m.stride, padding=m.padding) if m.bias! = None : m.bias.data = torch.randint_like(m.bias.data, low = - 128 , high = 127 ) elif isinstance (m, torch.nn.BatchNorm3d): m.weight.data.fill_( 1 ) m.bias.data.zero_() elif isinstance (m, torch.nn.Linear): m.weight.data = torch.randint_like(m.weight.data, low = - 128 , high = 127 ) if m.bias is not None : m.bias.data.zero_() # 将模型权重初始化为int8 model. apply (weight_init) |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
2018-03-23 C++11-元组
2018-03-23 文件读写