一起学TensorFlow---搭建最简单的全连接网络实现手写数字识别(MNIST)

刚开始学Tensorflow,这里记录学习中的点点滴滴,希望能和大家共同进步。

Cuda和Tensorflow的安装请参考上一篇博客:http://www.cnblogs.com/roboai/p/7768191.html

Tensorflow简单介绍

  我们知道,一维的数据可以用数组表示,二维可以用矩阵表示,那么三维或三维以上呢?比如图像,实际上就是一个三维数据[h,w,c],高、宽、通道数,对于灰度图来说,通道数为1,而对于彩色图像,通道数为3。对于这种三维或三维以上的数据,我们称之为张量(tensor),所以顾名思义,Tensorflow的意思就是张量的流动,Tensorflow将数据打包成一个个张量,由四个维度构成,分别是[batch, height, width, channels],然后在各个节点之间传递。

  节点是Tensorflow里另一重要的概念,对张量的操作称之为节点,一系列的节点构成图。接触过Caffe的朋友可能发现了,这和Caffe里的blob、layer、net是一致的。不同的是,我们需要启动一个会话来计算图,这是Tensorflow的内在机制所决定的。Tensorflow依赖于一个高效的C++后端来进行计算,与后端的这个连接叫做session。一般而言,使用TensorFlow程序的流程是先创建一个图,然后在session中启动它。其思想是先让我们描述一个交互操作图,然后完全将其运行在Python外部。这样做的目的是为了避免频繁切换Python环境和外部环境时需要的开销。如果你想在GPU或者分布式环境中计算时,这一开销会非常可怖,这一开销主要可能是用来进行数据迁移,并不能对计算做出贡献。

  我们构建一个简单的图来说明以上过程,改图包含三个节点(两个源节点和一个矩阵乘法节点),然后启动一个会话计算图得到输出结果,最后需要关闭会话。当然也可以使用with代码块实现自动关闭,效果是一样的。

# coding=utf-8
import tensorflow as tf

# 该图包含3个节点(两个源节点和乘法节点)
matrix1 = tf.constant([[3, 3]])
matrix2 = tf.constant([[2], [2]])
product = tf.matmul(matrix1, matrix2)

# 调用会话启动图
sess = tf.Session()
result = sess.run(product)

# 输出结果并关闭会话
print result
sess.close()

# 使用“with”代码块自动关闭, 该方法更简洁
with tf.Session() as sess:
    result = sess.run(product)
    print result

输出结果为

[[12]]
[[12]]

MNIST数据集

  MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片,也包含每一张图片对应的标签,告诉我们这个是数字几。新建一个get.sh文件,写入以下内容,执行该文件就可以下载该数据集。下载下来的数据集被分成两部分,60000行的训练数据集和10000行的测试数据集。每一张图片包含28X28个像素点,我们可以把图片展开成一个向量,长度是 28x28 = 784。

#!/usr/bin/env sh
# This scripts downloads the mnist data and unzips it.

DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd "$DIR"

echo "Downloading..."

for fname in train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte
do
    if [ ! -e $fname ]; then
        wget --no-check-certificate http://yann.lecun.com/exdb/mnist/${fname}.gz
    fi
done

Softmax Regression与Cross Entropy

  在本文中,我们将采用最简单的网络来预测输入图片中的数字,整个网络仅由一个Softmax Regression构成,数学模型可以写作\(y=softmax(Wx+b)\)。假设\(y'\)是实际分布,\(y\)是预测分布,Cross Entropy的定义是\(loss=\sum{y'\log{y}}\)。关于Softmax Regression的反向传递及Cross Entropy的物理含义请参考以下两篇博客,这里就不展开写了。

  http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92  

  http://blog.csdn.net/rtygbwwwerr/article/details/50778098

全连接网络实现手写数字识别

  下面终于进入正题了,我们有了数据集,同时也了解了算法流程,剩下的就是写代码实现了。首先是导入包,由于Tensorflow帮我们写了一部分数据读写的程序,我们这里就直接用了。

# coding=utf-8
import tensorflow.examples.tutorials.mnist.input_data as input_data
import tensorflow as tf

# 导入数据, 强烈建议预先下载
mnist = input_data.read_data_sets("data/", one_hot=True)

  这里数据可以用我前面给出的get.sh下载,然后放入data文件夹目录下,我之前是直接用input_data.read_data_sets("data/", one_hot=True)下载的,结果半天下载不下来,所以这里还是建议预先下载吧,用get.sh下载比较快。然后是程序的主要部分。

# 训练集占位符:28*28=784
x = tf.placeholder(tf.float32, [None, 784])
# 初始化参数
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
# 输出结果
y = tf.nn.softmax(tf.matmul(x, W) + b)
# 真实值
y_ = tf.placeholder(tf.float32, [None, 10])
# 计算交叉熵
crossEntropy = -tf.reduce_sum(y_*tf.log(y))
# 训练策略
trainStep = tf.train.GradientDescentOptimizer(0.01).minimize(crossEntropy)
# 初始化参数值
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

# 开始训练:循环训练1000次
for i in range(1000):
    batchXs, batchYs = mnist.train.next_batch(100)
    sess.run(trainStep, feed_dict={x: batchXs, y_: batchYs})

# 评估模型
correctPrediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correctPrediction, tf.float32))
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})

  这里用的是占位符的方式传入数据,占位符的尺寸为[None, 784],这里的None表示此张量的第一个维度可以是任何长度的。

  权重值W和偏置量b使用Variable来表示,一个Variable代表一个可修改的张量,存在在Tensorflow的用于描述交互性操作的图中。它们可以用于计算输入值,也可以在计算中被修改。对于各种机器学习应用,一般都会有模型参数,都可以用Variable表示。在这里,我们都用全为零的张量来初始化Wb。

  只需要一行代码就可以实现我们的模型y = tf.nn.softmax(tf.matmul(x, W) + b),同样损失函数也只需要一行代码crossEntropy = -tf.reduce_sum(y_*tf.log(y))。

  以0.01的学习速率,采用梯度下降法最小化交叉熵,对应的代码为trainStep = tf.train.GradientDescentOptimizer(0.01).minimize(crossEntropy)。

  然后初始化参数并训练,定义训练次数为1000,每次随机地选取100图像进行计算。

  最后对得到的模型使用测试数据进行评估,评估结果表明精度达到0.9148(每次都不一样,在91%左右徘徊)。

  至此,我们采用最简单的一个全连接网络实现了一个手写数字识别的网络,剩下的工作是将这个网络及参数保存,采用自己的图片进行识别,进一步感受这个网络的效果,这一部分将在后续的工作中进行。同时我们可以说这个网络过于简单了,91%的识别效果也远远达不到我们的需求,如何进一步提高网络的精度是我们关注的重点。

 


关于会话

  会话(session)提供在图中执行操作的一些方法。一般的模式是:

  1.  建立会话,此时会生成一张空图;
  2.  在会话中添加节点和边,形成一张图;
  3.  执行图

  在调用Session对象的run()方法来执行图时,传入一些Tensor,这个过程叫填充(feed);返回的结果类型根据输入的类型而定,这个过程叫取回(fetch)。

  会话是图交互的桥梁,一个会话可以有多个图,会话可以修改图的结构,也可以往图中注入数据进行计算。因此,会话主要由两个API接口--Extend和Run。Extend操作是在Graph中添加节点和边,Run操作是输入计算的节点和填充必要的数据后,进行计算,并输出运算结果。

关于节点与图

  图中的节点又称为算子,它代表一个操作(Operation,op),一般用来表示施加的数学运算,也可以表示数据输入(feed in)的起点以及输出(push out)的终点,或者是读取/写入持久变量(persistent variable)的终点。

  如果不显式添加一个默认图,系统会自动设置一个全局的默认图。所设置的默认图,在模块范围内定义的节点都将默认加入默认图中。

关于可视化

  可视化时,需要在程序中给必要的节点添加摘要(summary),摘要会收集该节点的数据,并标记上第几步、时间戳等标识,写入事件文件(event file)中。

模型存储与加载

  TensorFLow的API提供了两种方式存储和加载模型:

  (1)生成检查点文件,拓展名一般为.ckpt,通过tf.train.Saver.save()生成。它包含权重和程序中定义的变量,不包含图结构。如果需要在另一个程序中使用,需要重新构建图结构,并告诉TensorFlow如何处理这些权重。

  (2)生成图协议文件,这是一个二进制文件,拓展名一般为.pb,用tf.train.write_graph()保存,只包含图形结构,不包含权重,然后使用tf.import_graph_def()来加载图形。

模型训练之Momentum

  Momentum是模拟物理学中的动量的概念,更新时在一定程度上保留之前的更新方向,利用当前的批次再微调本次的更新参数,因此引入了一个新的变量v(速度),作为前几次梯度的累加。因此,Momentum能够改善训练过程,在下降初期,前后梯度一致时,能够加速学习;在下降的中后期,在局部最小值附近来回震荡时,能够抑制震荡,加快收敛。

 

posted on 2017-11-06 20:15  萝卜丶爱  阅读(1956)  评论(0编辑  收藏  举报

导航