18-神经网络-自定义带参数的层

1、nn.Parameter函数


2、torch.mm 和torch.matmul区别
都是 PyTorch 中用于矩阵乘法的函数,但它们在使用上有细微的差别

import torch
import torch.nn as nn
import torch.nn.functional as F
class MyLinear(nn.Module):
def __init__(self, in_units, out_units):
super(MyLinear, self).__init__()
self.weight = nn.Parameter(torch.randn((in_units, out_units)))
self.bias = nn.Parameter(torch.randn(out_units))
def forward(self, x):
linear = torch.matmul(x, self.weight) + self.bias
return F.relu(linear)
linear = MyLinear(5, 3)
print(linear.weight)
y = linear(torch.rand((2, 5)))
print(y)
posted @   不是孩子了  阅读(16)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示