Pytorch系列:(四)IO操作

首先注意pytorch中模型保存有两种格式,pth和pkl,其中,pth是pytorch默认格式,pkl还支持pickle库,不过一般如果没有特殊需求的时候,推荐使用默认pth格式保存

pytorch中有两种数据保存方法,一种是存储整个模型,一种只存储参数

方法一:存储整个模型

#保存

torch.save(model1, 'net.pth')

#读取

model1 = torch.load('net.pth')

方法二:存储模型参数

#保存

torch.save(model.state_dict(), 'checkpoint.pth')

#提取

state_dict = torch.load('checkpoint.pth')

model.load_state_dict(state_dict)

state_dict说明

state_dict 包含了模型使用的所有参数(Parameter类型),如果自定义的模型参数没有用Parameter封装,那么不会出现在state_dict中, 所以使用的时候,自定义参数一定不要忘记使用Parameter进行封装。

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.w1 = torch.randn(10,2)
        self.w2 = nn.Parameter(torch.randn(2,1))
        self.l1 = nn.Linear(10,1)

    def forward(self,x):
        pass 


net = MLP()

net.state_dict()

输出,可以发现只有w2和l1

OrderedDict([('w2',
              tensor([[0.9826],
                      [0.4665]])),
             ('l1.weight',
              tensor([[ 0.3098,  0.0985, -0.2566, -0.1024,  0.0449, -0.1681, -0.1743,  0.2985,
                       -0.0644, -0.0181]])),
             ('l1.bias', tensor([-0.2871]))])

中间状态保存

在训练的时候,可以保存训练中的中间状态,只需要把参数都保存到state字典中就可以了。 例如,在断点续传任务中,可以把epoch,模型状态,优化器状态,初始learning rate 等进行保存。

state = {
          'state_dict': net.state_dict(),
          'optimizer': optim.optimizer.state_dict(),
          'lr_base': optim.lr_base
          'epoch': epoch
        }
            
torch.save(
            state,
            self.CKPTS_PATH +
            'ckpt_' + self.VERSION +
            '/epoch'+ str(epoch) +
            '.pkl'
          )

加载

state = torch.load(
                    self.CKPTS_PATH +
                    'ckpt_' + self.VERSION +
                    '/epoch'+ str(epoch) +
                    '.pkl'
                   )  

 
net.load_state_dict(state['state_dict'])

optim.optimizer.load_state_dict(state['optimizer'])
optim.lr_base = state['lr_base']
start_epoch = state['epoch']

posted @ 2021-05-06 22:31  Neo0oeN  阅读(327)  评论(0编辑  收藏  举报