THE MNIST DATABASE of handwritten digits学习笔记
官网
http://yann.lecun.com/exdb/mnist/
下载4个压缩包
引入包
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
# 下载数据,存储到MNIST_data中
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# input_data 会调用一个maybe_download 函数,确保数据已经下载成功。
# 这个函数会判断数据是否下载,如果已经下载完,则不会重新下载
# 训练集的图片
training = mnist.train.images
# 训练集的标签
trainlabel = mnist.train.labels
# 测试集的图片
testing = mnist.test.images
# 测试集的标签
testlabel = mnist.test.labels
print('恭喜你,Mnist 准备充分了')
或者
# IMPORTS
import os, urllib.request
# PROVIDE YOUR DOWNLOAD DIRECTORY HERE
datapath = 'MNISTData/'
# CREATING DOWNLOAD DIRECTORY
if not os.path.exists(datapath):
os.makedirs(datapath)
# URLS TO DOWNLOAD FROM
urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz']
for url in urls:
filename = url.split('/')[-1] # GET FILENAME
if os.path.exists(datapath + filename):
print(filename, ' already exists') # CHECK IF FILE EXISTS
else:
print('Downloading ', filename)
urllib.request.urlretrieve(url, datapath + filename) # DOWNLOAD FILE
print('All files are available')
解压四个压缩包 并删除原有压缩包 只保留解压后的文件
import os,gzip,shutil
# PROVIDE YOUR DOWNLOAD DIRECTORY HERE
datapath = 'MNISTData/'
# LISTING ALL ARCHIVES IN THE DIRECTORY
files = os.listdir(datapath)
for file in files:
if file.endswith('gz'):
print('Extracting ',file)
with gzip.open(datapath+file, 'rb') as f_in:
with open(datapath+file.split('.')[0], 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
print('Extraction Complete')
# OPTIONAL REMOVE THE ARCHIVES
for file in files:
print('Removing ',file)
os.remove(datapath+file)
print ('All archives removed')
Tensorflow 读取MNIST图片数据
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
# 解决中文乱码的问题
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
# MNIST_data 代表当前程序文件所在的目录中,用于存放MNIST数据的文件夹,如果没有则新建,然后下载.
mnist = input_data.read_data_sets("MNISTData",one_hot=True)
# 训练集、测试集、验证集
# print("*"*100)
# print(mnist.train.images)
# print("*"*100)
# print(mnist.train.labels)
# print("*"*100)
# print(mnist.test.images)
# print("*"*100)
# print(mnist.test.labels)
# print("*"*100)
# print(mnist.validation.images)
# print("*"*100)
# print(mnist.validation.images)
print("*"*100)
print(mnist.train.images.shape) #train.images 数组行数为55000 列数为 784,代表了 55000 张测试图片
print(mnist.train.labels.shape)
#获取第二张图片
image = mnist.train.images[1,:]
#将图像数据还原成28*28的分辨率
image = image.reshape(28,28)
#打印对应的标签
print(mnist.train.labels[1])
plt.figure()
plt.xlabel("abscissa横坐标")
plt.ylabel("ordinate纵坐标")
plt.imshow(image)
plt.show()
训练神经网络
# coding:utf-8
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
mnist = input_data.read_data_sets("MNISTData", one_hot=True)
input = tf.placeholder(tf.float32, [None, 784])
input_image = tf.reshape(input, [-1, 28, 28, 1])
y = tf.placeholder(tf.float32, [None, 10])
# input 代表输入,filter 代表卷积核
def conv2d(input, filter):
return tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')
# 池化层
def max_pool(input):
return tf.nn.max_pool(input, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
# 初始化卷积核或者是权重数组的值
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
# 初始化bias的值
def bias_variable(shape):
return tf.Variable(tf.zeros(shape))
# [filter_height, filter_width, in_channels, out_channels]
# 定义了卷积核
filter = [3, 3, 1, 32]
filter_conv1 = weight_variable(filter)
b_conv1 = bias_variable([32])
# 创建卷积层,进行卷积操作,并通过Relu激活,然后池化
h_conv1 = tf.nn.relu(conv2d(input_image, filter_conv1) + b_conv1)
h_pool1 = max_pool(h_conv1)
h_flat = tf.reshape(h_pool1, [-1, 14 * 14 * 32])
W_fc1 = weight_variable([14 * 14 * 32, 768])
b_fc1 = bias_variable([768])
h_fc1 = tf.matmul(h_flat, W_fc1) + b_fc1
W_fc2 = weight_variable([768, 10])
b_fc2 = bias_variable([10])
y_hat = tf.matmul(h_fc1, W_fc2) + b_fc2
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_hat))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
# train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(10000):
batch_x, batch_y = mnist.train.next_batch(50)
if i % 100 == 0:
train_accuracy = accuracy.eval(feed_dict={input: batch_x, y: batch_y})
print("step %d,train accuracy %g " % (i, train_accuracy))
train_step.run(feed_dict={input: batch_x, y: batch_y})
# sess.run(train_step,feed_dict={x:batch_x,y:batch_y})
print("test accuracy %g " % accuracy.eval(feed_dict={input: mnist.test.images, y: mnist.test.labels}))
参考文献
https://blog.csdn.net/briblue/article/details/80398369?tdsourcetag=s_pcqq_aiomsg
posted on 2019-04-01 22:47 Indian_Mysore 阅读(585) 评论(1) 编辑 收藏 举报