17-神经网络-延迟初始化

使用torch.nn.LazyLinear(output)实现延迟初始化

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.LazyLinear(128)  # 输入维度设置为 None,表示延迟初始化
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)  # 输出维度为 10

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # 第一次调用 fc1 时才会初始化
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 实例化模型
model = MyModel()

# 打印模型参数,可以看到参数还没有初始化
print(model.fc1.weight)  # 输出:Parameter containing:
                         # [torch.FloatTensor of size (None, 128)]

# 准备一个输入数据,输入维度为 20
input_data = torch.randn(10, 20)

# 通过模型传递输入数据,触发参数初始化
output = model(input_data)

# 打印模型参数,可以看到参数已经初始化了
print(model.fc1.weight)  # 输出:Parameter containing:
                         # [torch.FloatTensor of size (20, 128)]
posted @ 2024-08-25 11:18  不是孩子了  阅读(20)  评论(0编辑  收藏  举报