网络模型的保存与读取

保存

点击查看代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d

vgg16_false = torchvision.models.vgg16(pretrained=False)
# 保存方式一,保存模型结构和参数
torch.save(vgg16_false, "vgg16_method1.pth")

# 保存方式二,保存模型参数
torch.save(vgg16_false.state_dict(), "vgg16_method2.pth")

# 易错
class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.covn1 = Conv2d(1, 1, 3)
    def forward(self, input):
        input = self.covn1(input)
        return input

test = Test()
torch.save(test, "test.pth")
print(test)



读取

点击查看代码
import torch
import torchvision
from torch import nn
from test_model_save import *


# 读取方式1
model = torch.load("vgg16_method1.pth")
# print(model)

# 读取方式2
model = torchvision.models.vgg16(pretrained=False)
model.load_state_dict(torch.load("vgg16_method2.pth"))
# print(model)

# 易错
model = torch.load("test.pth")
print(model)
posted @   荒北  阅读(41)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】
点击右上角即可分享
微信分享提示