P68 03——简单的神经网络实线手写数字识别
http://bilibili.com/video/BV184411Q7Ng?p=68
return:返回的是一个小批次样本的损失值列表。
one-hot编码:一个热编码。例如阿拉伯数字4编码成列向量:[0 0 0 0 1 0 0 0 0 0]
准确率怎么来?
答:在第一批次(mini-batch)训练之后,更新一下网络的权重。在第二批次及以后批次训练的时候,用预测的标签和真实的标签相比较,即可得到本批次的准确率。
其实第一批次也有准确率,第一批次的初始权重是随机初始化的。
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #FLAGS=tf.app.flags.FlAGS #tf.app.flags.DEFINE_integer("is_train",1,"指定程序是预测还是训练") #is_train=1 #is_train的值设为1,代表是训练 is_train=0 #is_train的值设为0,代表是训练之后的预测 def full_connected(): mnist=input_data.read_data_sets("./data/mnist/input_data/",one_hot=True) #建立数据的占位符, x [None,784] y_true [None,10] with tf.variable_scope("data"): x=tf.placeholder(tf.float32,[None,784]) y_true=tf.placeholder(tf.int32,[None,10]) #2、建立一个全连接层的神经网络w [784,10] b [10] with tf.variable_scope("fc_model"): # 随机初始化权重和偏置 weight=tf.Variable(tf.random_normal([784,10],mean=0.0,stddev=1.0,name='w')) bias=tf.Variable(tf.constant(0.0,shape=[10])) #预测None个样本的输出结果[None, 784]*[784,10]+[10]=[None,10] y_predict=tf.matmul(x,weight)+bias #3、求出所有样本的损失 with tf.variable_scope("soft_cross"): #求平均的交叉熵损失 loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict)) #4、梯度下降求损失 with tf.variable_scope("optimizer"): train_op=tf.train.GradientDescentOptimizer(0.1).minimize(loss) #5、计算准确率 with tf.variable_scope("acc"): equal_list=tf.equal(tf.argmax(y_true,1),tf.argmax(y_predict,1)) #equal_list is None个样本 [1,0,1,1,0,0,0,1,1,0...] accuracy=tf.reduce_mean(tf.cast(equal_list,tf.float32)) # 收集变量,单个数字值的收集 tf.summary.scalar("losses",loss)#收集损失 tf.summary.scalar("acc",accuracy)#收集准确率 # 高维度变量收集 tf.summary.histogram("weightes",weight) tf.summary.histogram("biases",bias) # 定义一个初始化的变量op init_op=tf.global_variables_initializer() # 定义一个合并变量的op merged=tf.summary.merge_all() # 创建一个saver saver=tf.train.Saver() #开启会话去训练 with tf.Session() as sess: #初始化变量 sess.run(init_op) #建立events文件,然后写入 filewriter=tf.summary.FileWriter("./tmp/summary/test/",graph=sess.graph) if is_train==1: #迭代不熟去训练,更新参数进行预测 for i in range(2000): #取出样本的特征值和目标值 mnist_x,mnist_y=mnist.train.next_batch(50) #运行train_op训练 sess.run(train_op,feed_dict={x:mnist_x,y_true:mnist_y}) # 写入每步训练的值 summary=sess.run(merged,feed_dict={x:mnist_x,y_true:mnist_y}) filewriter.add_summary(summary,i) print("训练第%d步,准确率为:%f"%(i,sess.run(accuracy,feed_dict={x:mnist_x,y_true:mnist_y}))) #保存模型 saver.save(sess,"./tmp/ckpt/fc_model") else: # 加载模型,主要是加载训练之后保存在模型中的weight和bias saver.restore(sess,"./tmp/ckpt/fc_model") # 如果是0,做预测 for i in range(100): # 每次测试一张图片 x_test,y_test=mnist.test.next_batch(1) print("第%d张图片,手写数字图片目标是:%d,预测结果是:%d"%( i, tf.argmax(y_test,1).eval(), tf.argmax(sess.run(y_predict,feed_dict={x:x_test,y_true:y_test}),1).eval() )) return None if __name__ == '__main__': full_connected()
第0张图片,手写数字图片目标是:2,预测结果是:2 第1张图片,手写数字图片目标是:9,预测结果是:9 第2张图片,手写数字图片目标是:9,预测结果是:9 第3张图片,手写数字图片目标是:2,预测结果是:2 第4张图片,手写数字图片目标是:0,预测结果是:0 第5张图片,手写数字图片目标是:2,预测结果是:2 第6张图片,手写数字图片目标是:7,预测结果是:7 第7张图片,手写数字图片目标是:7,预测结果是:7 第8张图片,手写数字图片目标是:0,预测结果是:0 第9张图片,手写数字图片目标是:0,预测结果是:0 第10张图片,手写数字图片目标是:6,预测结果是:7 第11张图片,手写数字图片目标是:5,预测结果是:5 第12张图片,手写数字图片目标是:7,预测结果是:7 第13张图片,手写数字图片目标是:0,预测结果是:0 第14张图片,手写数字图片目标是:6,预测结果是:6 第15张图片,手写数字图片目标是:3,预测结果是:3 第16张图片,手写数字图片目标是:5,预测结果是:8 第17张图片,手写数字图片目标是:1,预测结果是:1 第18张图片,手写数字图片目标是:7,预测结果是:7 第19张图片,手写数字图片目标是:5,预测结果是:5 第20张图片,手写数字图片目标是:6,预测结果是:6 第21张图片,手写数字图片目标是:6,预测结果是:6 第22张图片,手写数字图片目标是:5,预测结果是:6 第23张图片,手写数字图片目标是:7,预测结果是:7 第24张图片,手写数字图片目标是:9,预测结果是:8 第25张图片,手写数字图片目标是:5,预测结果是:5 第26张图片,手写数字图片目标是:6,预测结果是:6 第27张图片,手写数字图片目标是:1,预测结果是:6 第28张图片,手写数字图片目标是:1,预测结果是:1 第29张图片,手写数字图片目标是:7,预测结果是:7 第30张图片,手写数字图片目标是:0,预测结果是:0 第31张图片,手写数字图片目标是:1,预测结果是:1 第32张图片,手写数字图片目标是:4,预测结果是:9 第33张图片,手写数字图片目标是:3,预测结果是:5 第34张图片,手写数字图片目标是:6,预测结果是:6 第35张图片,手写数字图片目标是:5,预测结果是:3 第36张图片,手写数字图片目标是:5,预测结果是:5 第37张图片,手写数字图片目标是:1,预测结果是:1 第38张图片,手写数字图片目标是:7,预测结果是:7 第39张图片,手写数字图片目标是:8,预测结果是:2 第40张图片,手写数字图片目标是:7,预测结果是:7 第41张图片,手写数字图片目标是:7,预测结果是:7 第42张图片,手写数字图片目标是:7,预测结果是:4 第43张图片,手写数字图片目标是:3,预测结果是:3 第44张图片,手写数字图片目标是:4,预测结果是:4 第45张图片,手写数字图片目标是:1,预测结果是:1 第46张图片,手写数字图片目标是:3,预测结果是:3 第47张图片,手写数字图片目标是:5,预测结果是:5 第48张图片,手写数字图片目标是:0,预测结果是:0 第49张图片,手写数字图片目标是:1,预测结果是:1 第50张图片,手写数字图片目标是:2,预测结果是:2 第51张图片,手写数字图片目标是:0,预测结果是:0 第52张图片,手写数字图片目标是:9,预测结果是:9 第53张图片,手写数字图片目标是:7,预测结果是:7 第54张图片,手写数字图片目标是:5,预测结果是:5 第55张图片,手写数字图片目标是:4,预测结果是:4 第56张图片,手写数字图片目标是:3,预测结果是:3 第57张图片,手写数字图片目标是:1,预测结果是:1 第58张图片,手写数字图片目标是:9,预测结果是:9 第59张图片,手写数字图片目标是:5,预测结果是:5 第60张图片,手写数字图片目标是:9,预测结果是:9 第61张图片,手写数字图片目标是:8,预测结果是:8 第62张图片,手写数字图片目标是:0,预测结果是:0 第63张图片,手写数字图片目标是:3,预测结果是:3 第64张图片,手写数字图片目标是:6,预测结果是:2 第65张图片,手写数字图片目标是:1,预测结果是:1 第66张图片,手写数字图片目标是:4,预测结果是:4 第67张图片,手写数字图片目标是:8,预测结果是:8 第68张图片,手写数字图片目标是:6,预测结果是:6 第69张图片,手写数字图片目标是:4,预测结果是:4 第70张图片,手写数字图片目标是:1,预测结果是:1 第71张图片,手写数字图片目标是:1,预测结果是:1 第72张图片,手写数字图片目标是:3,预测结果是:3 第73张图片,手写数字图片目标是:1,预测结果是:1 第74张图片,手写数字图片目标是:4,预测结果是:4 第75张图片,手写数字图片目标是:4,预测结果是:4 第76张图片,手写数字图片目标是:7,预测结果是:7 第77张图片,手写数字图片目标是:3,预测结果是:8 第78张图片,手写数字图片目标是:4,预测结果是:4 第79张图片,手写数字图片目标是:8,预测结果是:8 第80张图片,手写数字图片目标是:7,预测结果是:7 第81张图片,手写数字图片目标是:0,预测结果是:0 第82张图片,手写数字图片目标是:2,预测结果是:2 第83张图片,手写数字图片目标是:0,预测结果是:0 第84张图片,手写数字图片目标是:0,预测结果是:5 第85张图片,手写数字图片目标是:1,预测结果是:1 第86张图片,手写数字图片目标是:0,预测结果是:0 第87张图片,手写数字图片目标是:9,预测结果是:9 第88张图片,手写数字图片目标是:9,预测结果是:9 第89张图片,手写数字图片目标是:6,预测结果是:6 第90张图片,手写数字图片目标是:0,预测结果是:0 第91张图片,手写数字图片目标是:4,预测结果是:4 第92张图片,手写数字图片目标是:7,预测结果是:7 第93张图片,手写数字图片目标是:1,预测结果是:1 第94张图片,手写数字图片目标是:8,预测结果是:8 第95张图片,手写数字图片目标是:1,预测结果是:1 第96张图片,手写数字图片目标是:2,预测结果是:2 第97张图片,手写数字图片目标是:4,预测结果是:4 第98张图片,手写数字图片目标是:1,预测结果是:1 第99张图片,手写数字图片目标是:1,预测结果是:1
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· winform 绘制太阳,地球,月球 运作规律
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
2019-12-18 墨卡托投影、横轴墨卡托投影和通用横轴墨卡托投影
2019-12-18 大地基准面
2019-12-18 我国的大地原点
2019-12-18 《诗经》