Dive into DL TF2--ResNet

  1 import tensorflow as tf
  2 from tensorflow.keras import layers, activations
  3 
  4 
  5 class Residual(tf.keras.Model):
  6   def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
  7     super(Residual, self).__init__(**kwargs)
  8     self.conv1 = layers.Conv2D(num_channels,
  9                                padding='same',
 10                                kernel_size=3,
 11                                strides=strides)
 12     self.conv2 = layers.Conv2D(num_channels, kernel_size=3, padding='same')
 13     if use_1x1conv:
 14       self.conv3 = layers.Conv2D(num_channels, 
 15                                  kernel_size=1,
 16                                  strides=strides)
 17     else:
 18       self.conv3 = None
 19     self.bn1 = layers.BatchNormalization()
 20     self.bn2 = layers.BatchNormalization()
 21 
 22   def call(self, X):
 23     Y = activations.relu(self.bn1(self.conv1(X)))
 24     Y = self.bn2(self.conv2(Y))
 25     if self.conv3:
 26       X = self.conv3(X)
 27     return activations.relu(Y + X)
 28 
 29 
 30 blk = Residual(3)
 31 #tensorflow input shape (n_images, x_shape, y_shape, channels)
 32 X = tf.random.uniform((4, 6, 6, 3))
 33 blk(X).shape
 34 
 35 
 36 blk = Residual(6, use_1x1conv=True, strides=2)
 37 blk(X).shape
 38 
 39 
 40 net = tf.keras.models.Sequential(
 41     [layers.Conv2D(64, kernel_size=7, strides=2, padding='same'),
 42      layers.BatchNormalization(),
 43      layers.Activation('relu'),
 44      layers.MaxPool2D(pool_size=3, strides=2, padding='same')]
 45 )
 46 
 47 
 48 class ResnetBlock(tf.keras.layers.Layer):
 49   def __init__(self, num_channels, num_residuals, first_block=False, **kwargs):
 50     super(ResnetBlock, self).__init__(**kwargs)
 51     self.listLayers = []
 52     for i in range(num_residuals):
 53       if i==0 and not first_block:
 54         self.listLayers.append(Residual(num_channels, use_1x1conv=True, strides=2))
 55       else:
 56         self.listLayers.append(Residual(num_channels))
 57 
 58   def call(self, X):
 59     for layer in self.listLayers.layers:
 60       X = layer(X)
 61     return X
 62 
 63 
 64 # 为ResNet加入所有残差块。这里每个模块使用两个残差块
 65 
 66 class ResNet(tf.keras.Model):
 67   def __init__(self, num_blocks, **kwargs):
 68     super(ResNet, self).__init__(**kwargs)
 69     self.conv=layers.Conv2D(64, kernel_size=7, strides=2, padding='same')
 70     self.bn = layers.BatchNormalization()
 71     self.relu = layers.Activation('relu')
 72     self.mp = layers.MaxPool2D(pool_size=3, strides=2, padding='same')
 73     self.resnet_block1 = ResnetBlock(64, num_blocks[0], first_block=True)
 74     self.resnet_block2 = ResnetBlock(128, num_blocks[1])
 75     self.resnet_block3 = ResnetBlock(256, num_blocks[2])
 76     self.resnet_block4 = ResnetBlock(512, num_blocks[3])
 77     self.gap = layers.GlobalAvgPool2D()
 78     self.fc = layers.Dense(units=10, activation=tf.keras.activations.softmax)
 79 
 80   def call(self, x):
 81     x = self.conv(x)
 82     x = self.bn(x)
 83     x = self.relu(x)
 84     x = self.mp(x)
 85     x = self.resnet_block1(x)
 86     x = self.resnet_block2(x)
 87     x = self.resnet_block3(x)
 88     x = self.resnet_block4(x)
 89     x = self.gap(x)
 90     x = self.fc(x)
 91     return x
 92 
 93 mynet = ResNet([2, 2, 2, 2])
 94 
 95 
 96 X = tf.random.uniform(shape=(1, 224, 224, 1))
 97 for layer in mynet.layers:
 98   X = layer(X)
 99   print(layer.name, 'output shape:\t', X.shape)
100 
101 
102 # 在Fashion-MNIST数据集上训练ResNet
103 (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
104 x_train = x_train.reshape((60000, 28, 28, 1)).astype('float32') / 255
105 x_test = x_test.reshape((10000, 28, 28, 1)).astype('float32') / 255
106 
107 mynet.compile(loss='sparse_categorical_crossentropy',
108               optimizer=tf.keras.optimizers.Adam(),
109               metrics=['accuracy'])
110 
111 history = mynet.fit(x_train, y_train, batch_size=64, epochs=5, validation_split=0.2)
112 test_scores = mynet.evaluate(x_test, y_test, verbose=2)

 

posted @ 2020-06-11 22:16  WWBlog  阅读(172)  评论(0编辑  收藏  举报