CNN(Pytorch版)实现GTA5的自动驾驶——第二节(torchvision的model使用)
目录
这里我不想涉及太多CNN基础介绍,因为内容太多了,如果有兴趣可以参考以下链接学习
因为torchvision已经包含了一些model,所以不必在意网络架构的设计,只需要调用即可
以Alexnet为例
import torchvision.models as models
alexnet = models.alexnet()
print(alexnet) #通过print查看网络结构
我们可以看到在 classfier下的最后一个Linear的input=4066, output=1000 在本系列中,我们需要的output是9.通过使用add_module
来增加
# 给alexnet增加一个模块,名为our_output, 输入为1000(也就是上一层的输出),输出为9(本系列需要的结果)
alexnet.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
完整代码如下
import torchvision.models as models
import torch.nn as nn
alexnet = models.alexnet()
alexnet.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
print(alexnet)
输出结果,可以看到新增的一层
把alexnet封装成一个函数,方便我们后续调用
def get_alex():
alexnet = models.alexnet()
alexnet.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
return alexnet
按照同样的方法,可以构建如下的网络(截止到2021年12月13日torchvision提供的model)
上述提供的网络太多了,我选择了其中的几个网络。并用构建alexnet的方法构建了几个网络
import torch
from torchvision import models
import torch.nn as nn
from conf import config
cf = config()
def get_alex():
alexnet = models.alexnet(pretrained=True)
alexnet.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
return alexnet
def get_res18():
resnet18 = models.resnet18(pretrained=True)
resnet18.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
return resnet18
def get_widerest():
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
wide_resnet50_2.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
return wide_resnet50_2
def build_model(model_name):
if model_name == 'resnet18':
model = get_res18().to(cf.DEVICE)
elif model_name == 'alexnet':
model = get_alex().to(cf.DEVICE)
elif model_name == 'wide_resnet50_2':
model = get_widerest().to(cf.DEVICE)
return model
参考链接
- torchvision官方链接:https://pytorch.org/vision/stable/models.html#classification