保存模型有两种方式,方式不同,在调用模型的时候也不同

我更建议用torch.jit。。。这样不需要在写模型的参数

torch.save

1
2
3
4
5
6
7
8
9
10
11
12
13
14
保存模型:
import torch
import torch.nn as nn
 
# 假设 model 是你的 PyTorch 模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)
 
model = SimpleModel()
 
# 保存模型到文件
torch.save(model.state_dict(), 'model.pth')<br>解释:<br><code>model.state_dict()</code> 返回模型的参数字典,<code>torch.save</code> 将这个字典保存到名为 <code>model.pth</code> 的文件中。

  

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
调用模型:
import torch
import torch.nn as nn
 
# 假设 model 是你的 PyTorch 模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)
 
model = SimpleModel()
 
# 加载模型参数
model.load_state_dict(torch.load('model.pth'))
 
# 将模型设为评估模式(如果是测试模型)
model.eval()
outputs = model(data.float())

  

torch.jit.script

TorchScript — PyTorch 2.1 documentation

torch.jit 模块是 PyTorch 中的即时(just-in-time)编译模块,提供了一种将 PyTorch 模型转换为脚本(script)或 Torch 脚本(TorchScript)的方法。Torch 脚本是一种中间表示形式,可以在不依赖 Python 解释器的情况下在 PyTorch 中运行。

可以将整个模型保存为一个 Torch 脚本文件,而不仅仅是模型的参数。这样做可以更轻松地保存和加载整个模型。

1
2
3
4
5
6
7
8
9
10
11
12
保存模型:import torch
import torch.jit
 
# model 是我的 PyTorch 模型
class SimpleModel(torch.nn.Module):
    def forward(self, x):
        return x + 1
 
model = SimpleModel()
 
# 将模型转换为 Torch 脚本
scripted_model = torch.jit.script(model)# 保存 Torch 脚本到文件 scripted_model.save("scripted_model.pt")

 

1
2
3
4
# 调用模型
model = torch.jit.load("scripted_model.pt")# 将模型设为评估模式(如果是测试模型)
model.eval()
outputs = model(data.float())

  

 

posted on   黑逍逍  阅读(167)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!



点击右上角即可分享
微信分享提示