tensorflow_用多层感知器模型训练MNIST数据集

import warnings
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
warnings.filterwarnings('ignore')

# 导入数据mnist
path = './mnist/input_data'
mnist = input_data.read_data_sets(path, one_hot=True)

# 定义网络参数
learning_rate = 0.01
iterations = 360
batch_number = 100
display_step = 10

# 定义占位符
# 因为这里的图片是28*28,所以是784个特征字段,输出为0-->9,所以Y输出为10列
X = tf.placeholder(tf.float32, shape=[None, 784], name='X')
Y = tf.placeholder(tf.float32, shape=[None, 10], name='Y')

# 定义隐层的对应的神经元数量
hidden_layer1 = 4096
hidden_layer2 = 2048

# 变量初始化
t = tf.truncated_normal_initializer(mean=0.21, stddev=0.1, seed=7)
weight = {
#     'w_layer1': tf.get_variable('w_layer1', [784, hidden_layer1], initializer=t),
    "w_layer1": tf.Variable(tf.random_normal([784, hidden_layer1])),
#     'out_layer': tf.get_variable('out_layer', [hidden_layer1, 10], initializer=t),
    "out_layer": tf.Variable(tf.random_normal([hidden_layer1, 10]))  
}

bias = {
#     'bias1': tf.Variable(tf.truncated_normal([hidden_layer1], mean=0.21, stddev=0.1, seed=7)),
#     'bias_out': tf.Variable(tf.truncated_normal([10], mean=0.21, stddev=0.1, seed=7)),
    'bias1': tf.Variable(tf.random_normal([hidden_layer1])),
    'bias_out': tf.Variable(tf.random_normal([10])),
}

weight['w_layer1'].shape
out[7]: TensorShape([Dimension(784), Dimension(4096)])

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    print(weight['w_layer1'][0].eval())
    print(bias['bias_out'].eval())
    sess.close()
out[8]: 
    [-0.9179551   2.620515    0.84942716 ...  1.2998239  -0.58302623
 -0.2974956 ]
    [ 1.1626408  -0.7264125   0.1215397   0.47657487 -0.20571353  0.42560428
  1.0548028  -2.1683683  -0.8865439  -0.28622103]

# 构建多层感知器模型
def multilayer_perceptron(x, w, b):
    hidden_layer_one = tf.matmul(x , w['w_layer1']) + b['bias1']
    hidden_layer_one = tf.nn.relu(hidden_layer_one)
    out_layer = tf.matmul(hidden_layer_one, w['out_layer']) + b['bias_out']
    return out_layer

# 调用多层感知器模型,获得预测值
y_pred = multilayer_perceptron(X, weight, bias)

# 定义损失代价函数 tf.reduce_mean()计算平均值
cross_entropys = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y, logits=y_pred)
cross_entropy = tf.reduce_mean(cross_entropys)

# 定义优化函数,进行"梯度下降"
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)


init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    
    for iteration in range(iterations):
        # 初始训练误差,计算每轮批量迭代次数
        train_cost=0
        batch_times=int(mnist.train._num_examples/batch_number)
        for i in range(batch_times):
            # 每次取batch_number张图
            batch_X, batch_Y = mnist.train.next_batch(batch_number)
            #  运行优化函数
            # 这里返回一个[optimizer,cost]的list, 其中 _代表optimizer,batch_cost代表cost的值
            _,batch_cost = sess.run([optimizer,cross_entropy],feed_dict={X:batch_X, Y:batch_Y})
            # 返回训练集误差:每次计算batch_number张图的batch_cost,计算了i次,所以最后除以batch_numbers
            train_cost += batch_cost / batch_times
            
        if iteration % display_step==0:

            # %04d: % 转义说明符 ; 0 指以0填充前面的位数 ;4 四位数; d 十进制整数
            # "{:.9f}".format(train_cost)  以保留小数点后9位显示train_cost
            prediction=tf.equal(tf.argmax(y_pred, 1), tf.argmax(Y,1))
            accuracy=tf.reduce_mean(tf.cast(prediction,"float"))
            print("Epoch:","%04d"%(iteration+1), "Train_cost :","{:.9f}  ".format(train_cost),
                  "accuracy :", sess.run(accuracy,feed_dict={X:mnist.test.images, Y:mnist.test.labels}))
#             print("accuracy :",sess.run(accuracy,feed_dict={X:mnist.test.images, Y:mnist.test.labels}))
    
    # tf.arg_max(pred,1):得到向量中最大数的下标,1代表水平方向
    # tf.equal():返回布尔值,相等返回1,否则0
    # 最后返回大小[none,1]的向量,1所在位置为布尔类型数据       
    prediction=tf.equal(tf.argmax(y_pred, 1), tf.argmax(Y,1))
    # tf.cast():将布尔型向量转换成浮点型向量
    # tf.reduce_mean():求所有数的均值
    # 返回正确率:也就是所有为1的数目占所有数目的比例
    accuracy=tf.reduce_mean(tf.cast(prediction,"float"))
    
    # 打印正确率
    print("Train :",sess.run(accuracy,feed_dict={X:mnist.train.images,Y:mnist.train.labels}))
    print("Test  :",sess.run(accuracy,feed_dict={X:mnist.test.images,Y:mnist.test.labels}))
    sess.close()

  

 

 

为了看清训练过程的一些细节,可以观看如下代码及其结果:

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    batch_X, batch_Y = mnist.train.next_batch(100)
    test_y = sess.run(y_pred, feed_dict={X:batch_X})
    print(type(test_y))
    print(test_y[0])
    print(tf.argmax(test_y, 1))
    print(tf.argmax(test_y, 1).eval())
    print(tf.argmax(batch_Y, 1).eval())
    print("原始的标签数据", batch_Y)
    sess.close()
#运行结果:
<class 'numpy.ndarray'>
[ 766.4617  -133.16173  604.783    758.3629   595.90173 -153.06392
  321.0015  -835.967   -668.4424  -585.8179 ]
Tensor("ArgMax_74:0", shape=(100,), dtype=int64)
[0 4 4 6 4 4 0 2 4 0 4 4 4 0 0 3 4 0 4 4 0 3 4 0 4 0 0 0 4 2 4 4 4 4 4 4 4
 0 2 4 0 0 4 4 0 4 4 0 4 3 4 3 0 2 4 6 4 0 0 0 5 0 4 0 2 0 0 0 0 0 4 0 2 3
 3 4 0 6 4 4 0 4 4 2 3 4 0 4 0 0 4 4 0 0 4 0 4 2 0 4]
[3 8 8 6 4 8 4 5 7 8 0 3 4 3 8 5 8 3 0 9 6 7 5 7 9 8 3 4 0 3 1 8 5 5 3 4 0
 1 8 8 7 9 3 7 8 8 3 2 7 7 2 6 4 4 3 7 8 8 8 0 8 5 0 6 8 6 8 5 7 4 1 4 4 4
 1 9 9 7 5 8 2 3 7 6 6 7 7 4 4 6 7 1 0 1 5 9 8 4 2 0]
原始的标签数据 [[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

 

posted @ 2018-11-25 18:05  巴蜀秀才  阅读(362)  评论(0编辑  收藏  举报