tensorflow2.3 实现简单的单层神经网络_mnist数据集手写数字识别

实现流程

 

1、准备数据

2、全连接结果计算

3、损失优化(梯度下降)

4、模型评估(计算准确性)

5、加入tensorboard图

6、使用训练后的模型进行预测

 

 1 def full_connect():
 2     #使用占位符时,tersorflow2.X以上会出现tf.placeholder() is not compatible with eager execution报错,需要加下面这段语,避免程序报此错误。
 3     tf.compat.v1.disable_eager_execution()
 4     #获取真实的数据
 5     mnist = input_data.read_data_sets("./tmp/mnist/", one_hot=True)
 6     #1、建立数据的占位符 ,X[None,784] y_true [None,10]
 7     with tf.compat.v1.variable_scope('data'):
 8         x=tf.compat.v1.placeholder(tf.float32,[None,784])
 9         y_true=tf.compat.v1.placeholder(tf.int32,[None,10])
10 
11     #2、建立一个全链接层的神经网络 w[784,10],b=[10]
12     with tf.compat.v1.variable_scope('fc_model'):
13         #随机初始化权重和偏置,权重和偏置后面会跟着训练自动优化
14         weight=tf.Variable(tf.compat.v1.random_normal([784,10],mean=0.0,stddev=1.0),name='weight')
15         bias=tf.Variable(tf.constant(0.0,shape=[10]))
16         #预测Nonew个样本的输出结果matrix [None,784]*[784*10]+[10]=[None,10],即矩阵[None,784]样本的特征*权重[784,10]+偏置[10]=预测结果[None,10]
17         y_predict=tf.matmul(x,weight)*bias
18     #计算交叉熵损失
19     with tf.compat.v1.variable_scope('soft_cross'):
20         #返回交叉熵的列表结果,对交叉熵求平均值
21         loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict))
22 
23     #梯度下降求出损失
24     with tf.compat.v1.variable_scope('optimizer'):
25         train_op=tf.compat.v1.train.GradientDescentOptimizer(0.1).minimize(loss)
26     #5、计算准确率,预测准确置为1
27     with tf.compat.v1.variable_scope('acc'):
28         #equal_list None个样本[1,0,1,1,.....]
29         equal_list=tf.equal(tf.argmax(y_true,1),tf.argmax(y_predict,1))
30         accuray=tf.reduce_mean(tf.cast(equal_list,tf.float32))
31     #收集变量,单个数字值收集
32     tf.compat.v1.summary.scalar("losses",loss)
33     tf.compat.v1.summary.scalar("acc", accuray)
34 
35     #高纬度变量收集
36     tf.compat.v1.summary.histogram('weight',weight)
37     tf.compat.v1.summary.histogram('biases',bias)
38 
39     #定义一个合并的op
40     merged=tf.compat.v1.summary.merge_all()
41 
42     #因为有变量,故要定义初始化变量的op
43     init_op=tf.compat.v1.global_variables_initializer()
44     #开启回话去训练
45     with tf.compat.v1.Session() as sess:
46         #初始化变量
47         sess.run(init_op)
48         filewriter=tf.compat.v1.summary.FileWriter('./tmp/summary/test/',graph=sess.graph)
49         #迭代步数去训练 ,更新参数预测
50         for i in range(2000):
51             mnist_x,mnist_y=mnist.train.next_batch(50)
52             #feed_dict实时提供的数据 x训练集,y为真实的目标值
53             #运行op训练
54             sess.run(train_op,feed_dict={x:mnist_x,y_true:mnist_y})
55             #写入每步训练的值
56             summary=sess.run(merged,feed_dict={x:mnist_x,y_true:mnist_y})
57             filewriter.add_summary(summary,i)
58 
59             print('训练第%d步,准确率为:%f'%(i,sess.run(accuray,feed_dict={x:mnist_x,y_true:mnist_y})))
60     return None

注意:在tensorflow2.X版本,如果出现报No module named 'tensorflow.examples.tutorials' ,手动下载tutorials文件包,并放到本地电脑tersorflow/examples目录下。

下载链接:https://share.weiyun.com/fpYSBj4X 密码:qu73et

 

出现报错tensorflow报AttributeError: __enter__,将tf.compat.v1.Session后面加上括号()

 

 

在以上的代码基础上增加一下代码(红色字体)

FLAGS=tf.compat.v1.flags.FLAGS
tf.compat.v1.flags.DEFINE_integer("is_train",1,'指定程序是预测还是训练')


 #开启回话去训练
    with tf.compat.v1.Session() as sess:
        #初始化变量
        sess.run(init_op)
        filewriter=tf.compat.v1.summary.FileWriter('./tmp/summary/test/',graph=sess.graph)
        #迭代步数去训练 ,更新参数预测
        if FLAGS.is_train ==1:
            for i in range(2000):
                mnist_x,mnist_y=mnist.train.next_batch(50)
                #feed_dict实时提供的数据 x训练集,y为真实的目标值
                #运行op训练
                sess.run(train_op,feed_dict={x:mnist_x,y_true:mnist_y})
                #写入每步训练的值
                summary=sess.run(merged,feed_dict={x:mnist_x,y_true:mnist_y})
                filewriter.add_summary(summary,i)

                print('训练第%d步,准确率为:%f'%(i,sess.run(accuray,feed_dict={x:mnist_x,y_true:mnist_y})))
            #保存模型
            saver.save(sess,"./tmp/ckpt/tc_model")
        else:
            #加载模型,如果不加载模型,则参数不会被新的覆盖
            saver.restore(sess,'./tmp/ckpt/tc_model')
            #如果是0,做出预测
            for i in range(100):
                #每次测试一张图片[0,0,0,1,0,..0]
                x_test,y_test=mnist.test.next_batch(1)
                print('第%d张图片,手写数字目标是%d,预测结果是:%d'%(i,
                                                   tf.argmax(y_test,1).eval(),
                                                   #图片与预测值的概率,再求其最大概率值
                                                  tf.argmax( sess.run(y_predict,feed_dict={x:x_test,y_true:y_test}),1).eval()))


    return None

 

 

 

cmd运行程序结果如下:

 

posted @ 2020-11-04 15:22  hisweetyGirl  阅读(409)  评论(0编辑  收藏  举报