torch.nn.Linear解释
torch.nn.Linear的作用是对输入向量进行矩阵的乘积和加法。y=x(A)转置+b。这点类似于全连接神经网络的的隐藏层。in_feature代表输入神经元的个数。out_feature代表输出神经元的个数。bias为False不参与训练。如果为True则参与训练。
x = torch.randn(20) # 输入的维度是(20)
m = torch.nn.Linear(20, 1) # 20,1是指输入维度、输出维度 神经网络又20个输入神经元,1个输出神经元。
output = m(x)
print('m.weight.shape:\n ', m.weight.shape)
print('m.bias.shape:\n', m.bias.shape)
print('output.shape:\n', output.shape)
输出结果:
m.weight.shape:
torch.Size([1, 20])//运算时,权值需要转置。
m.bias.shape:
torch.Size([1])//只有一个神经元故bias只有一个。
output.shape:
torch.Size([1])//一个神经元只有一个输出值。