18-神经网络-自定义带参数的层

1、nn.Parameter函数


2、torch.mm 和torch.matmul区别
都是 PyTorch 中用于矩阵乘法的函数,但它们在使用上有细微的差别

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyLinear(nn.Module):
    def __init__(self, in_units, out_units):
        super(MyLinear, self).__init__()
        self.weight = nn.Parameter(torch.randn((in_units, out_units)))
        self.bias = nn.Parameter(torch.randn(out_units))

    def forward(self, x):
        linear = torch.matmul(x, self.weight) + self.bias
        return F.relu(linear)

linear = MyLinear(5, 3)
print(linear.weight)

y = linear(torch.rand((2, 5)))
print(y)
posted @ 2024-08-25 11:41  不是孩子了  阅读(0)  评论(0编辑  收藏  举报