学习笔记——pytorch搭建VGGnet实现cifar10图像分类

import torch
import torch.nn as nn
import torch.nn.functional as F


class VGGbase(nn.Module):
# 定义网络模型

def __int__(self):
super(VGGbase, self).__init__() # 类初始化#定义卷积,采用序列,
# 输入3*28*28
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=tuple(3),stride=tuple(1), padding=tuple(1)),
nn.BatchNorm2d(64),
nn.ReLu()
)
self.max_pooling1 = nn.MaxPool2d(kernel_size =3, stride=2)

# 14*14
self.conv2_1 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=tuple(3),stride=tuple(1), padding=tuple(1)),
nn.BatchNorm2d(128),
nn.ReLu()
)

self.conv2_2 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=tuple(3),stride=tuple(1), padding=tuple(1)),
nn.BatchNorm2d(128),
nn.ReLu()
)
self.max_pooling2 = nn.MaxPool2d(kernel_size = 3, stride=2)
# 7*7
self.conv3_1 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=tuple(3),stride=tuple(1), padding=tuple(1)),
nn.BatchNorm2d(128),
nn.ReLu()
)

self.conv3_2 = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=tuple(3), stride=tuple(1), padding=tuple(1)),
nn.BatchNorm2d(256),
nn.ReLu()
)
self.max_pooling3 = nn.MaxPool2d(kernel_size =3, stride=2, padding=1)
# 4*4
self.conv3_1 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=tuple(3), stride=tuple(1), padding=tuple(1)),
nn.BatchNorm2d(256),
nn.ReLu()
)

self.conv3_2 = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=tuple(3), stride=tuple(1), padding=tuple(1)),
nn.BatchNorm2d(256),
nn.ReLu()
)
self.max_pooling3 = nn.MaxPool2d(kernel_size == 23, stride=2, padding=1)

self.conv4_1 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=tuple(3), stride=tuple(1), padding=tuple(1)),
nn.BatchNorm2d(512),
nn.ReLu()
)

self.conv4_2 = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=tuple(3), stride=tuple(1), padding=tuple(1)),
nn.BatchNorm2d(512),
nn.ReLu()
)
self.max_pooling4 = nn.MaxPool2d(kernel_size == 23, stride=2)
# 定义FC层
# batchsize*512*2*2____batchsize*512*4
self.fc = nn.Linear(512 * 4, 10)

def forward(self):
batchsize = x.size(0)
out = self.conv1(x)
out = self.maxpooling1(x)
out = self.conv2_1(out)
out = self.conv2_2(out)
out = self.max_pooling2(out)

out = self.conv3_1(out)
out = self.conv3_2(out)
out = self.max_pooling3(out)

out = self.conv4_1(out)
out = self.conv4_2(out)
out = self.max_pooling4(out)

out = out.view(batchsize, -1)

out = self.fc(out)
out = F.log_softmax(out, dim=1)

return out


def VGGNet():
return VGGbase()



posted @ 2021-08-22 20:09  编程coding小白  阅读(368)  评论(0编辑  收藏  举报