1603自定义层

点击查看代码
import torch
import torch.nn.functional as F
from torch import nn


# 构造一个没有任何参数的自定义层
print("构造一个没有任何参数的自定义层")
class CenteredLayer(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X):
        # 均值变为0
        return X - X.mean()

layer = CenteredLayer()
print(layer(torch.FloatTensor([1, 2, 3, 4, 5])))
# 将层作为组件合并到更复杂的模型中
print("将层作为组件合并到更复杂的模型中")
net = nn.Sequential(
    nn.Linear(8, 128),
    CenteredLayer()
)

Y = net(torch.rand(4, 8))
print(Y.mean())
# 带参数的图层
print("带参数的图层")
class MyLinear(nn.Module):
    def __init__(self, in_units, units):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_units, units))
        self.bias = nn.Parameter(torch.randn(units))

    def forward(self, X):
        linear = torch.matmul(X, self.weight.data) + self.bias.data
        return F.relu(linear)

dense = MyLinear(5, 3)
print(dense)
print(dense.weight)
# 使用自定义层直接执行正向传播
print("使用自定义层直接执行正向传播")
print(dense(torch.rand(2, 5)))
# 使用自定义层构建模型
print("使用自定义层构建模型")
net = nn.Sequential(
    MyLinear(64, 8),
    MyLinear(8, 1)
)
print(net(torch.rand(2, 64)))
posted @   荒北  阅读(18)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
点击右上角即可分享
微信分享提示