17-神经网络-延迟初始化
使用torch.nn.LazyLinear(output)实现延迟初始化
import torch import torch.nn as nn class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.fc1 = nn.LazyLinear(128) # 输入维度设置为 None,表示延迟初始化 self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 10) # 输出维度为 10 def forward(self, x): x = torch.relu(self.fc1(x)) # 第一次调用 fc1 时才会初始化 x = torch.relu(self.fc2(x)) x = self.fc3(x) return x # 实例化模型 model = MyModel() # 打印模型参数,可以看到参数还没有初始化 print(model.fc1.weight) # 输出:Parameter containing: # [torch.FloatTensor of size (None, 128)] # 准备一个输入数据,输入维度为 20 input_data = torch.randn(10, 20) # 通过模型传递输入数据,触发参数初始化 output = model(input_data) # 打印模型参数,可以看到参数已经初始化了 print(model.fc1.weight) # 输出:Parameter containing: # [torch.FloatTensor of size (20, 128)]
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)