Dive into DL TF2--ResNet
1 import tensorflow as tf 2 from tensorflow.keras import layers, activations 3 4 5 class Residual(tf.keras.Model): 6 def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs): 7 super(Residual, self).__init__(**kwargs) 8 self.conv1 = layers.Conv2D(num_channels, 9 padding='same', 10 kernel_size=3, 11 strides=strides) 12 self.conv2 = layers.Conv2D(num_channels, kernel_size=3, padding='same') 13 if use_1x1conv: 14 self.conv3 = layers.Conv2D(num_channels, 15 kernel_size=1, 16 strides=strides) 17 else: 18 self.conv3 = None 19 self.bn1 = layers.BatchNormalization() 20 self.bn2 = layers.BatchNormalization() 21 22 def call(self, X): 23 Y = activations.relu(self.bn1(self.conv1(X))) 24 Y = self.bn2(self.conv2(Y)) 25 if self.conv3: 26 X = self.conv3(X) 27 return activations.relu(Y + X) 28 29 30 blk = Residual(3) 31 #tensorflow input shape (n_images, x_shape, y_shape, channels) 32 X = tf.random.uniform((4, 6, 6, 3)) 33 blk(X).shape 34 35 36 blk = Residual(6, use_1x1conv=True, strides=2) 37 blk(X).shape 38 39 40 net = tf.keras.models.Sequential( 41 [layers.Conv2D(64, kernel_size=7, strides=2, padding='same'), 42 layers.BatchNormalization(), 43 layers.Activation('relu'), 44 layers.MaxPool2D(pool_size=3, strides=2, padding='same')] 45 ) 46 47 48 class ResnetBlock(tf.keras.layers.Layer): 49 def __init__(self, num_channels, num_residuals, first_block=False, **kwargs): 50 super(ResnetBlock, self).__init__(**kwargs) 51 self.listLayers = [] 52 for i in range(num_residuals): 53 if i==0 and not first_block: 54 self.listLayers.append(Residual(num_channels, use_1x1conv=True, strides=2)) 55 else: 56 self.listLayers.append(Residual(num_channels)) 57 58 def call(self, X): 59 for layer in self.listLayers.layers: 60 X = layer(X) 61 return X 62 63 64 # 为ResNet加入所有残差块。这里每个模块使用两个残差块 65 66 class ResNet(tf.keras.Model): 67 def __init__(self, num_blocks, **kwargs): 68 super(ResNet, self).__init__(**kwargs) 69 self.conv=layers.Conv2D(64, kernel_size=7, strides=2, padding='same') 70 self.bn = layers.BatchNormalization() 71 self.relu = layers.Activation('relu') 72 self.mp = layers.MaxPool2D(pool_size=3, strides=2, padding='same') 73 self.resnet_block1 = ResnetBlock(64, num_blocks[0], first_block=True) 74 self.resnet_block2 = ResnetBlock(128, num_blocks[1]) 75 self.resnet_block3 = ResnetBlock(256, num_blocks[2]) 76 self.resnet_block4 = ResnetBlock(512, num_blocks[3]) 77 self.gap = layers.GlobalAvgPool2D() 78 self.fc = layers.Dense(units=10, activation=tf.keras.activations.softmax) 79 80 def call(self, x): 81 x = self.conv(x) 82 x = self.bn(x) 83 x = self.relu(x) 84 x = self.mp(x) 85 x = self.resnet_block1(x) 86 x = self.resnet_block2(x) 87 x = self.resnet_block3(x) 88 x = self.resnet_block4(x) 89 x = self.gap(x) 90 x = self.fc(x) 91 return x 92 93 mynet = ResNet([2, 2, 2, 2]) 94 95 96 X = tf.random.uniform(shape=(1, 224, 224, 1)) 97 for layer in mynet.layers: 98 X = layer(X) 99 print(layer.name, 'output shape:\t', X.shape) 100 101 102 # 在Fashion-MNIST数据集上训练ResNet 103 (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data() 104 x_train = x_train.reshape((60000, 28, 28, 1)).astype('float32') / 255 105 x_test = x_test.reshape((10000, 28, 28, 1)).astype('float32') / 255 106 107 mynet.compile(loss='sparse_categorical_crossentropy', 108 optimizer=tf.keras.optimizers.Adam(), 109 metrics=['accuracy']) 110 111 history = mynet.fit(x_train, y_train, batch_size=64, epochs=5, validation_split=0.2) 112 test_scores = mynet.evaluate(x_test, y_test, verbose=2)