pytorch迁移学习mobilenet1

上个博客讲了怎么制作参数字典,这次讲怎么迁移,怎么按照层迁移。代码还有待寻优,现在先看看吧,

import torch
import torch.nn as nn
from torch import optim
import visdom
from torch.utils.data import DataLoader
from MobileNet.mobilenet_v1 import MobileNet
from MobileNet.iris_csv import Iris

batch_size=16
base_learning_rate=1e-4

epoches=10
torch.manual_seed(1234)
vis=visdom.Visdom()
train_db=Iris('/root/demo',64,128,'train')
validation_db=Iris('/root/demo',64,128,'validation')
test_db=Iris('root/demo',64,128,'test')

train_loader=DataLoader(train_db,batch_size=batch_size,shuffle=True,num_workers=4)
validation_loader=DataLoader(validation_db,batch_size=batch_size,num_workers=2)
test_loader=DataLoader(test_db,batch_size=batch_size,num_workers=2)
def evaluate(model,loader):
    correct=0
    total_num=len(loader.dataset)
    for x,y in loader:
        # x,y=x.to(device),y.to(device)
        with torch.no_grad():
            logits=model(x)
            pred=logits.argmax(dim=1)
        correct+=torch.eq(pred,y).sum().float().item()
    return correct/total_num
def adapt_weights(pthfile,module):
    module_dict=module.state_dict()
    pretrained_dict=torch.load(pthfile)
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in module_dict}
    module_dict.update(pretrained_dict)
    module.load_state_dict(module_dict)

def main():
    mod=MobileNet(35)
    mod_dict = mod.state_dict()
    nn.init.kaiming_normal_(mod.upchannel.weight, nonlinearity='relu')
    nn.init.constant_(mod.upchannel.bias,0.1)
    pretrained_dict = torch.load('/root/tf_to_torch.pth')
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in mod_dict}
    mod_dict.update(pretrained_dict)
    mod.load_state_dict(mod_dict)
    freeze_list=list(mod.state_dict().keys())[0:-2]
    # print(freeze_list)
    for name,param in mod.named_parameters():
         if name in freeze_list:
             param.requires_grad=False
         if param.requires_grad:
             print(name)
    optimizer=optim.SGD(filter(lambda p: p.requires_grad, mod.parameters()),lr=base_learning_rate)
    fun_loss = nn.CrossEntropyLoss()
    vis.line([0.], [-1], win='train_loss', opts=dict(title='train_loss'))
    vis.line([0.], [-1], win='validation_acc', opts=dict(title='validation_acc'))
    global_step = 0
    best_epoch, best_acc = 0, 0
    for epoch in range(10):
        for step, (x, y) in enumerate(train_loader):
            logits = mod(x)
            # print(logits.shape)
            loss = fun_loss(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            vis.line([loss.item()], [global_step], win='train_loss', update='append')
            global_step += 1


        if epoch%1==0:
            val_acc = evaluate(mod, validation_loader)
            if  val_acc > best_acc:
                best_acc = val_acc
                best_epoch = epoch
                torch.save(mod.state_dict(), 'best.pth')
                vis.line([val_acc], [global_step], win='validation_acc', update='append')

    print('best acc', best_acc, 'best epoch', best_epoch)

if __name__ == '__main__':
    main()

root的地方就是电脑的路径,根据自己的工程来就行。freeze_list就是不更新的层的key的名称,你不想哪一层的参数更新你就把哪一层的参数名写进去,然后用

for name,param in mod.named_parameters()

这一行得到参数字典里所有的参数名和参数本身,如果name在freeze_list当中,那你需要将它冻结,不然参数更新,只把它作为特征提取器使用。

posted @ 2020-05-07 20:59  daremosiranaihana  阅读(879)  评论(0编辑  收藏  举报