保存
点击查看代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
vgg16_false = torchvision.models.vgg16(pretrained=False)
torch.save(vgg16_false, "vgg16_method1.pth")
torch.save(vgg16_false.state_dict(), "vgg16_method2.pth")
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.covn1 = Conv2d(1, 1, 3)
def forward(self, input):
input = self.covn1(input)
return input
test = Test()
torch.save(test, "test.pth")
print(test)
读取
点击查看代码
import torch
import torchvision
from torch import nn
from test_model_save import *
model = torch.load("vgg16_method1.pth")
model = torchvision.models.vgg16(pretrained=False)
model.load_state_dict(torch.load("vgg16_method2.pth"))
model = torch.load("test.pth")
print(model)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】