在pytorch中保存模型或模型参数

在 PyTorch 中,我们可以使用 torch.save 函数将 PyTorch 模型保存到文件。这个函数接受两个参数:要保存的对象(通常是模型),以及文件路径。

保存模型参数

import torch
import torch.nn as nn

# 假设有一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 5)

model = SimpleModel()

# 这里可以进行模型的训练
# training step......

# 定义保存路径
save_path = 'simple_model.pth'

# 使用 torch.save 保存模型
torch.save(model.state_dict(), save_path)

在上面的例子中,model.state_dict() 用于获取模型的状态字典(包含模型的所有参数)。然后,torch.save 函数将这个状态字典保存到指定的文件路径('simple_model.pth')。

再次需要用到模型时可以调用参数:

# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SimpleModel().to(device)
model.load_state_dict(torch.load('simple_model.pth'))
model.eval()

保存整个模型

如果想保存整个模型(包括模型的架构和参数),而不仅仅是参数,我们可以直接传递整个模型对象给 torch.save

# 定义保存路径
torch.save(model, save_path)

要加载已保存的模型,可以使用 torch.load 函数:

loaded_model = torch.load(save_path)

这将加载模型的状态字典或整个模型,具体取决于你保存模型时使用的方法。

请注意,加载模型时,确保你的代码中定义了模型的类(例如,SimpleModel)以便正确加载模型的架构。

posted @   落魄统计佬  阅读(377)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
点击右上角即可分享
微信分享提示