Mnist字符识别-神经网络实现(TF框架)
Mnist字符识别-神经网络实现(TF框架)
该段代码即贴即用,先贴一下代码,有空的时候写个注释解析。大三的代码了,特别适合新手入门,现在都用Pytorch了。
电脑用的tensorflow版本是1.13.1的,用CPU跑也挺快的。之前用GPU跑了半小时准确率能达到98%左右。
代码
copy# -*- coding:utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from matplotlib import pyplot
import matplotlib.pyplot as plt
import numpy as np
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
seed=547
np.random.seed(seed)
epoch_time = 20;
ALPHY = 0.5
batch_size = 10
n_batch_all = mnist.train.num_examples // batch_size
n_batch = 1000 // batch_size
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
def xavier_init(size):
in_dim = size[0]
xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
return tf.random_normal(shape=size, stddev=xavier_stddev)
W1 =tf.Variable(xavier_init([784, 30]))
B1 = tf.Variable(tf.zeros([30]))
L1 = tf.nn.sigmoid(tf.matmul(x,W1) + B1)
W2 =tf.Variable(xavier_init([30, 10]))
B2 = tf.Variable(tf.zeros([10]))
logit_prediction = tf.matmul(L1,W2) + B2
prediction = tf.nn.sigmoid(logit_prediction)
# MSE损失函数
# loss = tf.reduce_mean(tf.square(y - prediction))
#交叉熵损失函数
loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logit_prediction,labels=y)
train_setup = tf.train.GradientDescentOptimizer(ALPHY).minimize(loss)
init = tf.global_variables_initializer()
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
def getBatch(inputs):
np.random.shuffle(inputs)
batch = inputs[:10]
fina_x = batch[:, :784]
fina_y = batch[:, 784:794]
return fina_x, fina_y
def draw(train, text):
names = range(0, epoch_time)
names = [str(x) for x in list(names)]
x = range(len(names))
plt.plot(x, train, marker='o', mec='r', mfc='w', label='train_1000')
plt.plot(x, text, marker='*', ms=10, label='train_all')
plt.legend()
plt.xticks(x, names, rotation=1)
plt.margins(0)
plt.subplots_adjust(bottom=0.10)
plt.xlabel('epoch')
plt.ylabel("accuracy")
pyplot.yticks([0, 0.5, 1])
# plt.title("A simple plot")
plt.savefig('accuracy.jpg', dpi=900)
def train_1000():
sess.run(init)
train = tf.zeros(epoch_time)
# batch_xs_all, batch_ys_all = mnist.train.next_batch(1000);
# print("X shape:", batch_xs_all.shape)
# print("Y shape:", batch_ys_all.shape)
X_mb, Y_mb = mnist.train.next_batch(1000)
Y_mb = Y_mb.astype(np.float32)
inputs = tf.concat(axis=1, values=[X_mb, Y_mb])
inputs = inputs.eval(session=sess)
train = train.eval(session=sess)
for epoch in range(epoch_time):
for batch in range(n_batch):
fina_x, fina_y = getBatch(inputs)
# batch_xs,batch_ys=mnist.train.next_batch(batch_size)
sess.run(train_setup, feed_dict={x: fina_x, y: fina_y})
acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
train[epoch] = acc;
print("Iter" + str(epoch) + ", Testing Accuracy=" + str(acc))
return train;
def train_all():
sess.run(init)
text = tf.zeros(epoch_time)
text = text.eval(session=sess)
for epoch in range(epoch_time):
for batch in range(n_batch_all):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_setup, feed_dict={x: batch_xs, y: batch_ys})
acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
text[epoch] = acc;
print("Iter" + str(epoch) + ", Testing Accuracy=" + str(acc))
with tf.Session() as sess:
p1 = train_1000();
p2 = train_all();
draw(p1, p2)
结果
copyIter0, Testing Accuracy=0.0982 Iter1, Testing Accuracy=0.2913 Iter2, Testing Accuracy=0.2973 Iter3, Testing Accuracy=0.3493 Iter4, Testing Accuracy=0.4311 Iter5, Testing Accuracy=0.3789 Iter6, Testing Accuracy=0.49 Iter7, Testing Accuracy=0.4547 Iter8, Testing Accuracy=0.4079 Iter9, Testing Accuracy=0.4748 Iter10, Testing Accuracy=0.564 Iter11, Testing Accuracy=0.5026 Iter12, Testing Accuracy=0.6053 Iter13, Testing Accuracy=0.6379 Iter14, Testing Accuracy=0.5863 Iter15, Testing Accuracy=0.6443 Iter16, Testing Accuracy=0.6487 Iter17, Testing Accuracy=0.5809 Iter18, Testing Accuracy=0.6616 Iter19, Testing Accuracy=0.6465 Iter0, Testing Accuracy=0.7625 Iter1, Testing Accuracy=0.864 Iter2, Testing Accuracy=0.8596 Iter3, Testing Accuracy=0.8694 Iter4, Testing Accuracy=0.9028 Iter5, Testing Accuracy=0.9046 Iter6, Testing Accuracy=0.902 Iter7, Testing Accuracy=0.9021 Iter8, Testing Accuracy=0.8874 Iter9, Testing Accuracy=0.9192 Iter10, Testing Accuracy=0.9175 Iter11, Testing Accuracy=0.9226 Iter12, Testing Accuracy=0.9233 Iter13, Testing Accuracy=0.9156 Iter14, Testing Accuracy=0.93 Iter15, Testing Accuracy=0.9251 Iter16, Testing Accuracy=0.9232 Iter17, Testing Accuracy=0.9176 Iter18, Testing Accuracy=0.9287 Iter19, Testing Accuracy=0.9273
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步