Boundary Aware PoolNet(2):BASNet模型与代码介绍
Boundary Aware PoolNet = PoolNet + BASNet,即使用BASNet中的Deep Supervision策略和Hybrid Loss改进PoolNet。
为理解Boundary Aware PoolNet,我们并不需要学习整个BASNet,只需要了解其中的Deep Supervision策略和Hybrid Loss即可。
本文将简单介绍BASNet的模型结构,重点介绍其Deep Supervision和Hybrid Loss的理论和代码实现。
相关文章汇总:
BASNet
传送门
- BASNet论文:https://openaccess.thecvf.com/content_CVPR_2019/html/Qin_BASNet_Boundary-Aware_Salient_Object_Detection_CVPR_2019_paper.html
- BASNet代码:https://github.com/xuebinqin/BASNet
- BASNet论文阅读笔记:https://zhuanlan.zhihu.com/p/355420066
BASNet结构
如上图所示,BASNet模型包括Predict Module和Refine Module。
-
Predict Module
a U-Net-like densely supervised Encoder-Decoder network,作用是predict saliency map from input images。
其实这个Encoder-Decoder结构和FPN(特征金字塔网络)没什么区别吧。
-
Refine Module
refines the resulting saliency map of the prediction module by learning the residuals between the saliency map and the ground truth。
基于上述的2个Module,BASNet使用Deep Supervision(上图中的Sup1-8)和Hybrid Loss进行模型训练。
代码
Predict Module的代码在文件./model/BASNet.py
中类BASNet
中,Refine Module的代码在文件./model/BASNet.py
中类RefUnet
中。
Deep Supervision
直白来讲,Deep Supervision即使用神经网络中多个层的Loss之和进行梯度下降。
如前文中BASNet结构图所示,BASNet作者计算了Predict Module中的7层和Refine Module中的最后1层的Loss并进行求和,然后进行梯度下降,以此实现Deep Supervision。在计算边路输出时,需要进行上采样和卷积使得边路输出的尺寸、通道数与输入相同。
Deep Supervision的代码在文件./model/BASNet.py
的类BASNet
的函数forward()
中,可知类BASNet
在forward()
时返回了8个边路输出,后继计算这8层的Hybrid Loss并求和进行梯度下降。
Hybrid Loss
直白来讲,Hybrid Loss即在计算损失时使用BCE Loss、SSIM Loss、IOU Loss这3个损失之和而非只使用BCE损失函数。
Hybrid Loss的代码在文件./basnet_trin.py
中的函数muti_bce_loss_fusion()
中。该函数的输入为BASNet的8个边路输出和输入对应的标注,该函数使用函数bce_ssim_loss()
计算1个边路输出与标注的3种Loss之和。
Github(github.com):@chouxianyu
Github Pages(github.io):@臭咸鱼
知乎(zhihu.com):@臭咸鱼
博客园(cnblogs.com):@臭咸鱼
B站(bilibili.com):@绝版臭咸鱼
微信公众号:@臭咸鱼
转载请注明出处,欢迎讨论和交流!