Tensorflow入门实战-mnist手写体识别
1 ''' 2 tensorflow 教程 3 mnist样例 4 ''' 5 import tensorflow as tf 6 from tensorflow.examples.tutorials.mnist import input_data 7 8 #参数设置 9 INPUT_NODE=784 10 OUTPUT_NODE=10 11 LAYER1_NODE=500 12 BATCH_SIZE=100 13 LEARNING_RATE_BASE=0.8 14 LEARNING_RATE_DECAY=0.99 15 REGULARIZATION_RATE=0.0001 16 TRAINING_STEPS=10000 17 MOVEING_AVEARGE_DECAY=0.99 18 19 20 def inference(input_tensor,avg_class,weights1,biases1,weights2,biases2): 21 ''' 22 定义前向计算的过程: 23 avg_class是滑动平均函数,使权重平滑过渡,保留历史数据, 24 为None时,表示普通的参数更新过程 25 ''' 26 if avg_class==None: 27 layer1=tf.nn.relu(tf.matmul(input_tensor,weights1)+biases1) 28 return tf.matmul(layer1,weights2)+biases2 29 else: 30 layer1=tf.nn.relu(tf.matmul(input_tensor,avg_class.average(weights1)+avg_class.average(biases1))) 31 return tf.matmul(layer1,avg_class.average(weights2)+avg_class.average(biases2)) 32 33 def train(mnist): 34 #设置输入变量 placerholder表示占位,开启会话训练的时候需要传入数据 35 x=tf.placeholder(tf.float32,[None,INPUT_NODE],name='x-input') 36 y_=tf.placeholder(tf.float32,[None,OUTPUT_NODE],name='y-input') 37 38 #设置权重变量,variable表示训练时需要自动更新 39 weights1=tf.Variable(tf.random_normal([INPUT_NODE,LAYER1_NODE],stddev=0.1)) 40 biases1=tf.Variable(tf.constant(0.1,shape=[LAYER1_NODE])) 41 weights2=tf.Variable(tf.random_normal([LAYER1_NODE,OUTPUT_NODE],stddev=0.1)) 42 biases2=tf.Variable(tf.constant(0.1,shape=[OUTPUT_NODE])) 43 44 #y=inference(x,None,weights1,biases1,weights2,biases2) 45 46 global_step=tf.Variable(0,trainable=False)#不可更新参数 47 variable_averages=tf.train.ExponentialMovingAverage(MOVEING_AVEARGE_DECAY,global_step)#min(decay,(1+step)/(10+step)) 后面的变量会越来越大,表示参数的更新越来越稳定,大都依赖于历史数据 48 variable_averages_op=variable_averages.apply(tf.trainable_variables()) 49 average_y=inference(x,variable_averages,weights1,biases1,weights2,biases2) 50 51 cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=average_y,labels=tf.argmax(y_,1))#计算图的输出是每个分类的得分,但是要求输入的标签是正确答案的下标 52 cross_entropy_mean=tf.reduce_mean(cross_entropy) 53 54 regularizer=tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) 55 regularization=regularizer(weights1)+regularizer(weights2) 56 loss=cross_entropy+regularization 57 58 learning_rate=tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples/BATCH_SIZE,LEARNING_RATE_DECAY)#学习率成阶梯状衰减 每个epoch衰减一次,也就是一整轮数据训练完衰减一次 59 train_step=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step) 60 61 train_op=tf.group(train_step,variable_averages_op)#把反向传播是需要更新的参数打包,不使用滑动平均不需要这句话,因为只更新权重。滑动平均还要利用历史数据更新并更新历史数据 62 63 correct_prediction=tf.equal(tf.argmax(average_y,1),tf.argmax(y_,1)) 64 accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) 65 66 67 with tf.Session() as sess: 68 tf.global_variables_initializer().run() 69 validate_feed={x:mnist.validation.images,y_:mnist.validation.labels} 70 test_feed={x:mnist.test.images,y_:mnist.test.labels} 71 72 for i in range(TRAINING_STEPS): 73 if i%1000==0: 74 validate_acc=sess.run(accuracy,feed_dict=validate_feed) 75 print('After %d training steps,validation accuracy using average model is %g' %(i,validate_acc)) 76 77 xs,ys=mnist.train.next_batch(BATCH_SIZE) 78 sess.run(train_op,feed_dict={x:xs,y_:ys}) 79 80 test_acc=sess.run(accuracy,feed_dict=test_feed) 81 print('After %d training steps,test accuracy using average model is %g' %(TRAINING_STEPS,test_acc)) 82 83 def main(argv=None): 84 mnist=input_data.read_data_sets("/tmp/data",one_hot=True) 85 train(mnist) 86 87 if __name__ == '__main__': 88 tf.app.run()