tensorflow的卷积和池化层(二):记实践之cifar10

tensorflow中的卷积和池化层(一)各种卷积类型Convolution这两篇博客中,主要讲解了卷积神经网络的核心层,同时也结合当下流行的Caffe和tf框架做了介绍,本篇博客将接着tensorflow中的卷积和池化层(一)的内容,继续介绍tf框架中卷积神经网络CNN的使用。

因此,接下来将介绍CNN的入门级教程cifar10\100项目。cifar10\100 数据集是由Alex Krizhevsky、Vinod Nair和Geoffrey Hinton收集的,这两个数据集都是从8000万的数据集中挑选出来的。因此构成它们本身的图片是很相似的,而区别在于:

  • cifar-10是由60000张表示10类物体的32*32大小的彩色图片构成,顾名思义,每类刚好6000张,类间数据平衡,而且5000张用于训练,1000张用于测试和验证,那么这个数据集就总共有50000张训练图片,10000张测试图片。那包含的10类如下:airplane、automobile、bird、cat、deer、dog、frog、horse、ship、truck。

  • cifar-100是由60000张表示100类物体的32*32大小的彩色图片构成,顾名思义,每类刚好600张,类内数据平衡,而且500张用于训练,100张用于测试和验证,那么这个数据集就总共有50000张训练图片,10000张测试图片。包含的类别可查看官网。

官网地址如下:The CIFAR-10 and CIFAR-100 dataset  里面给我们提供了3种读取数据集的方式。

不管学习何种框架,cifar-10\100都是入门级的经典CNN项目,之所以称之为项目,则是由上述数据集演化的依赖各种不同的工具解决的各类工程问题,此处就是10类或100类的分类问题,因此,这个项目值得大家共同学习。而本篇博客就介绍在tf工具下实现的cifar-10\100项目。

cifar-10的代码结构如下表所示,总共有5个文件,它们的作用如下:

文件作用
cifar10_input.py 读取本地CIFAR-10的二进制文件格式的内容。
cifar10.py 建立CIFAR-10的模型。
cifar10_train.py 在CPU或GPU上训练CIFAR-10的模型。
cifar10_multi_gpu_train.py 在多GPU上训练CIFAR-10的模型。
cifar10_eval.py 评估CIFAR-10模型的预测性能。

给出上述5个文件的百度云地址:cifar-10项目链接   密码: f9ge

由于上述文件中有着详细的注释,因此下面对这些文件需要进一步理解的地方具体说明。

  • cifar10_input.py

这个文件是用来读取官方cifar10\100数据集的,并且是三种数据集方式中的二进制文件格式,也就是一系列data_batch_num.bin(num=1...5)文件和test_batch.bin文件,每一个文件的组织形式都是这样的:

<1 x label><3072 x pixel>
...
<1 x label><3072 x pixel>

那也就是说,每一个batch中每一行都记录着一张图片,第一个字节是这个图片的label,应该是在0-9范围内的整形变量;而另外3072个字节则表示3个通道的像素值,即3*32*32,应该是按照RGB的顺序排列着,即1024R+1024G+1024B。
除此之外,还有另外一个文件,batches.meta.txt,顾名思义,光有label不行,还需要知道每个label代表什么,这个文件里按行存放着每个整形label的表示,并且两者是一一对应的。
这个文件有4个函数:read_cifar10、_generate_image_and_label_batch、distorted_inputs、inputs。
第一个函数read_cifar10的目的在于读取一行data_batch_num.bin中的内容,即读取一张图片,并且获取这张图片的label和按照[height,width,channel]维度组织的像素值。
这里有4个重要的函数:tf.FixedLengthRecordReader、tf.slice、tf.reshape、tf.transpose。其中,
tf.FixedLengthRecordReader是专门用来读取固定长度字节数的二进制文件阅读器;tf.slice函数是tf的切片操作,函数原型如下:
tf.slice(inputs,begin,size,name='')
在begin的位置从inputs上抽取size大小的内容,name有默认值,例如:
tf.slice(record_bytes, [0], [label_bytes])
tf.slice(record_bytes, [label_bytes], [image_bytes])
tf.reshape就是将一个tensor的维度重组,函数原型如下:
tf.reshape(tensor,shape,name=None)

将原来的tensor按照shape的样子重新组织成为新的tensor。例如:

tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]), [result.depth, result.height, result.width])

这样就可以把slice得到的那一串字符重新组织成[result.depth, result.height, result.width]大小和维度的tensor了。

tf.transpose就是将一个tensor的维度顺序进行交换,函数原型如下:
tf.transpose(a, perm=None, name='transpose')

将tensor a按照perm的顺序交换变成新的tensor,例如:

 result.uint8image = tf.transpose(depth_major, [1, 2, 0])

这就把[result.depth, result.height, result.width]变成了[result.height, result.width, result.depth]。

第二个函数_generate_image_and_label_batch的目的在于构建一个batch的图片和相应的label。这里有一个非常重要的函数tf.train.shuffle_batch,例如:

images, label_batch = tf.train.shuffle_batch(
      [image, label],
      batch_size=batch_size,
      num_threads=num_preprocess_threads,
      capacity=min_queue_examples + 3 * batch_size,
      min_after_dequeue=min_queue_examples)

这样理解:这个函数构建了一个capacity大小的队列,然后呢,在capacity内随机打乱这些图片,每次从中取大小为batch_size的图片数目出列,同时又添加一部分的图片数目入列,整个过程中在队列里的图片数目不能少于min_after_dequeue个,以保证进出之间达到很好的打乱效果。就这样,这个函数返回一个batch的image和对应的label。

第三个函数distorted_inputs和第四个函数inputs的目的分别是对训练和测试的数据集进行数据增强和预处理,包括crop、Flip、random_brightness、random_contrast、Whitening等等。最后返回的就是经过这些处理后的数据作为模型真正的输入。这里面的函数看看代码即可。

  • cifar10.py

这个文件主要是用来定义cifar10\100模型的,同时定义loss函数,以便训练。这个文件内共定义了10个函数,分别是_activation_summary、_variable_on_cpu、_variable_with_weight_decay、 distorted_inputs、inputs、inference、loss、_add_loss_summaries、train、maybe_download_and_extract。除此之外就是一些超参数的配置,比如batch_size=128,初始化学习率0.1,学习率的衰减因子0.1,学习速率开始下降的周期数350,移动平均衰减量0.9999等等。

函数的功能如下:

_activation_summary函数为激活函数添加summary,方便在tensorboard中可视化相关节点传输的数据。主要是tf.histogram_summary和tf.scalar_summary,这将在后续介绍。

_variable_on_cpu函数的目的是在CPU上创建变量。

_variable_with_weight_decay函数的目的是为了利用高斯分布初始化变量并且需要时添加权重衰减因子weight_decay。

distorted_inputs函数的目的是在cifar10_input.py的基础上构建训练数据集,得到经过数据增强和预处理之后模型的输入数据。

inputs函数的目的是在cifar10_input.py的基础上构建测试数据集,得到经过数据增强和预处理之后模型测试的输入数据。

inference函数的目的是构建CNN网络。

loss函数的目的是定义模型的loss。

 _add_loss_summaries函数的目的是给loss添加summary,以便于可视化。

train函数的目的是训练cifar10模型,代价函数等。

maybe_download_and_extract函数的目的是从指定的网站下载cifar10数据集并解压。

 运行上述训练代码cifar10_train.py,也不是一帆风顺的,错误和修改办法如下:

1. AttributeError: module 'tensorflow.python.ops.image_ops' has no attribute 'random_crop'。

这个错误来自于cifar10_input.py文件中的distorted_image = tf.image.random_crop(reshaped_image, [height, width]),将此句修改为:

distorted_image = tf.random_crop(reshaped_image, [height, width, 3])

2. AttributeError: module 'tensorflow.python.ops.image_ops' has no attribute 'per_image_whitening'。

这个错误来自于cifar10_input.py文件中的 float_image = tf.image.per_image_whitening(distorted_image),将此句修改为:

 float_image = tf.image.per_image_standardization(distorted_image)

3. AttributeError: module 'tensorflow' has no attribute 'image_summary'。

这个错误来自于cifar10_input.py文件中的tf.image_summary('images', images),将此句修改为:

tf.summary.image('images', images)
 tf.summary.scalar('learning_rate', lr)#cifar10.py

注:整个项目类似的地方做修改。

4. AttributeError: module 'tensorflow' has no attribute 'histogram_summary'。

这个错误来自于cifar10.py文件中的 tf.histogram_summary(tensor_name + '/activations', x),将此句修改为:

tf.summary.histogram(tensor_name + '/activations', x)
tf.summary.histogram(var.op.name, var)

注:整个项目类似的地方做修改。

5. AttributeError: module 'tensorflow' has no attribute 'scalar_summary'。

这个错误来自于cifar10.py文件中的 tf.scalar_summary(tensor_name + '/sparsity', tf.nn.zero_fraction(x)),将此句修改为:

tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x))

6. AttributeError: module 'tensorflow' has no attribute 'mul'。

这个错误之前有说过,改为multiply即可。

7. ValueError: Tried to convert 'tensor' to a tensor and failed. Error: Argument must be a dense tensor: range(0, 128) - got shape [128], but wanted []。

这个错误定位在cifar10.py的这句代码上:

 indices = tf.reshape(range(FLAGS.batch_size), [FLAGS.batch_size, 1])

将此句改为:

 indices = tf.reshape(list(range(FLAGS.batch_size)), [FLAGS.batch_size, 1])

8. ValueError: Shapes (2, 128, 1) and () are incompatible。

这个错误是cifar10.py中的 concated = tf.concat(1, [indices, sparse_labels])触发的,此句修改为:

 concated = tf.concat([indices, sparse_labels], 1)

9. ValueError: Only call `softmax_cross_entropy_with_logits` with named arguments (labels=..., logits=..., ...)。

这个错误来自于softmax_cross_entropy_with_logits这个函数,新的tf版本更新了这个函数,函数原型:

tf.nn.softmax_cross_entropy_with_logits(_sentinel=None, labels=None, logits=None, dim=-1, name=None)

将其修改为:

cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
      logits=logits, labels=dense_labels, name='cross_entropy_per_example')

10. TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

提示了,应该将 if grad: 修改为 if grad is not None。

11. AttributeError: module 'tensorflow' has no attribute 'merge_all_summaries'。

此处错误来自于cifar10_train.py中的 summary_op = tf.merge_all_summaries(),将其修改为:

 summary_op = tf.summary.merge_all()

12. AttributeError: module 'tensorflow.python.training.training' has no attribute 'SummaryWriter'。

这个错误来自于cifar10_train.py的

summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
                                            graph_def=sess.graph_def)

将其修改为:

summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                            graph_def=sess.graph_def)

13. WARNING:tensorflow:Passing a `GraphDef` to the SummaryWriter is deprecated. Pass a `Graph` object instead, such as `sess.graph`.

这个警告来自于cifar10_train.py的

summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                            graph_def=sess.graph_def)

将其修改为:

summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                            sess.graph)

到这里为止,才能成功的运行这个例子,上述错误都是由于tf高低版本不兼容导致的,代码本身没有问题,只是在高版本的tf做了修改,比如本机的版本是1.7.0,某些低版本应该没有问题。

以下是运行的log

2018-05-05 13:13:40.782676: step 0, loss = 4.68 (0.1 examples/sec; 918.057 sec/batch)
2018-05-05 13:14:37.995829: step 10, loss = 4.66 (19.7 examples/sec; 6.510 sec/batch)
2018-05-05 13:15:38.571923: step 20, loss = 4.64 (19.3 examples/sec; 6.618 sec/batch)
2018-05-05 13:16:37.660061: step 30, loss = 4.62 (20.4 examples/sec; 6.260 sec/batch)
2018-05-05 13:17:35.194066: step 40, loss = 4.60 (22.7 examples/sec; 5.639 sec/batch)
2018-05-05 13:18:36.177244: step 50, loss = 4.58 (22.5 examples/sec; 5.699 sec/batch)
2018-05-05 13:19:37.775057: step 60, loss = 4.57 (20.9 examples/sec; 6.122 sec/batch)
2018-05-05 13:20:38.255898: step 70, loss = 4.55 (21.0 examples/sec; 6.081 sec/batch)
2018-05-05 13:21:39.074639: step 80, loss = 4.53 (18.5 examples/sec; 6.929 sec/batch)
2018-05-05 13:22:42.469230: step 90, loss = 4.51 (21.9 examples/sec; 5.858 sec/batch)
2018-05-05 13:23:43.102476: step 100, loss = 4.50 (20.5 examples/sec; 6.236 sec/batch)
2018-05-05 13:24:53.920811: step 110, loss = 4.48 (19.1 examples/sec; 6.708 sec/batch)
2018-05-05 13:25:55.722164: step 120, loss = 4.46 (21.0 examples/sec; 6.097 sec/batch)
2018-05-05 13:26:58.607399: step 130, loss = 4.44 (20.8 examples/sec; 6.153 sec/batch)
2018-05-05 13:27:56.598621: step 140, loss = 4.42 (19.4 examples/sec; 6.589 sec/batch)
2018-05-05 13:28:57.043367: step 150, loss = 4.41 (20.9 examples/sec; 6.117 sec/batch)
2018-05-05 13:30:00.026865: step 160, loss = 4.39 (19.3 examples/sec; 6.640 sec/batch)
2018-05-05 13:30:57.701242: step 170, loss = 4.38 (19.2 examples/sec; 6.677 sec/batch)
2018-05-05 13:31:54.940464: step 180, loss = 4.36 (24.6 examples/sec; 5.210 sec/batch)
2018-05-05 13:32:54.969103: step 190, loss = 4.34 (20.9 examples/sec; 6.119 sec/batch)
2018-05-05 13:33:57.856344: step 200, loss = 4.32 (20.2 examples/sec; 6.340 sec/batch)
2018-05-05 13:35:03.489890: step 210, loss = 4.31 (21.5 examples/sec; 5.966 sec/batch)

上面括号里面的两个数字表示的是每秒跑了多少张图片和多少秒跑了一个batch,两者相乘约等于一个batch的图片数目128。可以从代码中看出来,如下代码所示:

if step % 10 == 0:
        num_examples_per_step = FLAGS.batch_size
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = float(duration)

        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                      'sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,
                             examples_per_sec, sec_per_batch))

由于电脑配置低,跑起来太慢了,20000次,很耗时,所以最终的结果就不贴了。

 

posted on 2018-05-05 13:48  greathuman  阅读(6717)  评论(0编辑  收藏  举报