tensorflow实现简单的卷积神经网络
1 # MNIST训练 2 3 import tensorflow as tf 4 import matplotlib.pyplot as plt 5 from tensorflow.examples.tutorials.mnist import input_data 6 import numpy as np 7 8 mnist = input_data.read_data_sets('MNIST_data/',one_hot=True) 9 10 def weight_variable(shape): 11 initial = tf.truncated_normal(shape,stddev=0.1) 12 return tf.Variable(initial) 13 14 def bias_variable(shape): 15 initial = tf.constant(0.1,shape=shape) 16 return tf.Variable(initial) 17 18 def conv(x,w): 19 return tf.nn.conv2d(x,w,strides=[1,1,1,1], padding='SAME') 20 21 def max_pool(x): 22 return tf.nn.max_pool(x,ksize = [1, 2, 2, 1],strides=[1,2,2,1],padding='SAME') 23 24 x = tf.placeholder(tf.float32,shape=[None,784]) 25 y_ = tf.placeholder(tf.float32,shape=[None,10]) 26 x_image = tf.reshape(x,[-1,28,28,1]) 27 28 #卷积层1-池化层1 29 w_conv1 = weight_variable([5,5,1,32]) 30 b_conv1 = bias_variable([32]) 31 h_conv1 = tf.nn.relu(conv(x_image,w_conv1)+b_conv1) 32 h_pool1 = max_pool(h_conv1) 33 34 #卷积层2-池化层2 35 w_conv2 = weight_variable([5,5,32,64]) 36 b_conv2 = bias_variable([64]) 37 h_conv2 = tf.nn.relu(conv(h_pool1,w_conv2)+b_conv2) 38 h_pool2 = max_pool(h_conv2) 39 40 #全连接层 41 w_fc1 = weight_variable([7 * 7 *64,1024]) 42 b_fc1 = bias_variable([1024]) 43 h_pool_flat = tf.reshape(h_pool2,[-1,7 * 7 *64]) 44 h_fc1 = tf.nn.relu(tf.matmul(h_pool_flat,w_fc1)+b_fc1) 45 46 #dropout层 47 keep_drop = tf.placeholder(tf.float32) 48 h_fc1_drop = tf.nn.dropout(h_fc1,keep_drop) 49 50 #softmax层 51 w_fc2 = weight_variable([1024,10]) 52 b_fc2 = bias_variable([10]) 53 y = tf.matmul(h_fc1_drop,w_fc2)+b_fc2 54 55 #loss 56 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y,labels=y_)) 57 train_step = tf.train.AdamOptimizer(0.0001).minimize(loss) 58 #计算模型预测的准确率 59 correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) 60 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) 61 62 sess = tf.InteractiveSession() 63 init = tf.global_variables_initializer() 64 sess.run(init) 65 losses = [] 66 acc = [] 67 for i in range(2000): 68 batch = mnist.train.next_batch(50) 69 if i % 100 == 0: 70 train_accuracy = accuracy.eval(feed_dict={x:batch[0],y_:batch[1],keep_drop:1.0}) 71 print('step %d,training accuracy %g' %(i,train_accuracy)) 72 acc.append(train_accuracy) 73 loss_tmp = sess.run(loss,feed_dict={x:batch[0],y_:batch[1],keep_drop:1.0}) 74 losses.append(loss_tmp) 75 sess.run(train_step,feed_dict={x: batch[0], y_: batch[1], keep_drop: 0.5}) 76 print("test accuracy",accuracy.eval(feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_drop:1.0}))
参考文章:
1.https://www.cnblogs.com/willnote/p/6874699.html
作者:舟华520
出处:https://www.cnblogs.com/xfzh193/
本文以学习,分享,研究交流为主,欢迎转载,请标明作者出处!