23、Keras实战CIFAR10
1、call函数
call()的本质是将一个类变成一个函数(使这个类的实例可以像函数一样调用)
1 class A(object): 2 def __init__(self, name, age): 3 self.name = name 4 self.age = age 5 6 def __call__(self,gender): #对象变成函数之后调用call,可以增加参数 7 print('my name is %s' % self.name) 8 print('my age is %s' % self.age) 9 print('my age is %s' % gender) 10 if __name__ == '__main__': 11 a = A('jack', 26) 12 a("man")
输出:
my name is jack my age is 26 my age is man
2、
CIFAR10是包含有10类的图片,每类图片中是32x32的同种类别的图片,CIFAR10的图片比较模糊,一般的网络的检测acc为50%左右就已经很好了。
import tensorflow as tf from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics from tensorflow import keras import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' def preprocess(x, y): # [0~255] => [-1~1] x = 2 * tf.cast(x, dtype=tf.float32) / 255. - 1. y = tf.cast(y, dtype=tf.int32) return x,y batchsz = 128 # [50k, 32, 32, 3], [10k, 1] (x, y), (x_val, y_val) = datasets.cifar10.load_data() #得到训练集和测试集 y = tf.squeeze(y) #对y的维度进行约简,[50k, 1,10],相当于列表转换为矩阵 y_val = tf.squeeze(y_val) y = tf.one_hot(y, depth=10) # [50k, 10] y_val = tf.one_hot(y_val, depth=10) # [10k, 10] print('datasets:', x.shape, y.shape, x_val.shape, y_val.shape, x.min(), x.max()) train_db = tf.data.Dataset.from_tensor_slices((x,y)) train_db = train_db.map(preprocess).shuffle(10000).batch(batchsz) #对数据进行预处理,并进行打散处理 test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val)) test_db = test_db.map(preprocess).batch(batchsz) sample = next(iter(train_db)) #测试训练集的shape是否如我们所愿 print('batch:', sample[0].shape, sample[1].shape) #创建自定义网络不能直接使用layers.Dense class MyDense(layers.Layer): # 代替标准的layers.Dense() def __init__(self, inp_dim, outp_dim): super(MyDense, self).__init__() #对继承的父类进行初始化,相当于layers.Layer.__init__() self.kernel = self.add_variable('w', [inp_dim, outp_dim]) # self.bias = self.add_variable('b', [outp_dim]) def call(self, inputs, training=None): #call()的本质是将一个类变成一个函数(使这个类的实例可以像函数一样调用) x = inputs @ self.kernel return x class MyNetwork(keras.Model): def __init__(self): super(MyNetwork, self).__init__() self.fc1 = MyDense(32*32*3, 256) #因为call函数将MyDense产生的对象fc变成了函数,下次调用的时候可以直接()并且可以添加参数 self.fc2 = MyDense(256, 128) self.fc3 = MyDense(128, 64) self.fc4 = MyDense(64, 32) self.fc5 = MyDense(32, 10) def call(self, inputs, training=None): """ :param inputs: [b, 32, 32, 3] :param training: :return: """ x = tf.reshape(inputs, [-1, 32*32*3]) # [b, 32*32*3] => [b, 256] x = self.fc1(x) x = tf.nn.relu(x) # [b, 256] => [b, 128] x = self.fc2(x) x = tf.nn.relu(x) # [b, 128] => [b, 64] x = self.fc3(x) x = tf.nn.relu(x) # [b, 64] => [b, 32] x = self.fc4(x) x = tf.nn.relu(x) # [b, 32] => [b, 10] x = self.fc5(x) return x network = MyNetwork() #实例化一个对象 network.compile(optimizer=optimizers.Adam(lr=1e-3), #网络对象的compile有三个功能 loss=tf.losses.CategoricalCrossentropy(from_logits=True), metrics=['accuracy']) network.fit(train_db, epochs=15, validation_data=test_db, validation_freq=1) #validation_data用来做测试,validation_freq测试频率 network.evaluate(test_db) network.save_weights('ckpt/weights.ckpt') #保存网络的权值w del network #删除现在在训练的网络 print('saved to ckpt/weights.ckpt') network = MyNetwork() #新建一个网络 network.compile(optimizer=optimizers.Adam(lr=1e-3), loss=tf.losses.CategoricalCrossentropy(from_logits=True), metrics=['accuracy']) network.load_weights('ckpt/weights.ckpt') #对新建的网路加载权值 print('loaded weights from file.') network.evaluate(test_db)
输出:
391/391 [==============================] - 6s 14ms/step - loss: 1.7377 - accuracy: 0.3842 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00
Epoch 2/15
1/391 [..............................] - ETA: 2:50 - loss: 1.5974 - accuracy: 0.4531
8/391 [..............................] - ETA: 23s - loss: 1.5496 - accuracy: 0.4590
15/391 [>.............................] - ETA: 13s - loss: 1.5670 - accuracy: 0.4547
22/391 [>.............................] - ETA: 10s - loss: 1.5520 - accuracy: 0.4677
29/391 [=>............................] - ETA: 8s - loss: 1.5520 - accuracy: 0.4663
36/391 [=>............................] - ETA: 7s - loss: 1.5564 - accuracy: 0.4657
43/391 [==>...........................] - ETA: 6s - loss: 1.5496 - accuracy: 0.4664
49/391 [==>...........................] - ETA: 6s - loss: 1.5449 - accuracy: 0.4667
56/391 [===>..........................] - ETA: 5s - loss: 1.5292 - accuracy: 0.4720
63/391 [===>..........................] - ETA: 5s - loss: 1.5350 - accuracy: 0.4695
70/391 [====>.........................] - ETA: 4s - loss: 1.5364 - accuracy: 0.4675
74/391 [====>.........................] - ETA: 4s - loss: 1.5332 - accuracy: 0.4670
80/391 [=====>........................] - ETA: 4s - loss: 1.5360 - accuracy: 0.4647
87/391 [=====>........................] - ETA: 4s - loss: 1.5367 - accuracy: 0.4652
90/391 [=====>........................] - ETA: 4s - loss: 1.5403 - accuracy: 0.4628
97/391 [======>.......................] - ETA: 4s - loss: 1.5380 - accuracy: 0.4631
103/391 [======>.......................] - ETA: 4s - loss: 1.5347 - accuracy: 0.4634
109/391 [=======>......................] - ETA: 3s - loss: 1.5362 - accuracy: 0.4627
115/391 [=======>......................] - ETA: 3s - loss: 1.5336 - accuracy: 0.4641
121/391 [========>.....................] - ETA: 3s - loss: 1.5330 - accuracy: 0.4635
127/391 [========>.....................] - ETA: 3s - loss: 1.5344 - accuracy: 0.4629
134/391 [=========>....................] - ETA: 3s - loss: 1.5289 - accuracy: 0.4645
140/391 [=========>....................] - ETA: 3s - loss: 1.5322 - accuracy: 0.4636
144/391 [==========>...................] - ETA: 3s - loss: 1.5334 - accuracy: 0.4629
151/391 [==========>...................] - ETA: 3s - loss: 1.5296 - accuracy: 0.4641
158/391 [===========>..................] - ETA: 2s - loss: 1.5298 - accuracy: 0.4640
165/391 [===========>..................] - ETA: 2s - loss: 1.5271 - accuracy: 0.4651
171/391 [============>.................] - ETA: 2s - loss: 1.5257 - accuracy: 0.4651
178/391 [============>.................] - ETA: 2s - loss: 1.5235 - accuracy: 0.4656
185/391 [=============>................] - ETA: 2s - loss: 1.5264 - accuracy: 0.4645
192/391 [=============>................] - ETA: 2s - loss: 1.5257 - accuracy: 0.4648
199/391 [==============>...............] - ETA: 2s - loss: 1.5247 - accuracy: 0.4656
206/391 [==============>...............] - ETA: 2s - loss: 1.5228 - accuracy: 0.4654
213/391 [===============>..............] - ETA: 2s - loss: 1.5206 - accuracy: 0.4663
219/391 [===============>..............] - ETA: 1s - loss: 1.5203 - accuracy: 0.4661
226/391 [================>.............] - ETA: 1s - loss: 1.5170 - accuracy: 0.4675
233/391 [================>.............] - ETA: 1s - loss: 1.5160 - accuracy: 0.4676
240/391 [=================>............] - ETA: 1s - loss: 1.5159 - accuracy: 0.4675
246/391 [=================>............] - ETA: 1s - loss: 1.5172 - accuracy: 0.4665
253/391 [==================>...........] - ETA: 1s - loss: 1.5182 - accuracy: 0.4661
260/391 [==================>...........] - ETA: 1s - loss: 1.5181 - accuracy: 0.4660
267/391 [===================>..........] - ETA: 1s - loss: 1.5182 - accuracy: 0.4663
274/391 [====================>.........] - ETA: 1s - loss: 1.5175 - accuracy: 0.4664
281/391 [====================>.........] - ETA: 1s - loss: 1.5178 - accuracy: 0.4667
287/391 [=====================>........] - ETA: 1s - loss: 1.5174 - accuracy: 0.4664
293/391 [=====================>........] - ETA: 1s - loss: 1.5165 - accuracy: 0.4668
296/391 [=====================>........] - ETA: 1s - loss: 1.5163 - accuracy: 0.4669
302/391 [======================>.......] - ETA: 0s - loss: 1.5156 - accuracy: 0.4667
305/391 [======================>.......] - ETA: 0s - loss: 1.5164 - accuracy: 0.4662
311/391 [======================>.......] - ETA: 0s - loss: 1.5154 - accuracy: 0.4665
315/391 [=======================>......] - ETA: 0s - loss: 1.5150 - accuracy: 0.4664
325/391 [=======================>......] - ETA: 0s - loss: 1.5123 - accuracy: 0.4670
331/391 [========================>.....] - ETA: 0s - loss: 1.5127 - accuracy: 0.4668
334/391 [========================>.....] - ETA: 0s - loss: 1.5125 - accuracy: 0.4669
344/391 [=========================>....] - ETA: 0s - loss: 1.5117 - accuracy: 0.4673
356/391 [==========================>...] - ETA: 0s - loss: 1.5100 - accuracy: 0.4684
369/391 [===========================>..] - ETA: 0s - loss: 1.5086 - accuracy: 0.4687
383/391 [============================>.] - ETA: 0s - loss: 1.5074 - accuracy: 0.4697
此外,可以通过改变网络每层的输入和输出的维度来改变网络的收敛速度。