CNN(Pytorch版)实现GTA5的自动驾驶——第二节(torchvision的model使用)

这里我不想涉及太多CNN基础介绍,因为内容太多了,如果有兴趣可以参考以下链接学习

  1. 李沐老师的《动手学深度学习》
  2. B站视频《动手学深度学习》

因为torchvision已经包含了一些model,所以不必在意网络架构的设计,只需要调用即可

以Alexnet为例

import torchvision.models as models
alexnet = models.alexnet()
print(alexnet) #通过print查看网络结构

我们可以看到在 classfier下的最后一个Linear的input=4066, output=1000 在本系列中,我们需要的output是9.通过使用add_module来增加

# 给alexnet增加一个模块,名为our_output, 输入为1000(也就是上一层的输出),输出为9(本系列需要的结果)
alexnet.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))

完整代码如下

import torchvision.models as models
import torch.nn as nn
alexnet = models.alexnet()
alexnet.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
print(alexnet)

输出结果,可以看到新增的一层

把alexnet封装成一个函数,方便我们后续调用

def get_alex():
    alexnet = models.alexnet()
    alexnet.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
    return alexnet

按照同样的方法,可以构建如下的网络(截止到2021年12月13日torchvision提供的model)

上述提供的网络太多了,我选择了其中的几个网络。并用构建alexnet的方法构建了几个网络

import torch
from torchvision import models
import torch.nn as nn
from conf import config

cf = config()


def get_alex():
    alexnet = models.alexnet(pretrained=True)
    alexnet.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
    return alexnet
def get_res18():
    resnet18 = models.resnet18(pretrained=True)
    resnet18.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
    return resnet18
def get_widerest():
    wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
    wide_resnet50_2.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
    return wide_resnet50_2

def build_model(model_name):
    if model_name == 'resnet18':
        model = get_res18().to(cf.DEVICE)
    elif model_name == 'alexnet':
        model = get_alex().to(cf.DEVICE)
    elif model_name == 'wide_resnet50_2':
        model = get_widerest().to(cf.DEVICE)
    return model

参考链接

  1. torchvision官方链接:https://pytorch.org/vision/stable/models.html#classification
posted @ 2021-12-13 20:51  Adam_lxd  阅读(449)  评论(0编辑  收藏  举报