1604读写文件

点击查看代码
import torch
from torch import nn
from torch.nn import functional as F

# 加载和保存张量
print("加载和保存张量")
x = torch.arange(4)
torch.save(x, 'x-file')

x2 = torch.load('x-file')
print(x2)

# 存储张量列表,读回内存
print("存储张量列表,读回内存")
y = torch.zeros(4)
torch.save([x, y], 'x-files')
x2, y2 = torch.load('x-files')
print((x2, y2))

# 写入或读取从字符串映射到张量的字典
print("写入或读取从字符串映射到张量的字典")
mydict = {'x' : x, 'y': y}
torch.save(mydict, 'mydict')
mydict2 = torch.load('mydict')
print(mydict2)

# 加载和保存模型参数
# 存权重,不存定义
print("加载和保存模型参数")
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

    def forward(self, X):
        return self.output(F.relu(self.hidden(X)))

net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)
print(Y)
# 将模型的参数存储到 mlp.params 文件中
torch.save(net.state_dict(), 'mlp.params')

# 实例化原始多层感知机模型的一个备份,直接读取文件中存储的参数
clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))
print(clone)
# 由于两个实例具有相同的模型参数,在输入相同的X时, 两个实例的计算结果应该相同。
Y_clone = clone(X)
print(Y_clone == Y)
posted @   荒北  阅读(18)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· 三行代码完成国际化适配,妙~啊~
· .NET Core 中如何实现缓存的预热?
点击右上角即可分享
微信分享提示