2022CVPR_SNR-Aware Low-light Image Enhancement(SNR)

一. motivation

对于区域含有较高的峰值信噪比(SNR)则可利用局部信息进行图像增强,但是对于区域含有很低的SNR,则无法利用局部信息,此时非局部信息就非常重要

二. contribution

提出(1)一个利用卷积结构设计的短分支方便利用局部信息,一个利用transformer结构设计的长分支方便利用全局信息

(2)在transformer中利用SNR引导的自注意self-attention, 利用SNR进行长短分支的特征融合

三.Network

 1. 对于低光照的图片首先采用公式2获得SNR Map

 (1)Ig:是低光图片          :是经过cv.blur进行均值滤波后的图像

 (2) 对Ig和Ig' 取得灰度图进行绝对值相减得到噪声N

 (3)SNR(mask):均值滤波后的图像与噪声相除得到S

2. 先进行浅层特征提取

 3.  对于fea进行深层特征提取,fea进行两个分支,一个分支(短分支)进行卷积块(更容易捕获局部信息)进行残差连接,另外一个分支(长分支)进行SNR引导的transformer结构(更容易捕获全局信息)

短分支结构:

普通的两个卷积操作+跳连接形成一个残差块,如此循环6次得到经过短分支结构的Fs

长分支结构

(1)首先利用F.unfold展成n个patch,以后放入transformer中进行计算

(2)transformer结构就是归一化+ 多头 self attention + FFN(前馈网络)组成

(3)SNR主要体现在多头自注意中

softmax:Softmax函数计算每个元素的指数函数和,然后对所有元素求和以得到归一化的概率分布。

 σ 是一个非常小的负标量−1e9。当S’为0时 (1-S')σ = σ 此时再进行softmax后就会得到0,这意味着SNR很低的地方不会计算到注意力中,当S'=1,(1-S')σ = 0从而实现SNR引导的自注意计算

SNR引导的特征融合

再利用mask:进行特征融合

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

详解:

 transformer包括Attention和FeedForward

对于Attention,对输入的特征进行归一化分别赋给q,k,v

q*k的转置是查看patch之间的相似度

对于mask的操作: 将mask分成s’个pacth,之后按照dim=2取得了mask的均值,采用公式,如果mask 的均值<0.5 则 取值为0,便于后面attention操作

      

 对于mask取值为0的地方,attn中填充为很大的负数,在计算softmax这部分像素值不会发生变化

 FeedForward: 首先归一化,之后全连接+relu,再归一化,一个跳连接。最后输出经过transformer结构的fea_unfold

 

再利用mask:进行特征融合

 ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

 

四. 两个损失

1.CharbonnierLoss2:用Charbonnier Loss来近似 l_{1} 损失来提高模型的性能

 大佬链接:https://blog.csdn.net/weixin_43135178/article/details/120865709

class CharbonnierLoss2(nn.Module):
    """Charbonnier Loss (L1)"""

    def __init__(self, eps=1e-6):
        super(CharbonnierLoss2, self).__init__()
        self.eps = eps

    def forward(self, x, y):
        diff = x - y
        loss = torch.mean(torch.sqrt(diff * diff + self.eps))
        return loss

2. perceptual loss:用VGG19提取特征(关于内容的一个损失)

 对于nn.L1Loss中reduction的说明:https://blog.csdn.net/qq_39450134/article/details/121745209

import torchvision
class VGG19(torch.nn.Module):
    def __init__(self, requires_grad=False):
        super().__init__()
        vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h_relu1 = self.slice1(X)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out


class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()
        self.vgg = VGG19().cuda()
        # self.criterion = nn.L1Loss()
        self.criterion = nn.L1Loss(reduction='sum')
        self.criterion2 = nn.L1Loss()
        self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]

    def forward(self, x, y):
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        loss = 0
        for i in range(len(x_vgg)):
            # print(x_vgg[i].shape, y_vgg[i].shape)
            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
        return loss

 

posted @ 2023-05-04 15:25  helloWorldhelloWorld  阅读(490)  评论(0编辑  收藏  举报