17:全连接网络的创建以及参数优化(创建,训练,测试[计算准确度],可视化)以及Flatten层的自定义代码

全连接层非常重要,学习后面的各种网络模型都会用到,比如:cNN,RNN等等。但是一般向以上的模型输入的都是四维张量,故通过卷积和LSTM之后输出是四维张量,但是全连接层需要输入二维张量。故需要用到打平层(Flatten层),将后面的3个维度打平,才能输入到全连接层(nn.liner层)。由于pytorch未提供Flatten层,但是后面又非常常用,故这里我将其定义好了,以便大家参考使用。

Flatten层:作用:将卷积层等输出的四维张量转换维二维张量。为全连接层,提供二维的张量。

 

 

1:全连接网络的创建

【注】

(1):自己创建的网络结构需要继承nn.Module

(2):需要实现forward()函数,不需要实现backward(),因为nn.Module提供了backward()函数的实现。(pytorch的autograd包会自动完成向后求导的过程。)

2:代码的实现

(2.1)步骤一:全连接网络的创建以及前向传播的实现

 

 [注]nn.Sequential(

)类似于一个容器,容器中可以添加任何继承自nn.Module的类。也可以添加自己创建的类。

[注]self.model继承自nn.Module故可以使用self.model(x)调用model.forward()函数。

[注]关于全连接层的详细讲解,可以参考博客https://blog.csdn.net/zerone_zjp/article/details/108625099,个人认为该作者讲的不错。

 

 [注]:nn.ReLU为类风格的API,F.relu()为函数风格的API。

两种风格的不同:

对于类风格的API必须先进行实例化,再进行调用。并且其内部参数w,b必须通过para方法来进行访问。

对于函数风格的API可以自己进行过程的管理,仅仅使用了gpu加速的功能。

(2)步骤2:train

 [注]optimizer()函数可以实现对w和b的更新。学习率为learning_rate。

 (3)test(也即求准确率或精度)

 

 

 

 [注]argmax(dim=)函数可以实现在指定维度取最大值所在的索引。无论是对未经过激活函数的预测值还是对经过激活函数的预测值输出的最大值索引都是相同的(激活函数不改变函数的单调性)。

 

[注]上图实现了accuracy的计算:

步骤为:

(1)pred=logits.argmax(dim=)          这里的pred为预测的最大值所在的索引

(2)correct=pred.eq(target).float().sum().item()                            这里的target为真实的最大值所在的索引 ,sum()可以统计预测正确的个数

(3)accuracy=correct/len(test_loader,dataset)                              正确的个数除以总的数量为准确率(精度)

 [注]

精确率为precision,accuracy为准确率,召回率为recall。

指标计算:
精确度=TT/(TT+TF)--判断正样本中真正正样本的比例
准确率=(TT+FF)/(T+F)--判断正确的比重
召回率=TT/(TT+FT)--正确判断正例的比重

漏报率=FT/(TT+FT)--多少个正例被漏判了
虚警率=TF/(TT+TF)--反映被判为正例样本中,有多少个是负例

(4)可视化(Visdom,tensorboardX)这里只介绍Visdom

【注】

Visdom可以接收tensor类型的数据(实际也是需要对tensor类型进行转化,只是进行了打包),tensorboardX需将tensor类型的数据搬运到cpu上,然后转化成numpy数据才能够进行可视化。

Visdom更新更加实时,tensorboardX大概30秒更新一次

tensorboardX会将数据写到监听文件中,导致监听文件非常大。

(4.1.1)Visdom的安装以及开启监听进程

linux安装:pip install visdom

        开启:python -m visdom.server

window安装:可以使用linux下的安装步骤。如果出现问题:

                    可以从github下载visdom的源码,解压后cd visdom-master 

                    pip install -e.

      cd ../..(退回到原目录)

      python -m visdom.server

                    复制地址打开浏览器即可。

 (4.1.2)画曲线

 

 

 

 [注]

win是一个小窗口其值为小窗口的ID,envs是一个工程。

opts=dict(title='')是设置窗口的名字,便于查看

因为每一个点都是添加在最后面故:需要使用update=‘append’

 [注]除了画图可以直接收tensor类型的数据,画曲线还是要传入numpy类型的数据。

 

 

 

 [注]legend为图标

 

 

 

 

posted @ 2021-08-01 19:42  收购阿里巴巴  阅读(1276)  评论(0编辑  收藏  举报