pytorch快速加载预训练模型参数的方式

pytorch快速加载预训练模型参数的方式

针对的预训练模型是通用的模型,也可以是自定义模型,大多是vgg16 ,  resnet50 , resnet101 , 等,从官网加载太慢

直接修改源码,改为本地地址

1.直接使用默认程序里的下载方式,往往比较慢;

2.通过修改源代码,使得模型加载已经下载好的参数,修改地方如下:

通过查找自己代码里所调用网络的类,使用pycharm自带的函数查找功能(ctrl+鼠标左键),查看此网络的加载方法,修改model.load_state_dict()函数。

例如:已经下载好的resnet50的参数文件:放在model_urls里面,这样就可以提前下载直接使用。

model_urls = {
'resnet50': '/home/huihua/NewDisk1/pretrain_parameter/resnet50-19c8e357.pth',
}

 

把模型权重下载至torch的缓存文件夹

在本用户目录下,linux和win有不同

cd .cache/torch/checkpoints
cd /home/team/.torch/models
两种方式,常常是用第二种作为torch模型的缓存文件夹, /home/team/是用户的文件夹

进入文件夹把所需模型权重放入即可自动加载,相比第一种方法简单点。

posted @ 2019-04-15 17:02  you-wh  阅读(5333)  评论(0编辑  收藏  举报
Fork me on GitHub