pytorch快速加载预训练模型参数的方式
pytorch快速加载预训练模型参数的方式
针对的预训练模型是通用的模型,也可以是自定义模型,大多是vgg16 , resnet50 , resnet101 , 等,从官网加载太慢
直接修改源码,改为本地地址
1.直接使用默认程序里的下载方式,往往比较慢;
2.通过修改源代码,使得模型加载已经下载好的参数,修改地方如下:
通过查找自己代码里所调用网络的类,使用pycharm自带的函数查找功能(ctrl+鼠标左键),查看此网络的加载方法,修改model.load_state_dict()函数。
例如:已经下载好的resnet50的参数文件:放在model_urls里面,这样就可以提前下载直接使用。
model_urls = {
'resnet50': '/home/huihua/NewDisk1/pretrain_parameter/resnet50-19c8e357.pth',
}
把模型权重下载至torch的缓存文件夹
在本用户目录下,linux和win有不同
cd .cache/torch/checkpoints
cd /home/team/.torch/models
两种方式,常常是用第二种作为torch模型的缓存文件夹, /home/team/是用户的文件夹
进入文件夹把所需模型权重放入即可自动加载,相比第一种方法简单点。
凤舞九天