torch.nn.Linear解释

 

 

torch.nn.Linear的作用是对输入向量进行矩阵的乘积和加法。y=x(A)转置+b。这点类似于全连接神经网络的的隐藏层。in_feature代表输入神经元的个数。out_feature代表输出神经元的个数。bias为False不参与训练。如果为True则参与训练。

x = torch.randn(20)  # 输入的维度是(20)
m = torch.nn.Linear(20, 1)  # 20,1是指输入维度、输出维度 神经网络又20个输入神经元,1个输出神经元。
output = m(x)
print('m.weight.shape:\n ', m.weight.shape)
print('m.bias.shape:\n', m.bias.shape)
print('output.shape:\n', output.shape)
输出结果:

m.weight.shape:
torch.Size([1, 20])//运算时,权值需要转置。
m.bias.shape:
torch.Size([1])//只有一个神经元故bias只有一个。
output.shape:
torch.Size([1])//一个神经元只有一个输出值。

posted @ 2021-06-21 16:22  祥瑞哈哈哈  阅读(3361)  评论(0编辑  收藏  举报