Resnet——深度残差网络(二)
基于上一篇resnet网络结构进行实战。
再来贴一下resnet的基本结构方便与代码进行对比
resnet的自定义类如下:
import tensorflow as tf from tensorflow import keras class BasicBlock(keras.layers.Layer): # filter_num指定通道数,stride指定步长 def __init__(self,filter_num,stride=1): super(BasicBlock, self).__init__() # 注意padding=same并不总使得输入维度等于输出维度,而是对不同的步长有不同的策略,使得滑动更加完整 self.conv1 = keras.layers.Conv2D(filter_num,(3,3),strides=stride,padding='same') self.bn1 = keras.layers.BatchNormalization() self.relu = keras.layers.Activation('relu') self.conv2 = keras.layers.Conv2D(filter_num,(3,3),strides=1,padding='same') self.bn2 = keras.layers.BatchNormalization() if stride!=1: self.dowmsample = keras.Sequential() self.dowmsample.add(keras.layers.Conv2D(filter_num,(1,1),strides=stride)) else: self.dowmsample = lambda x:x def call(self, inputs, training=None): out = self.conv1(inputs) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) identity = self.dowmsample(inputs) output = keras.layers.add([out,identity]) output = tf.nn.relu(output) return output class ResNet(keras.Model): # resnet基本结构为[2,2,2,2],即分为四个部分,每个部分又分两个小部分 def __init__(self,layer_dims,num_classes=100): super(ResNet,self).__init__() # 预处理层 self.stem = keras.Sequential([ keras.layers.Conv2D(64,(3,3),strides=(1,1)), keras.layers.BatchNormalization(), keras.layers.Activation('relu'), keras.layers.MaxPool2D(pool_size=(2,2),strides=(1,1),padding='same') ]) self.layer1 = self.build_resblock(64,layer_dims[0]) self.layer2 = self.build_resblock(128, layer_dims[1], stride=2) self.layer3 = self.build_resblock(256, layer_dims[2], stride=2) self.layer4 = self.build_resblock(512, layer_dims[3], stride=2) # 自适应输出,方便送入全连层进行分类 self.avgpool = keras.layers.GlobalAveragePooling2D() self.fc = keras.layers.Dense(num_classes) def call(self, inputs, training=None): x = self.stem(inputs) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = self.fc(x) return x def build_resblock(self,filter_num,blocks,stride=1): res_blocks = keras.Sequential(); res_blocks.add(BasicBlock(filter_num,stride)) for _ in range(1,blocks): res_blocks.add(BasicBlock(filter_num,1)) return res_blocks def resnet18(): return ResNet([2,2,2,2])
训练过程如下:
import tensorflow as tf from tensorflow import keras import os from resnet import resnet18 os.environ['TF_CPP_MIN_LOG'] = '2' def preprocess(x,y): x = 2*tf.cast(x,dtype=tf.float32)/255.-1 y = tf.cast(y,dtype=tf.int32) return x,y (x,y),(x_test,y_test) = keras.datasets.cifar100.load_data() y = tf.squeeze(y,axis=1) y_test = tf.squeeze(y_test,axis=1) print(x.shape,y.shape,x_test.shape,y_test.shape) train_db = tf.data.Dataset.from_tensor_slices((x,y)) train_db = train_db.shuffle(1000).map(preprocess).batch(64) test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test)) test_db = train_db.map(preprocess).batch(64) def main(): model = resnet18() model.build(input_shape=(None,32,32,3)) optimizer = keras.optimizers.Adam(lr=1e-3) model.summary() for epoch in range(50): for step,(x,y) in enumerate(train_db): with tf.GradientTape() as tape: logits = model(x) y_onehot = tf.one_hot(y,depth=10) loss = tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True) loss = tf.reduce_mean(loss) gradient = tape.gradient(loss,model.trainable_variables) optimizer.apply_gradients(zip(gradient,model.trainable_variables)) if step % 100 == 0: print(epoch,step,'loss:',float(loss)) total_num = 0 total_correct = 0 for x,y in test_db: logits = model(x) prob = tf.nn.softmax(logits,axis=1) pred = tf.argmax(prob,axis=1) pred = tf.cast(pred,dtype=tf.int32) correct = tf.cast(tf.equal(pred,y),dtype=tf.int32) correct = tf.reduce_sum(correct) total_num += x.shape[0] total_correct += correct acc = total_correct/total_num print("acc:",acc) if __name__ == '__main__': main()
打印网络结构和参数量如下: