Mnist字符识别-神经网络实现(TF框架)

Mnist字符识别-神经网络实现(TF框架)

该段代码即贴即用,先贴一下代码,有空的时候写个注释解析。大三的代码了,特别适合新手入门,现在都用Pytorch了。

电脑用的tensorflow版本是1.13.1的,用CPU跑也挺快的。之前用GPU跑了半小时准确率能达到98%左右。

代码

copy
# -*- coding:utf-8 -*- import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data from matplotlib import pyplot import matplotlib.pyplot as plt import numpy as np mnist = input_data.read_data_sets("MNIST_data", one_hot=True) seed=547 np.random.seed(seed) epoch_time = 20; ALPHY = 0.5 batch_size = 10 n_batch_all = mnist.train.num_examples // batch_size n_batch = 1000 // batch_size x = tf.placeholder(tf.float32,[None,784]) y = tf.placeholder(tf.float32,[None,10]) def xavier_init(size): in_dim = size[0] xavier_stddev = 1. / tf.sqrt(in_dim / 2.) return tf.random_normal(shape=size, stddev=xavier_stddev) W1 =tf.Variable(xavier_init([784, 30])) B1 = tf.Variable(tf.zeros([30])) L1 = tf.nn.sigmoid(tf.matmul(x,W1) + B1) W2 =tf.Variable(xavier_init([30, 10])) B2 = tf.Variable(tf.zeros([10])) logit_prediction = tf.matmul(L1,W2) + B2 prediction = tf.nn.sigmoid(logit_prediction) # MSE损失函数 # loss = tf.reduce_mean(tf.square(y - prediction)) #交叉熵损失函数 loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logit_prediction,labels=y) train_setup = tf.train.GradientDescentOptimizer(ALPHY).minimize(loss) init = tf.global_variables_initializer() correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) def getBatch(inputs): np.random.shuffle(inputs) batch = inputs[:10] fina_x = batch[:, :784] fina_y = batch[:, 784:794] return fina_x, fina_y def draw(train, text): names = range(0, epoch_time) names = [str(x) for x in list(names)] x = range(len(names)) plt.plot(x, train, marker='o', mec='r', mfc='w', label='train_1000') plt.plot(x, text, marker='*', ms=10, label='train_all') plt.legend() plt.xticks(x, names, rotation=1) plt.margins(0) plt.subplots_adjust(bottom=0.10) plt.xlabel('epoch') plt.ylabel("accuracy") pyplot.yticks([0, 0.5, 1]) # plt.title("A simple plot") plt.savefig('accuracy.jpg', dpi=900) def train_1000(): sess.run(init) train = tf.zeros(epoch_time) # batch_xs_all, batch_ys_all = mnist.train.next_batch(1000); # print("X shape:", batch_xs_all.shape) # print("Y shape:", batch_ys_all.shape) X_mb, Y_mb = mnist.train.next_batch(1000) Y_mb = Y_mb.astype(np.float32) inputs = tf.concat(axis=1, values=[X_mb, Y_mb]) inputs = inputs.eval(session=sess) train = train.eval(session=sess) for epoch in range(epoch_time): for batch in range(n_batch): fina_x, fina_y = getBatch(inputs) # batch_xs,batch_ys=mnist.train.next_batch(batch_size) sess.run(train_setup, feed_dict={x: fina_x, y: fina_y}) acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}) train[epoch] = acc; print("Iter" + str(epoch) + ", Testing Accuracy=" + str(acc)) return train; def train_all(): sess.run(init) text = tf.zeros(epoch_time) text = text.eval(session=sess) for epoch in range(epoch_time): for batch in range(n_batch_all): batch_xs, batch_ys = mnist.train.next_batch(batch_size) sess.run(train_setup, feed_dict={x: batch_xs, y: batch_ys}) acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}) text[epoch] = acc; print("Iter" + str(epoch) + ", Testing Accuracy=" + str(acc)) with tf.Session() as sess: p1 = train_1000(); p2 = train_all(); draw(p1, p2)

结果

copy
Iter0, Testing Accuracy=0.0982 Iter1, Testing Accuracy=0.2913 Iter2, Testing Accuracy=0.2973 Iter3, Testing Accuracy=0.3493 Iter4, Testing Accuracy=0.4311 Iter5, Testing Accuracy=0.3789 Iter6, Testing Accuracy=0.49 Iter7, Testing Accuracy=0.4547 Iter8, Testing Accuracy=0.4079 Iter9, Testing Accuracy=0.4748 Iter10, Testing Accuracy=0.564 Iter11, Testing Accuracy=0.5026 Iter12, Testing Accuracy=0.6053 Iter13, Testing Accuracy=0.6379 Iter14, Testing Accuracy=0.5863 Iter15, Testing Accuracy=0.6443 Iter16, Testing Accuracy=0.6487 Iter17, Testing Accuracy=0.5809 Iter18, Testing Accuracy=0.6616 Iter19, Testing Accuracy=0.6465 Iter0, Testing Accuracy=0.7625 Iter1, Testing Accuracy=0.864 Iter2, Testing Accuracy=0.8596 Iter3, Testing Accuracy=0.8694 Iter4, Testing Accuracy=0.9028 Iter5, Testing Accuracy=0.9046 Iter6, Testing Accuracy=0.902 Iter7, Testing Accuracy=0.9021 Iter8, Testing Accuracy=0.8874 Iter9, Testing Accuracy=0.9192 Iter10, Testing Accuracy=0.9175 Iter11, Testing Accuracy=0.9226 Iter12, Testing Accuracy=0.9233 Iter13, Testing Accuracy=0.9156 Iter14, Testing Accuracy=0.93 Iter15, Testing Accuracy=0.9251 Iter16, Testing Accuracy=0.9232 Iter17, Testing Accuracy=0.9176 Iter18, Testing Accuracy=0.9287 Iter19, Testing Accuracy=0.9273
posted @   梁君牧  阅读(96)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示
🚀