23、Keras实战CIFAR10

1、call函数

call()的本质是将一个类变成一个函数(使这个类的实例可以像函数一样调用)

 1 class A(object):
 2     def __init__(self, name, age):
 3         self.name = name
 4         self.age = age
 5 
 6     def __call__(self,gender): #对象变成函数之后调用call,可以增加参数
 7         print('my name is %s' % self.name)
 8         print('my age is %s' % self.age)
 9         print('my age is %s' % gender)
10 if __name__ == '__main__':
11     a = A('jack', 26)
12     a("man")

输出:

my name is jack
my age is 26
my age is man

2、

 

 

CIFAR10是包含有10类的图片,每类图片中是32x32的同种类别的图片,CIFAR10的图片比较模糊,一般的网络的检测acc为50%左右就已经很好了。

import  tensorflow as tf
from    tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
from     tensorflow import keras
import  os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


def preprocess(x, y):
    # [0~255] => [-1~1]
    x = 2 * tf.cast(x, dtype=tf.float32) / 255. - 1.
    y = tf.cast(y, dtype=tf.int32)
    return x,y


batchsz = 128
# [50k, 32, 32, 3], [10k, 1]
(x, y), (x_val, y_val) = datasets.cifar10.load_data()  #得到训练集和测试集
y = tf.squeeze(y)  #对y的维度进行约简,[50k, 1,10],相当于列表转换为矩阵
y_val = tf.squeeze(y_val)
y = tf.one_hot(y, depth=10) # [50k, 10]
y_val = tf.one_hot(y_val, depth=10) # [10k, 10]
print('datasets:', x.shape, y.shape, x_val.shape, y_val.shape, x.min(), x.max())


train_db = tf.data.Dataset.from_tensor_slices((x,y))
train_db = train_db.map(preprocess).shuffle(10000).batch(batchsz)  #对数据进行预处理,并进行打散处理
test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
test_db = test_db.map(preprocess).batch(batchsz)


sample = next(iter(train_db))   #测试训练集的shape是否如我们所愿
print('batch:', sample[0].shape, sample[1].shape)

#创建自定义网络不能直接使用layers.Dense
class MyDense(layers.Layer):
    # 代替标准的layers.Dense()
    def __init__(self, inp_dim, outp_dim):
        super(MyDense, self).__init__()  #对继承的父类进行初始化,相当于layers.Layer.__init__()

        self.kernel = self.add_variable('w', [inp_dim, outp_dim])
        # self.bias = self.add_variable('b', [outp_dim])

    def call(self, inputs, training=None): #call()的本质是将一个类变成一个函数(使这个类的实例可以像函数一样调用)

        x = inputs @ self.kernel
        return x

class MyNetwork(keras.Model):

    def __init__(self):
        super(MyNetwork, self).__init__()

        self.fc1 = MyDense(32*32*3, 256)  #因为call函数将MyDense产生的对象fc变成了函数,下次调用的时候可以直接()并且可以添加参数
        self.fc2 = MyDense(256, 128)
        self.fc3 = MyDense(128, 64)
        self.fc4 = MyDense(64, 32)
        self.fc5 = MyDense(32, 10)

    def call(self, inputs, training=None):
        """

        :param inputs: [b, 32, 32, 3]
        :param training:
        :return:
        """
        x = tf.reshape(inputs, [-1, 32*32*3])
        # [b, 32*32*3] => [b, 256]
        x = self.fc1(x)
        x = tf.nn.relu(x)
        # [b, 256] => [b, 128]
        x = self.fc2(x)
        x = tf.nn.relu(x)
        # [b, 128] => [b, 64]
        x = self.fc3(x)
        x = tf.nn.relu(x)
        # [b, 64] => [b, 32]
        x = self.fc4(x)
        x = tf.nn.relu(x)
        # [b, 32] => [b, 10]
        x = self.fc5(x)

        return x


network = MyNetwork()  #实例化一个对象
network.compile(optimizer=optimizers.Adam(lr=1e-3),   #网络对象的compile有三个功能
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
network.fit(train_db, epochs=15, validation_data=test_db, validation_freq=1) #validation_data用来做测试,validation_freq测试频率

network.evaluate(test_db)
network.save_weights('ckpt/weights.ckpt') #保存网络的权值w
del network #删除现在在训练的网络
print('saved to ckpt/weights.ckpt')


network = MyNetwork()  #新建一个网络
network.compile(optimizer=optimizers.Adam(lr=1e-3),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
network.load_weights('ckpt/weights.ckpt')  #对新建的网路加载权值
print('loaded weights from file.')
network.evaluate(test_db)

输出:

391/391 [==============================] - 6s 14ms/step - loss: 1.7377 - accuracy: 0.3842 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00
Epoch 2/15

  1/391 [..............................] - ETA: 2:50 - loss: 1.5974 - accuracy: 0.4531
  8/391 [..............................] - ETA: 23s - loss: 1.5496 - accuracy: 0.4590 
 15/391 [>.............................] - ETA: 13s - loss: 1.5670 - accuracy: 0.4547
 22/391 [>.............................] - ETA: 10s - loss: 1.5520 - accuracy: 0.4677
 29/391 [=>............................] - ETA: 8s - loss: 1.5520 - accuracy: 0.4663 
 36/391 [=>............................] - ETA: 7s - loss: 1.5564 - accuracy: 0.4657
 43/391 [==>...........................] - ETA: 6s - loss: 1.5496 - accuracy: 0.4664
 49/391 [==>...........................] - ETA: 6s - loss: 1.5449 - accuracy: 0.4667
 56/391 [===>..........................] - ETA: 5s - loss: 1.5292 - accuracy: 0.4720
 63/391 [===>..........................] - ETA: 5s - loss: 1.5350 - accuracy: 0.4695
 70/391 [====>.........................] - ETA: 4s - loss: 1.5364 - accuracy: 0.4675
 74/391 [====>.........................] - ETA: 4s - loss: 1.5332 - accuracy: 0.4670
 80/391 [=====>........................] - ETA: 4s - loss: 1.5360 - accuracy: 0.4647
 87/391 [=====>........................] - ETA: 4s - loss: 1.5367 - accuracy: 0.4652
 90/391 [=====>........................] - ETA: 4s - loss: 1.5403 - accuracy: 0.4628
 97/391 [======>.......................] - ETA: 4s - loss: 1.5380 - accuracy: 0.4631
103/391 [======>.......................] - ETA: 4s - loss: 1.5347 - accuracy: 0.4634
109/391 [=======>......................] - ETA: 3s - loss: 1.5362 - accuracy: 0.4627
115/391 [=======>......................] - ETA: 3s - loss: 1.5336 - accuracy: 0.4641
121/391 [========>.....................] - ETA: 3s - loss: 1.5330 - accuracy: 0.4635
127/391 [========>.....................] - ETA: 3s - loss: 1.5344 - accuracy: 0.4629
134/391 [=========>....................] - ETA: 3s - loss: 1.5289 - accuracy: 0.4645
140/391 [=========>....................] - ETA: 3s - loss: 1.5322 - accuracy: 0.4636
144/391 [==========>...................] - ETA: 3s - loss: 1.5334 - accuracy: 0.4629
151/391 [==========>...................] - ETA: 3s - loss: 1.5296 - accuracy: 0.4641
158/391 [===========>..................] - ETA: 2s - loss: 1.5298 - accuracy: 0.4640
165/391 [===========>..................] - ETA: 2s - loss: 1.5271 - accuracy: 0.4651
171/391 [============>.................] - ETA: 2s - loss: 1.5257 - accuracy: 0.4651
178/391 [============>.................] - ETA: 2s - loss: 1.5235 - accuracy: 0.4656
185/391 [=============>................] - ETA: 2s - loss: 1.5264 - accuracy: 0.4645
192/391 [=============>................] - ETA: 2s - loss: 1.5257 - accuracy: 0.4648
199/391 [==============>...............] - ETA: 2s - loss: 1.5247 - accuracy: 0.4656
206/391 [==============>...............] - ETA: 2s - loss: 1.5228 - accuracy: 0.4654
213/391 [===============>..............] - ETA: 2s - loss: 1.5206 - accuracy: 0.4663
219/391 [===============>..............] - ETA: 1s - loss: 1.5203 - accuracy: 0.4661
226/391 [================>.............] - ETA: 1s - loss: 1.5170 - accuracy: 0.4675
233/391 [================>.............] - ETA: 1s - loss: 1.5160 - accuracy: 0.4676
240/391 [=================>............] - ETA: 1s - loss: 1.5159 - accuracy: 0.4675
246/391 [=================>............] - ETA: 1s - loss: 1.5172 - accuracy: 0.4665
253/391 [==================>...........] - ETA: 1s - loss: 1.5182 - accuracy: 0.4661
260/391 [==================>...........] - ETA: 1s - loss: 1.5181 - accuracy: 0.4660
267/391 [===================>..........] - ETA: 1s - loss: 1.5182 - accuracy: 0.4663
274/391 [====================>.........] - ETA: 1s - loss: 1.5175 - accuracy: 0.4664
281/391 [====================>.........] - ETA: 1s - loss: 1.5178 - accuracy: 0.4667
287/391 [=====================>........] - ETA: 1s - loss: 1.5174 - accuracy: 0.4664
293/391 [=====================>........] - ETA: 1s - loss: 1.5165 - accuracy: 0.4668
296/391 [=====================>........] - ETA: 1s - loss: 1.5163 - accuracy: 0.4669
302/391 [======================>.......] - ETA: 0s - loss: 1.5156 - accuracy: 0.4667
305/391 [======================>.......] - ETA: 0s - loss: 1.5164 - accuracy: 0.4662
311/391 [======================>.......] - ETA: 0s - loss: 1.5154 - accuracy: 0.4665
315/391 [=======================>......] - ETA: 0s - loss: 1.5150 - accuracy: 0.4664
325/391 [=======================>......] - ETA: 0s - loss: 1.5123 - accuracy: 0.4670
331/391 [========================>.....] - ETA: 0s - loss: 1.5127 - accuracy: 0.4668
334/391 [========================>.....] - ETA: 0s - loss: 1.5125 - accuracy: 0.4669
344/391 [=========================>....] - ETA: 0s - loss: 1.5117 - accuracy: 0.4673
356/391 [==========================>...] - ETA: 0s - loss: 1.5100 - accuracy: 0.4684
369/391 [===========================>..] - ETA: 0s - loss: 1.5086 - accuracy: 0.4687
383/391 [============================>.] - ETA: 0s - loss: 1.5074 - accuracy: 0.4697

此外,可以通过改变网络每层的输入和输出的维度来改变网络的收敛速度。

posted on 2019-12-25 15:42  Luaser  阅读(669)  评论(0编辑  收藏  举报