filename = 'cvae_' + str(epoch+1) + '.pkl'
save_path = save_dir / Path(filename)
states = {}
states['model'] = cvae.state_dict() # 模型参数
states['z_dim'] = args.z_dim
states['x_dim'] = args.x_dim
states['s_dim'] = args.s_dim
states['optim'] = cvae.state_dict()
torch.save(states, str(save_path)) #检查点:将states字典存放在save_path文件下
保存和加载模型的时候,配对的函数:
对于仅保存state_dict()的方式,那保存和加载模型的方式为:
保存:torch.save(model.state_dict(), PATH)
加载:model.laod_state_dict(torch.load(PATH))
一般加载模型是在训练完成后用模型做测试,这时候加载模型记得要加上model.eval(),把模型切换到evaluation模式,这时候会调整dropout和bactch的模式。
对于保存和加载整个模型的情况:
torch.save(model, PATH)
model = torch.load(PATH)
可以看到,前面的model.load_state_dict()和这里的不同,前面的情况需要你先定义一个模型,然后再load_state_dict()
但是这里load整个模型,会把模型的定义一起load进来。完成了模型的定义和加载参数的两个过程。
详细代码
1 def save(self):
2 save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
3
4 if not os.path.exists(save_dir):
5 os.makedirs(save_dir)
6
7 torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl'))
8 torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl'))
9
10 with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f:
11 pickle.dump(self.train_hist, f)
12 # 使用方法:对模型初始化以后,使用以下方法,加载模型的参数,以至于不用再对数据集进行训练
13 def load(self):
14 save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
15
16 self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl')))
17 self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))
note:
pickle.dump(obj, file, [,protocol]) 序列化对象,将对象obj保存到文件file中去。self.train_hist用于存放模型文件
pickle.load(file) 反序列化对象,将文件中的数据解析为一个python对象。file中有read()接口和readline()接口