深度残差网络(ResNet)原理与实现(tensorflow2.x)
ResNet原理
深层网络在学习任务中取得了超越人眼的准确率,但是,经过实验表明,模型的性能和模型的深度并非成正比,是由于模型的表达能力过强,反而在测试数据集中性能下降。ResNet的核心是,为了防止梯度弥散或爆炸,让信息流经快捷连接到达浅层。
更正式的讲,输入\(x\)通过卷积层,得到特征变换后的输出\(F(x)\),与输入\(x\)进行对应元素的相加运算,得到最终输出\(H(x)\):
\[H(x) = x + F(x)
\]
VGG模块和残差模块对比如下:
为了能够满足输入\(x\)与卷积层的输出\(F(x)\)能够相加运算,需要输入\(x\)的 shape 与\(F(x)\)的shape 完全一致。当出现 shape 不一致时,一般通过Conv2D进行变换,该Conv2D的核为1×1,步幅为2。
ResNet实现
使用tensorflow2.3实现ResNet
模型创建
import numpy as np
import tensorflow as tf
from tensorflow import keras
from matplotlib import pyplot as plt
import os
import math
"""
用于控制模型层数
"""
#残差块数
n = 3
depth = n * 9 + 1
def resnet_layer(inputs,
num_filters=16,
kernel_size=3,
strides=1,
activation='relu',
batch_normalization=True,
conv_first=True):
"""2D Convolution-Batch Normalization-Activation stack builder
Arguments:
inputs (tensor): 输入
num_filters (int): 卷积核个数
kernel_size (int): 卷积核大小
activation (string): 激活层
batch_normalization (bool): 是否使用批归一化
conv_first (bool): conv-bn-active(True) or bn-active-conv (False)层堆叠次序
Returns:
x (tensor): 输出
"""
conv = keras.layers.Conv2D(num_filters,
kernel_size=kernel_size,
strides=strides,
padding='same',
kernel_initializer='he_normal',
kernel_regularizer=keras.regularizers.l2(1e-4))
x = inputs
if conv_first:
x = conv(x)
if batch_normalization:
x = keras.layers.BatchNormalization()(x)
if activation is not None:
x = keras.layers.Activation(activation)(x)
else:
if batch_normalization:
x = keras.layers.BatchNormalization()(x)
if activation is not None:
x = keras.layers.Activation(activation)(x)
x = conv(x)
return x
def resnet(input_shape,depth,num_classes=10):
"""ResNet
Arguments:
input_shape (tensor): 输入尺寸
depth (int): 网络层数
num_classes (int): 预测类别数
Return:
model (Model): 模型
"""
if (depth - 2) % 6 != 0:
raise ValueError('depth should be 6n+2')
#超参数
num_filters = 16
num_res_blocks = int((depth - 2) / 6)
inputs = keras.layers.Input(shape=input_shape)
x = resnet_layer(inputs=inputs)
for stack in range(3):
for res_block in range(num_res_blocks):
strides = 1
if stack > 0 and res_block == 0:
strides = 2
y = resnet_layer(inputs=x,num_filters=num_filters,
strides=strides)
y = resnet_layer(inputs=y,num_filters=num_filters,
activation=None)
if stack > 0 and res_block == 0:
x = resnet_layer(inputs=x,
num_filters=num_filters,
kernel_size=1,
strides=strides,
activation=None,
batch_normalization=False)
x = keras.layers.add([x,y])
x = keras.layers.Activation('relu')(x)
num_filters *= 2
x = keras.layers.AveragePooling2D(pool_size=8)(x)
x = keras.layers.Flatten()(x)
outputs = keras.layers.Dense(num_classes,activation='softmax',
kernel_initializer='he_normal')(x)
model = keras.Model(inputs=inputs,outputs=outputs)
return model
model = resnet_v1(input_shape=input_shape,depth=depth)
数据加载
#加载数据
(x_train,y_train),(x_test,y_test) = keras.datasets.cifar10.load_data()
#计算类别数
num_labels = len(np.unique(y_train))
#转化为one-hot编码
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
#预处理
input_shape = x_train.shape[1:]
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
模型编译
#超参数
batch_size = 64
epochs = 200
#编译模型
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['acc'])
model.summary()
模型训练
model.fit(x_train,y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test,y_test),
shuffle=True)
测试模型
scores = model.evaluate(x_test,y_test,batch_size=batch_size,verbose=0)
print('Test loss: ',scores[0])
print('Test accuracy: ',scores[1])
训练过程
Epoch 104/200
782/782 [==============================] - ETA: 0s - loss: 0.2250 - acc: 0.9751
Epoch 00104: val_acc did not improve from 0.91140
782/782 [==============================] - 15s 19ms/step - loss: 0.2250 - acc: 0.9751 - val_loss: 0.4750 - val_acc: 0.9090
learning rate: 0.0001
Epoch 105/200
781/782 [============================>.] - ETA: 0s - loss: 0.2206 - acc: 0.9754
Epoch 00105: val_acc did not improve from 0.91140
782/782 [==============================] - 16s 20ms/step - loss: 0.2206 - acc: 0.9754 - val_loss: 0.4687 - val_acc: 0.9078
learning rate: 0.0001
Epoch 106/200
782/782 [==============================] - ETA: 0s - loss: 0.2160 - acc: 0.9769
Epoch 00106: val_acc did not improve from 0.91140
782/782 [==============================] - 15s 20ms/step - loss: 0.2160 - acc: 0.9769 - val_loss: 0.4886 - val_acc: 0.9053