pytorch预训练模型

1.加载预训练模型:

只加载模型,不加载预训练参数:resnet18 = models.resnet18(pretrained=False)

print resnet18 打印模型结构

resnet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))加载预先下载好的预训练参数到resnet18

print resnet18 打印的还是模型结构

note: cnn = resnet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))是错误的,这样cnn将是nonetype

pre_dict = resnet18.state_dict()按键值对将模型参数加载到pre_dict

print for k, v in pre_dict.items(): 打印模型参数

for k, v in pre_dict.items():

  print k

打印模型每层命名

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 

note:model是自己定义好的模型,将pretrained_dict和model_dict中命名一致的层加入pretrained_dict(包括参数)

加载模型和预训练参数:resnet34 = models.resnet34(pretrained=True)

 

reference:

1.

http://blog.csdn.net/VictoriaW/article/details/72821329

 

2.

vgg16 = models.vgg16(pretrained=True)

pretrained_dict = vgg16.state_dict()

model_dict = model.state_dict()

# 1. filter out unnecessary keys

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

# 2. overwrite entries in the existing state dict

model_dict.update(pretrained_dict)

# 3. load the new state dict

model.load_state_dict(model_dict)

 

posted @ 2017-12-17 23:13  zhanghouyu  阅读(11872)  评论(0编辑  收藏  举报