凯鲁嘎吉
用书写铭记日常,最迷人的不在远方

Python小练习:权重初始化(Weight Initialization)

作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/

调用Pytorch中的torch.nn.init.xxx实现对模型权重与偏置初始化。

1. weight_init_test.py

 1 # -*- coding: utf-8 -*-
 2 # Author:凯鲁嘎吉 Coral Gajic
 3 # https://www.cnblogs.com/kailugaji/
 4 # Python小练习:权重初始化(Weight Initialization)
 5 # Custom weight init for Conv2D and Linear layers.
 6 import torch
 7 import torch.nn.functional as F
 8 import torch.nn as nn
 9 # 根据网络层的不同定义不同的初始化方式
10 # 以下是两种不同的初始化方式:
11 # 正态分布+常数
12 def weight_init(m):
13     if isinstance(m, nn.Linear):
14         # 如果传入的参数是 nn.Linear 类型,则执行以下操作:
15         nn.init.xavier_normal_(m.weight) # 将权重初始化为 Xavier 正态分布
16         nn.init.constant_(m.bias, 0) # 将权重初始化为常数
17     elif isinstance(m, nn.Conv2d):
18         # 如果传入的参数是 nn.Conv2d 类型,则执行以下操作:
19         nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') # 将权重初始化为正态分布
20     elif isinstance(m, nn.BatchNorm2d):
21         # 如果传入的参数是 nn.BatchNorm2d 类型,则执行以下操作:
22         nn.init.constant_(m.weight, 1)
23         nn.init.constant_(m.bias, 0)
24 
25 # 正交+常数
26 def weight_init2(m):
27     if isinstance(m, nn.Linear):
28         # 如果传入的参数是 nn.Linear 类型,则执行以下操作:
29         nn.init.orthogonal_(m.weight.data) # 对权重矩阵进行正交化操作,使其具有对称性。
30         if hasattr(m.bias, 'data'):
31             m.bias.data.fill_(0.0) # 如果传入的参数包含偏置项,则将其填充为零。
32     elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
33         # 如果传入的参数是 nn.Conv2d 或 nn.ConvTranspose2d 类型,则执行以下操作:
34         gain = nn.init.calculate_gain('relu') # 用于计算激活函数的增益
35         nn.init.orthogonal_(m.weight.data, gain) # 对权重矩阵进行正交化操作,使其具有对称性。
36         if hasattr(m.bias, 'data'):
37             m.bias.data.fill_(0.0) # 如果传入的参数包含偏置项,则将其填充为零。
38 
39 class Net(nn.Module):
40     def __init__(self, input_size=1):
41         self.input_size = input_size
42         super(Net, self).__init__()
43         self.fc1 = nn.Linear(self.input_size, 2)
44         self.fc2 = nn.Linear(2, 4)
45         self.fc3 = nn.Linear(4, 2)
46 
47     def forward(self, x):
48         x = x.view(-1, self.input_size)
49         x = F.relu(self.fc1(x))
50         x = F.relu(self.fc2(x))
51         x = self.fc3(x)
52         return F.log_softmax(x, dim=1)
53 
54 torch.manual_seed(1)
55 num = 4 # 输入维度
56 x = torch.randn(1, num)
57 # 方式1:
58 model = Net(input_size = num)
59 print('网络结构:\n', model)
60 print('输入:\n', x)
61 model.apply(weight_init)
62 y = model(x)
63 print('输出1:\n', y.data)
64 print('权重1:\n', model.fc1.weight.data)
65 # 方式2:
66 model = Net(input_size = num)
67 model.apply(weight_init2)
68 y = model(x)
69 print('输出2:\n', y.data)
70 print('权重2:\n', model.fc1.weight.data)

2. 结果

D:\ProgramData\Anaconda3\python.exe "D:/Python code/2023.3 exercise/Neural Network/weight_init_test.py"
网络结构:
 Net(
  (fc1): Linear(in_features=4, out_features=2, bias=True)
  (fc2): Linear(in_features=2, out_features=4, bias=True)
  (fc3): Linear(in_features=4, out_features=2, bias=True)
)
输入:
 tensor([[0.6614, 0.2669, 0.0617, 0.6213]])
输出1:
 tensor([[-0.7233, -0.6639]])
权重1:
 tensor([[ 2.0709, -1.0573,  0.9230, -0.7373],
        [ 0.1879, -0.2766,  0.7962,  1.4599]])
输出2:
 tensor([[-0.6951, -0.6912]])
权重2:
 tensor([[-0.8471, -0.4721,  0.1653,  0.1795],
        [-0.4072,  0.5991, -0.6437,  0.2467]])

Process finished with exit code 0

完成。

posted on 2023-04-06 20:04  凯鲁嘎吉  阅读(149)  评论(0编辑  收藏  举报