pytorch迁移学习alexnet

话不多说我直接上代码,我为了验证state_dict的使用方法,全连接的时候写的有点不一样,之后我会试试其他模型的迁移学习,看看有没有什么更好的办法,字典实在是用的太不习惯了,python我唯一能忍受的就是列表了,别的都好难用。

 1 import torch
 2 import torch.nn as nn
 3 from torchvision.models import alexnet
 4 
 5 alex=alexnet(pretrained=True)
 6 # print(alex)
 7 # print(alex.state_dict().keys())
 8 pretrained_dict=alex.state_dict()
 9 weight_0=pretrained_dict['features.3.weight']
10 bias_0=pretrained_dict['features.3.bias']
11 print(weight_0.shape)
12 print(bias_0.shape)
13 class alex_net(nn.Module):
14     def __init__(self,num_classes):
15         super(alex_net, self).__init__()
16         self.features=nn.Sequential(
17             nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
18             nn.ReLU(inplace=True),
19             nn.MaxPool2d(kernel_size=3, stride=2),
20             nn.Conv2d(64, 192, kernel_size=5, padding=2),
21             nn.ReLU(inplace=True),
22             nn.MaxPool2d(kernel_size=3, stride=2),
23             nn.Conv2d(192, 384, kernel_size=3, padding=1),
24             nn.ReLU(inplace=True),
25             nn.Conv2d(384, 256, kernel_size=3, padding=1),
26             nn.ReLU(inplace=True),
27             nn.Conv2d(256, 256, kernel_size=3, padding=1),
28             nn.ReLU(inplace=True),
29             nn.MaxPool2d(kernel_size=3, stride=2),
30         )
31         self.avgpool=nn.AdaptiveAvgPool2d((6,6))
32         self.classifier=nn.Sequential(
33             nn.Dropout(0.5),
34             nn.Linear(256 * 6 * 6, 4096),
35             nn.ReLU(inplace=True),
36             nn.Dropout(),
37             nn.Linear(4096, 4096),
38             nn.ReLU(inplace=True),
39             # nn.Linear(4096,num_classes)
40         )
41         self.gategory=nn.Linear(4096, num_classes)
42     def forward(self,input):
43         out=self.features(input)
44         out=self.avgpool(out)
45         out=torch.flatten(out,1)
46         out=self.classifier(out)
47         out=self.gategory(out)
48         return out
49 
50 model=alex_net(num_classes=5)
51 print(model.state_dict().keys())
52 # print(model)
import torch
from torch import optim,nn
import visdom
from torchvision.models import alexnet
from torch.utils.data import DataLoader

from transfer_learning.poke import Pokemonn
from transfer_learning.model import alex_net
batch_size=16
learning_rate=1e-3
# device=torch.device('cuda')
epoches=10
# 设置随机种子,用于生成随机数
torch.manual_seed(1234)
vis = visdom.Visdom()
train_db=Pokemonn('/Users/wenyu/Desktop/TorchProject/Pokemann/pokeman',227,mode='train')
validation_db=Pokemonn('/Users/wenyu/Desktop/TorchProject/Pokemann/pokeman',227,mode='validation')
test_db=Pokemonn('/Users/wenyu/Desktop/TorchProject/Pokemann/pokeman',227,mode='test')

train_loader=DataLoader(train_db,batch_size=batch_size,shuffle=True,num_workers=4)
validation_loader=DataLoader(validation_db,batch_size=batch_size,num_workers=2)
test_loader=DataLoader(test_db,batch_size=batch_size,num_workers=2)
def evaluate(model,loader):
    correct=0
    total_num=len(loader.dataset)
    for x,y in loader:
        # x,y=x.to(device),y.to(device)
        with torch.no_grad():
            logits=model(x)
            pred=logits.argmax(dim=1)
        correct+=torch.eq(pred,y).sum().float().item()
    return correct/total_num
def main():
    # model = ResNet18(5).to(device)
    model=alex_net(5)
    model_dict=model.state_dict()
    pretrained_model=alexnet(pretrained=True)
    pretrained_dict=pretrained_model.state_dict()
    pretrained_dict={k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    optimizer=optim.SGD(model.parameters(),lr=learning_rate)
    fun_loss=nn.CrossEntropyLoss()
    vis.line([0.], [-1], win='train_loss', opts=dict(title='train_loss'))
    vis.line([0.], [-1], win='validation_acc', opts=dict(title='validation_acc'))
    global_step=0
    best_epoch,best_acc=0,0
    for epoch in range(epoches):
        for step,(x,y) in enumerate(train_loader):
            # x,y=x.to(device),y.to(device)
            logits=model(x)
            loss=fun_loss(logits,y)
            # pred=logits.argmax(dim=1)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            vis.line([loss.item()],[global_step],win='train_loss',update='append')
            global_step += 1

        if epoch % 1==0:
            val_acc=evaluate(model, validation_loader)
            if val_acc>best_acc:
                best_acc=val_acc
                best_epoch=epoch
                torch.save(model.state_dict(),'best.mdl')
                vis.line([val_acc],[global_step],win='validation_acc',update='append')

    print('best acc',best_acc,'best epoch',best_epoch)
    model.load_state_dict(torch.load('best.mdl'))
    print('load from ckpt')

    test_acc=evaluate(model,test_loader)
    print(test_acc)


if __name__ == '__main__':
    main()

训练这部分大部分是跟龙龙老师写的,数据集也是它的,我就想简单的验证一下迁移学习怎么用的,之后会做mobelnet,龙龙老师的pytorch讲的真的非常浅显易懂,但是迁移学习这块不是很全面,想学的话还需要再看看。

 

posted @ 2020-04-23 17:46  daremosiranaihana  阅读(1218)  评论(0编辑  收藏  举报