[Tensorflow] RNN - 03. MultiRNNCell for Digit Prediction

Ref: http://blog.csdn.net/u014595019/article/details/52759104

 

Time: 2min

Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting MNIST_data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
(55000, 784)
Iter0, step 5, training accuracy 0.257812
Iter0, step 10, training accuracy 0.320312
Iter0, step 15, training accuracy 0.523438
Iter0, step 20, training accuracy 0.554688
Iter0, step 25, training accuracy 0.515625
Iter0, step 30, training accuracy 0.484375
Iter0, step 35, training accuracy 0.554688
Iter0, step 40, training accuracy 0.679688
Iter0, step 45, training accuracy 0.71875
Iter0, step 50, training accuracy 0.742188
Iter0, step 55, training accuracy 0.671875
Iter0, step 60, training accuracy 0.742188
Iter0, step 65, training accuracy 0.75
Iter0, step 70, training accuracy 0.742188
Iter0, step 75, training accuracy 0.804688
Iter0, step 80, training accuracy 0.789062
Iter0, step 85, training accuracy 0.875
Iter0, step 90, training accuracy 0.859375
Iter0, step 95, training accuracy 0.875
Iter0, step 100, training accuracy 0.835938
Iter0, step 105, training accuracy 0.84375
Iter0, step 110, training accuracy 0.859375
Iter0, step 115, training accuracy 0.867188
Iter0, step 120, training accuracy 0.875
Iter0, step 125, training accuracy 0.875
Iter0, step 130, training accuracy 0.898438
Iter0, step 135, training accuracy 0.90625
Iter0, step 140, training accuracy 0.875
Iter0, step 145, training accuracy 0.84375
Iter0, step 150, training accuracy 0.90625
Iter0, step 155, training accuracy 0.90625
Iter0, step 160, training accuracy 0.914062
Iter0, step 165, training accuracy 0.914062
Iter0, step 170, training accuracy 0.828125
Iter0, step 175, training accuracy 0.914062
Iter0, step 180, training accuracy 0.898438
Iter0, step 185, training accuracy 0.921875
Iter0, step 190, training accuracy 0.914062
Iter0, step 195, training accuracy 0.929688
Iter0, step 200, training accuracy 0.914062
Iter0, step 205, training accuracy 0.921875
Iter0, step 210, training accuracy 0.929688
Iter0, step 215, training accuracy 0.929688
Iter0, step 220, training accuracy 0.929688
Iter0, step 225, training accuracy 0.914062
Iter0, step 230, training accuracy 0.914062
Iter0, step 235, training accuracy 0.929688
Iter0, step 240, training accuracy 0.945312
Iter0, step 245, training accuracy 0.914062
Iter0, step 250, training accuracy 0.9375
Iter0, step 255, training accuracy 0.953125
Iter0, step 260, training accuracy 0.953125
Iter0, step 265, training accuracy 0.953125
Iter0, step 270, training accuracy 0.921875
Iter0, step 275, training accuracy 0.929688
Iter0, step 280, training accuracy 0.890625
Iter0, step 285, training accuracy 0.945312
Iter0, step 290, training accuracy 0.929688
Iter0, step 295, training accuracy 0.945312
Iter0, step 300, training accuracy 0.914062
Iter0, step 305, training accuracy 0.929688
Iter0, step 310, training accuracy 0.929688
Iter0, step 315, training accuracy 0.945312
Iter0, step 320, training accuracy 0.960938
Iter0, step 325, training accuracy 0.914062
Iter0, step 330, training accuracy 0.945312
Iter0, step 335, training accuracy 0.921875
Iter0, step 340, training accuracy 0.929688
Iter0, step 345, training accuracy 0.921875
Iter0, step 350, training accuracy 0.9375
Iter0, step 355, training accuracy 0.953125
Iter0, step 360, training accuracy 0.953125
Iter0, step 365, training accuracy 0.9375
Iter0, step 370, training accuracy 0.953125
Iter0, step 375, training accuracy 0.953125
Iter0, step 380, training accuracy 0.9375
Iter0, step 385, training accuracy 0.945312
Iter0, step 390, training accuracy 0.960938
Iter0, step 395, training accuracy 0.921875
Iter0, step 400, training accuracy 0.960938
Iter0, step 405, training accuracy 0.960938
Iter0, step 410, training accuracy 0.96875
Iter0, step 415, training accuracy 0.96875
Iter0, step 420, training accuracy 0.945312
Iter0, step 425, training accuracy 0.921875
Iter1, step 430, training accuracy 0.953125
Iter1, step 435, training accuracy 0.984375
Iter1, step 440, training accuracy 0.921875
Iter1, step 445, training accuracy 0.976562
Iter1, step 450, training accuracy 0.945312
Iter1, step 455, training accuracy 0.976562
Iter1, step 460, training accuracy 0.921875
Iter1, step 465, training accuracy 0.976562
Iter1, step 470, training accuracy 0.945312
Iter1, step 475, training accuracy 0.960938
Iter1, step 480, training accuracy 0.976562
Iter1, step 485, training accuracy 0.945312
Iter1, step 490, training accuracy 0.976562
Iter1, step 495, training accuracy 0.96875
Iter1, step 500, training accuracy 0.992188
Iter1, step 505, training accuracy 0.953125
Iter1, step 510, training accuracy 0.960938
Iter1, step 515, training accuracy 0.9375
Iter1, step 520, training accuracy 0.945312
Iter1, step 525, training accuracy 0.945312
Iter1, step 530, training accuracy 0.96875
Iter1, step 535, training accuracy 0.976562
Iter1, step 540, training accuracy 0.929688
Iter1, step 545, training accuracy 0.976562
Iter1, step 550, training accuracy 0.96875
Iter1, step 555, training accuracy 0.945312
Iter1, step 560, training accuracy 0.984375
Iter1, step 565, training accuracy 0.921875
Iter1, step 570, training accuracy 0.945312
Iter1, step 575, training accuracy 0.96875
Iter1, step 580, training accuracy 0.953125
Iter1, step 585, training accuracy 0.953125
Iter1, step 590, training accuracy 0.945312
Iter1, step 595, training accuracy 0.945312
Iter1, step 600, training accuracy 0.984375
Iter1, step 605, training accuracy 0.9375
Iter1, step 610, training accuracy 0.953125
Iter1, step 615, training accuracy 0.960938
Iter1, step 620, training accuracy 0.976562
Iter1, step 625, training accuracy 0.96875
Iter1, step 630, training accuracy 0.953125
Iter1, step 635, training accuracy 0.992188
Iter1, step 640, training accuracy 0.929688
Iter1, step 645, training accuracy 0.960938
Iter1, step 650, training accuracy 0.984375
Iter1, step 655, training accuracy 0.953125
Iter1, step 660, training accuracy 0.960938
Iter1, step 665, training accuracy 0.984375
Iter1, step 670, training accuracy 0.953125
Iter1, step 675, training accuracy 0.96875
Iter1, step 680, training accuracy 0.984375
Iter1, step 685, training accuracy 0.976562
Iter1, step 690, training accuracy 0.992188
Iter1, step 695, training accuracy 0.96875
Iter1, step 700, training accuracy 0.953125
Iter1, step 705, training accuracy 0.960938
Iter1, step 710, training accuracy 0.960938
Iter1, step 715, training accuracy 0.929688
Iter1, step 720, training accuracy 0.976562
Iter1, step 725, training accuracy 0.96875
Iter1, step 730, training accuracy 0.960938
Iter1, step 735, training accuracy 0.976562
Iter1, step 740, training accuracy 0.984375
Iter1, step 745, training accuracy 0.976562
Iter1, step 750, training accuracy 0.96875
Iter1, step 755, training accuracy 0.960938
Iter1, step 760, training accuracy 0.945312
Iter1, step 765, training accuracy 0.96875
Iter1, step 770, training accuracy 0.953125
Iter1, step 775, training accuracy 0.921875
Iter1, step 780, training accuracy 0.96875
Iter1, step 785, training accuracy 0.96875
Iter1, step 790, training accuracy 0.96875
Iter1, step 795, training accuracy 0.960938
Iter1, step 800, training accuracy 0.976562
Iter1, step 805, training accuracy 0.96875
Iter1, step 810, training accuracy 0.984375
Iter1, step 815, training accuracy 0.96875
Iter1, step 820, training accuracy 0.976562
Iter1, step 825, training accuracy 0.984375
Iter1, step 830, training accuracy 0.976562
Iter1, step 835, training accuracy 0.992188
Iter1, step 840, training accuracy 0.976562
Iter1, step 845, training accuracy 0.960938
Iter1, step 850, training accuracy 0.992188
Iter1, step 855, training accuracy 0.960938
Iter2, step 860, training accuracy 0.953125
Iter2, step 865, training accuracy 0.96875
Iter2, step 870, training accuracy 0.976562
Iter2, step 875, training accuracy 0.96875
Iter2, step 880, training accuracy 0.984375
Iter2, step 885, training accuracy 0.960938
Iter2, step 890, training accuracy 0.960938
Iter2, step 895, training accuracy 0.976562
Iter2, step 900, training accuracy 0.984375
Iter2, step 905, training accuracy 1
Iter2, step 910, training accuracy 0.976562
Iter2, step 915, training accuracy 0.96875
Iter2, step 920, training accuracy 0.960938
Iter2, step 925, training accuracy 0.992188
Iter2, step 930, training accuracy 1
Iter2, step 935, training accuracy 0.984375
Iter2, step 940, training accuracy 0.96875
Iter2, step 945, training accuracy 0.976562
Iter2, step 950, training accuracy 0.976562
Iter2, step 955, training accuracy 0.976562
Iter2, step 960, training accuracy 0.984375
Iter2, step 965, training accuracy 0.976562
Iter2, step 970, training accuracy 0.960938
Iter2, step 975, training accuracy 0.984375
Iter2, step 980, training accuracy 0.976562
Iter2, step 985, training accuracy 0.953125
Iter2, step 990, training accuracy 0.960938
Iter2, step 995, training accuracy 0.992188
Iter2, step 1000, training accuracy 0.960938
Iter2, step 1005, training accuracy 1
Iter2, step 1010, training accuracy 0.96875
Iter2, step 1015, training accuracy 0.953125
Iter2, step 1020, training accuracy 0.984375
Iter2, step 1025, training accuracy 0.960938
Iter2, step 1030, training accuracy 0.96875
Iter2, step 1035, training accuracy 0.953125
Iter2, step 1040, training accuracy 0.984375
Iter2, step 1045, training accuracy 0.984375
Iter2, step 1050, training accuracy 0.976562
Iter2, step 1055, training accuracy 0.976562
Iter2, step 1060, training accuracy 0.96875
Iter2, step 1065, training accuracy 0.984375
Iter2, step 1070, training accuracy 1
Iter2, step 1075, training accuracy 0.976562
Iter2, step 1080, training accuracy 0.976562
Iter2, step 1085, training accuracy 0.984375
Iter2, step 1090, training accuracy 0.984375
Iter2, step 1095, training accuracy 0.96875
Iter2, step 1100, training accuracy 0.976562
Iter2, step 1105, training accuracy 0.960938
Iter2, step 1110, training accuracy 0.976562
Iter2, step 1115, training accuracy 0.96875
Iter2, step 1120, training accuracy 0.976562
Iter2, step 1125, training accuracy 0.992188
Iter2, step 1130, training accuracy 0.992188
Iter2, step 1135, training accuracy 0.992188
Iter2, step 1140, training accuracy 0.945312
Iter2, step 1145, training accuracy 0.984375
Iter2, step 1150, training accuracy 0.992188
Iter2, step 1155, training accuracy 0.984375
Iter2, step 1160, training accuracy 0.96875
Iter2, step 1165, training accuracy 0.96875
Iter2, step 1170, training accuracy 0.976562
Iter2, step 1175, training accuracy 0.960938
Iter2, step 1180, training accuracy 1
Iter2, step 1185, training accuracy 0.984375
Iter2, step 1190, training accuracy 0.992188
Iter2, step 1195, training accuracy 0.976562
Iter2, step 1200, training accuracy 0.96875
Iter2, step 1205, training accuracy 0.984375
Iter2, step 1210, training accuracy 0.976562
Iter2, step 1215, training accuracy 0.992188
Iter2, step 1220, training accuracy 0.992188
Iter2, step 1225, training accuracy 0.96875
Iter2, step 1230, training accuracy 0.992188
Iter2, step 1235, training accuracy 0.976562
Iter2, step 1240, training accuracy 0.976562
Iter2, step 1245, training accuracy 0.984375
Iter2, step 1250, training accuracy 0.960938
Iter2, step 1255, training accuracy 0.992188
Iter2, step 1260, training accuracy 0.984375
Iter2, step 1265, training accuracy 0.992188
Iter2, step 1270, training accuracy 0.992188
Iter2, step 1275, training accuracy 0.976562
Iter2, step 1280, training accuracy 0.976562
Iter2, step 1285, training accuracy 0.976562
Iter3, step 1290, training accuracy 0.984375
Iter3, step 1295, training accuracy 0.984375
Iter3, step 1300, training accuracy 1
Iter3, step 1305, training accuracy 0.992188
Iter3, step 1310, training accuracy 1
Iter3, step 1315, training accuracy 0.984375
Iter3, step 1320, training accuracy 0.992188
Iter3, step 1325, training accuracy 0.992188
Iter3, step 1330, training accuracy 0.960938
Iter3, step 1335, training accuracy 0.96875
Iter3, step 1340, training accuracy 0.976562
Iter3, step 1345, training accuracy 0.984375
Iter3, step 1350, training accuracy 0.984375
Iter3, step 1355, training accuracy 0.984375
Iter3, step 1360, training accuracy 0.976562
Iter3, step 1365, training accuracy 0.992188
Iter3, step 1370, training accuracy 0.96875
Iter3, step 1375, training accuracy 0.96875
Iter3, step 1380, training accuracy 0.976562
Iter3, step 1385, training accuracy 0.96875
Iter3, step 1390, training accuracy 0.992188
Iter3, step 1395, training accuracy 0.984375
Iter3, step 1400, training accuracy 0.976562
Iter3, step 1405, training accuracy 0.992188
Iter3, step 1410, training accuracy 0.992188
Iter3, step 1415, training accuracy 0.953125
Iter3, step 1420, training accuracy 0.984375
Iter3, step 1425, training accuracy 0.984375
Iter3, step 1430, training accuracy 0.984375
Iter3, step 1435, training accuracy 0.984375
Iter3, step 1440, training accuracy 0.96875
Iter3, step 1445, training accuracy 0.96875
Iter3, step 1450, training accuracy 0.984375
Iter3, step 1455, training accuracy 0.976562
Iter3, step 1460, training accuracy 0.984375
Iter3, step 1465, training accuracy 0.984375
Iter3, step 1470, training accuracy 1
Iter3, step 1475, training accuracy 0.984375
Iter3, step 1480, training accuracy 0.992188
Iter3, step 1485, training accuracy 0.992188
Iter3, step 1490, training accuracy 0.992188
Iter3, step 1495, training accuracy 0.976562
Iter3, step 1500, training accuracy 0.984375
Iter3, step 1505, training accuracy 1
Iter3, step 1510, training accuracy 0.984375
Iter3, step 1515, training accuracy 0.984375
Iter3, step 1520, training accuracy 0.992188
Iter3, step 1525, training accuracy 0.960938
Iter3, step 1530, training accuracy 0.984375
Iter3, step 1535, training accuracy 1
Iter3, step 1540, training accuracy 0.976562
Iter3, step 1545, training accuracy 0.984375
Iter3, step 1550, training accuracy 0.984375
Iter3, step 1555, training accuracy 0.992188
Iter3, step 1560, training accuracy 0.976562
Iter3, step 1565, training accuracy 0.984375
Iter3, step 1570, training accuracy 0.992188
Iter3, step 1575, training accuracy 0.976562
Iter3, step 1580, training accuracy 0.992188
Iter3, step 1585, training accuracy 0.992188
Iter3, step 1590, training accuracy 0.976562
Iter3, step 1595, training accuracy 0.992188
Iter3, step 1600, training accuracy 0.976562
Iter3, step 1605, training accuracy 0.96875
Iter3, step 1610, training accuracy 0.984375
Iter3, step 1615, training accuracy 0.984375
Iter3, step 1620, training accuracy 0.976562
Iter3, step 1625, training accuracy 0.976562
Iter3, step 1630, training accuracy 0.992188
Iter3, step 1635, training accuracy 0.976562
Iter3, step 1640, training accuracy 0.984375
Iter3, step 1645, training accuracy 1
Iter3, step 1650, training accuracy 0.992188
Iter3, step 1655, training accuracy 0.976562
Iter3, step 1660, training accuracy 0.984375
Iter3, step 1665, training accuracy 0.992188
Iter3, step 1670, training accuracy 0.96875
Iter3, step 1675, training accuracy 0.992188
Iter3, step 1680, training accuracy 0.976562
Iter3, step 1685, training accuracy 0.992188
Iter3, step 1690, training accuracy 1
Iter3, step 1695, training accuracy 0.960938
Iter3, step 1700, training accuracy 0.945312
Iter3, step 1705, training accuracy 0.976562
Iter3, step 1710, training accuracy 0.984375
Iter3, step 1715, training accuracy 0.992188
Iter4, step 1720, training accuracy 0.984375
Iter4, step 1725, training accuracy 0.992188
Iter4, step 1730, training accuracy 0.992188
Iter4, step 1735, training accuracy 1
Iter4, step 1740, training accuracy 1
Iter4, step 1745, training accuracy 0.984375
Iter4, step 1750, training accuracy 1
Iter4, step 1755, training accuracy 0.96875
Iter4, step 1760, training accuracy 0.96875
Iter4, step 1765, training accuracy 0.984375
Iter4, step 1770, training accuracy 0.992188
Iter4, step 1775, training accuracy 0.992188
Iter4, step 1780, training accuracy 0.976562
Iter4, step 1785, training accuracy 0.992188
Iter4, step 1790, training accuracy 1
Iter4, step 1795, training accuracy 0.992188
Iter4, step 1800, training accuracy 0.992188
Iter4, step 1805, training accuracy 0.976562
Iter4, step 1810, training accuracy 0.984375
Iter4, step 1815, training accuracy 0.976562
Iter4, step 1820, training accuracy 0.984375
Iter4, step 1825, training accuracy 0.992188
Iter4, step 1830, training accuracy 0.984375
Iter4, step 1835, training accuracy 0.992188
Iter4, step 1840, training accuracy 0.984375
Iter4, step 1845, training accuracy 0.976562
Iter4, step 1850, training accuracy 0.984375
Iter4, step 1855, training accuracy 0.984375
Iter4, step 1860, training accuracy 0.984375
Iter4, step 1865, training accuracy 0.96875
Iter4, step 1870, training accuracy 0.992188
Iter4, step 1875, training accuracy 0.976562
Iter4, step 1880, training accuracy 1
Iter4, step 1885, training accuracy 0.976562
Iter4, step 1890, training accuracy 0.976562
Iter4, step 1895, training accuracy 0.96875
Iter4, step 1900, training accuracy 0.984375
Iter4, step 1905, training accuracy 0.992188
Iter4, step 1910, training accuracy 0.992188
Iter4, step 1915, training accuracy 1
Iter4, step 1920, training accuracy 0.960938
Iter4, step 1925, training accuracy 0.984375
Iter4, step 1930, training accuracy 0.992188
Iter4, step 1935, training accuracy 0.984375
Iter4, step 1940, training accuracy 0.992188
Iter4, step 1945, training accuracy 1
Iter4, step 1950, training accuracy 1
Iter4, step 1955, training accuracy 0.984375
Iter4, step 1960, training accuracy 0.992188
Iter4, step 1965, training accuracy 0.984375
Iter4, step 1970, training accuracy 0.992188
Iter4, step 1975, training accuracy 0.984375
Iter4, step 1980, training accuracy 0.984375
Iter4, step 1985, training accuracy 1
Iter4, step 1990, training accuracy 0.984375
Iter4, step 1995, training accuracy 0.984375
Iter4, step 2000, training accuracy 0.96875
GPU log

 

Code analysis: 

# coding: utf-8

# **tensorflow 版本: 1.2.1**
# 
# 通过本例,你可以了解到单层 LSTM 的实现多层 LSTM 的实现。输入输出数据的格式。 RNN 的 dropout layer 的实现
#
# From: https://github.com/yongyehuang/Tensorflow-Tutorial

# In[3]:


import tensorflow as tf
import numpy as np
from tensorflow.contrib import rnn
from tensorflow.examples.tutorials.mnist import input_data

# 设置 GPU 按需增长
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# 首先导入数据,看一下数据的形式
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
print (mnist.train.images.shape)


#  ** 一、首先设置好模型用到的各个超参数 **

# In[4]:

lr = 1e-3
input_size    = 28   # 每个时刻的输入特征是28维的,就是每个时刻输入一行,一行有 28 个像素
timestep_size = 28   # 时序持续长度为28,即每做一次预测,需要先输入28行
hidden_size   = 256  # 隐含层的width
layer_num     = 2    # LSTM layer 的层数
class_num     = 10   # 最后输出分类类别数量,如果是回归预测的话应该是 1

_X = tf.placeholder(tf.float32, [None, 784])
y  = tf.placeholder(tf.float32, [None, class_num])
# 在训练和测试的时候,我们想用不同的 batch_size.所以采用占位符的方式
batch_size = tf.placeholder(tf.int32, [])       # 注意类型必须为 tf.int32, batch_size = 128
keep_prob  = tf.placeholder(tf.float32, [])


#  ** 二、开始搭建 LSTM 模型,其实普通 RNNs 模型也一样 **

# In[5]:

# 把784个点的字符信息还原成 28 * 28 的图片
# 下面几个步骤是实现 RNN / LSTM 的关键
####################################################################
# # **步骤1:RNN 的输入shape = (batch_size, timestep_size, input_size) 
X = tf.reshape(_X, [-1, 28, 28])

# # **步骤2:定义一层 LSTM_cell,只需要说明 hidden_size, 它会自动匹配输入的 X 的维度
# lstm_cell = rnn.BasicLSTMCell(num_units=hidden_size, forget_bias=1.0, state_is_tuple=True)

# # **步骤3:添加 dropout layer, 一般只设置 output_keep_prob
# lstm_cell = rnn.DropoutWrapper(cell=lstm_cell, input_keep_prob=1.0, output_keep_prob=keep_prob)

# # **步骤4:调用 MultiRNNCell 来实现多层 LSTM
# mlstm_cell = rnn.MultiRNNCell([lstm_cell] * layer_num, state_is_tuple=True)
# mlstm_cell = rnn.MultiRNNCell([lstm_cell for _ in range(layer_num)] , state_is_tuple=True)

# 在 tf 1.0.0 版本中,可以使用上面的 三个步骤创建多层 lstm, 但是在 tf 1.2.1 版本中,可以通过下面方式来创建
def lstm_cell():
    cell = rnn.LSTMCell(hidden_size, reuse=tf.get_variable_scope().reuse)
    return rnn.DropoutWrapper(cell, output_keep_prob=keep_prob)

mlstm_cell= tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(layer_num)], state_is_tuple = True)

# **步骤5:用全零来初始化state
init_state = mlstm_cell.zero_state(batch_size, dtype=tf.float32)

# **步骤6:方法一,调用 dynamic_rnn() 来让我们构建好的网络运行起来
# ** 当 time_major==False 时, outputs.shape = [batch_size, timestep_size, hidden_size] 
# ** 所以,可以取 h_state = outputs[:, -1, :] 作为最后输出
# ** state.shape = [layer_num, 2, batch_size, hidden_size], 
# ** 或者,可以取 h_state = state[-1][1] 作为最后输出
# ** 最后输出维度是 [batch_size, hidden_size]
# # outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False) # h_state = state[-1][1] # *************** 为了更好的理解 LSTM 工作原理,我们把上面 步骤6 中的函数自己来实现 *************** # 通过查看文档你会发现, RNNCell 都提供了一个 __call__()函数,我们可以用它来展开实现LSTM按时间步迭代。 # **步骤6:方法二,按时间步展开计算 outputs = list() state = init_state
with tf.variable_scope(
'RNN'): for timestep in range(timestep_size): if timestep > 0: tf.get_variable_scope().reuse_variables() # 这里的state保存了每一层 LSTM 的状态 (cell_output, state) =mlstm_cell(X[:, timestep, :],state) outputs.append(cell_output) h_state = outputs[-1] # ** 三、最后设置 loss function 和 优化器,展开训练并完成测试 ** # In[ ]: ############################################################################ # 以下部分其实和之前写的多层 CNNs 来实现 MNIST 分类是一样的。 # 只是在测试的时候也要设置一样的 batch_size. # 上面 LSTM 部分的输出会是一个 [hidden_size] 的tensor,我们要分类的话,还需要接一个 softmax 层 # 首先定义 softmax 的连接权重矩阵和偏置 # out_W = tf.placeholder(tf.float32, [hidden_size, class_num], name='out_Weights') # out_bias = tf.placeholder(tf.float32, [class_num], name='out_bias') # 开始训练和测试 W = tf.Variable(tf.truncated_normal([hidden_size, class_num], stddev=0.1), dtype=tf.float32) bias = tf.Variable(tf.constant(0.1,shape=[class_num]), dtype=tf.float32) y_pre = tf.nn.softmax(tf.matmul(h_state, W) + bias) # 损失和评估函数 cross_entropy = -tf.reduce_mean(y * tf.log(y_pre)) train_op = tf.train.AdamOptimizer(lr).minimize(cross_entropy) correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(y,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) sess.run(tf.global_variables_initializer()) for i in range(2000): _batch_size = 128 batch = mnist.train.next_batch(_batch_size) if (i+1)%5 == 0: train_accuracy = sess.run(accuracy, feed_dict={_X:batch[0], y: batch[1], keep_prob: 1.0, batch_size: _batch_size}) print("Iter%d, step %d, training accuracy %g" % ( mnist.train.epochs_completed, (i+1), train_accuracy)) sess.run(train_op, feed_dict={_X: batch[0], y: batch[1], keep_prob: 0.5, batch_size: _batch_size}) # 计算测试数据的准确率 print("test accuracy %g"% sess.run(accuracy, feed_dict={ _X: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0, batch_size:mnist.test.images.shape[0]})) # 我们一共只迭代不到5个epoch,在测试集上就已经达到了0.98的准确率,可以看出来 LSTM 在做这个字符分类的任务上还是比较有效的,
# 而且我们最后一次性对 10000 张测试图片进行预测,才占了 725 MiB 的显存。而我们在之前的两层 CNNs 网络中,预测 10000 张图片一共用了 8721 MiB 的显存,差了整整 12 倍呀!!
# 这主要是因为 RNN/LSTM 网络中,每个时间步所用的权值矩阵都是共享的,可以通过前面介绍的 LSTM 的网络结构分析一下,整个网络的参数非常少。


# ## 四、可视化看看 LSTM 的是怎么做分类的 # 毕竟 LSTM 更多的是用来做时序相关的问题,要么是文本,要么是序列预测之类的,所以很难像 CNNs 一样非常直观地看到每一层中特征的变化。
# 在这里,我想通过可视化的方式来帮助大家理解 LSTM 是怎么样一步一步地把图片正确的给分类。
# In[ ]: # 手写的结果 shape _batch_size = 5 X_batch, y_batch = mnist.test.next_batch(_batch_size) print(X_batch.shape, y_batch.shape) _outputs, _state = np.array(sess.run([outputs, state], feed_dict={_X: X_batch, y: y_batch, keep_prob: 1.0, batch_size: _batch_size})) print('_outputs.shape =', np.asarray(_outputs).shape) print('arr_state.shape =', np.asarray(_state).shape) # 可见:
# outputs.shape = [ batch_size, timestep_size, hidden_size]
# state.shape = [layer_num, 2, batch_size, hidden_size] # 看下面我找了一个字符 3 # In[ ]: import matplotlib.pyplot as plt # In[ ]: print(mnist.train.labels[4]) # 我们先来看看这个字符样子,上半部分还挺像 2 来的 # In[ ]: X3 = mnist.train.images[4] img3 = X3.reshape([28, 28]) plt.imshow(img3, cmap='gray') plt.show()
# 我们看看在分类的时候,一行一行地输入,分为各个类别的概率会是什么样子的

# In[14]:

X3.shape = [-1, 784]
y_batch = mnist.train.labels[0]
y_batch.shape = [-1, class_num]

X3_outputs = np.array(sess.run(outputs, feed_dict={_X: X3, y: y_batch, keep_prob: 1.0, batch_size: 1}))
print(X3_outputs.shape)
X3_outputs.shape = [28, hidden_size]
print(X3_outputs.shape)


# In[15]:

h_W    = sess.run(W,    feed_dict={_X:X3, y: y_batch, keep_prob: 1.0, batch_size: 1})
h_bias = sess.run(bias, feed_dict={_X:X3, y: y_batch, keep_prob: 1.0, batch_size: 1})
h_bias.shape = [-1, 10]

bar_index = range(class_num)
for i in xrange(X3_outputs.shape[0]):
    plt.subplot(7, 4, i+1)
    X3_h_shate = X3_outputs[i, :].reshape([-1, hidden_size])
    pro = sess.run(tf.nn.softmax(tf.matmul(X3_h_shate, h_W) + h_bias))

print("pro.shape:", pro.shape)
print("pro[0] :", pro[0])
# [ 4.75662528e-05 1.90045666e-05 8.20193236e-05 9.71286136e-06
# 8.26372998e-05 2.28238772e-04 9.99474943e-01 2.17880233e-06
# 5.12166080e-05 2.49308982e-06]

    plt.bar(bar_index, pro[0], width=0.2 , align='center')
    plt.axis('off')
plt.show()
# 在上面的图中,为了更清楚地看到线条的变化,我把坐标都去了,每一行显示了 4 个图,共有 7 行,表示了一行一行读取过程中,模型对字符的识别。
可以看到,在只看到前面的几行像素时,模型根本认不出来是什么字符,随着看到的像素越来越多,最后就基本确定了它是字符 3.

 

posted @ 2017-10-01 20:51  郝壹贰叁  阅读(422)  评论(0编辑  收藏  举报