MWCNN中使用的haar小波变换 pytorch

 

1.原理

小波变换的计算方法:

1)一维信号:

例如:有a=[5,7,6,8]四个数,并使用b[4]数组来保存结果.

        则一级Haar小波变换的结果为:

        b[0]=(a[0]+a[1])/2,                       b[2]=(a[0]-a[1])/2

        b[1]=(a[2]+a[3])/2,                       b[3]=(a[2]-a[3])/2

⚠️计算差均值时也有看见a[1]-a[0]的,只要保持一致应该都可以

 

由此可知,Haar变换采用的原理是:

A)低频采用和均值,即b[0]和b[1],和均值中均值存储了图像的整体信息

B)高频采用差均值,即b[2]和b[3],用于记录图像的细节信息,这样在重构时能够恢复图像的全部信息

 因此上面的例子中

b[0] = (5+7)/2 = 6 , b[1] = (6+8)/2 = 6, b[2] = (5-7)/2 = -1, b[3] = (6-8)/2 = -1 

 

如果要继续进行多级的小波变换:

  如上图可见是对低频的信息继续进行haar小波变换

 

2)二维

对于二维haar小波,我们通常一次分解形成了整体图像,水平细节,垂直细节,对角细节。首先我们按照一维haar小波分解的原理,按照行顺序对行进行处理,然后按照列顺序对行处理结果进行同样的处理

用图像表述如图所示:图中a表示原图,图b表示经过一级小波变换的结果,h1 表示水平反向的细节,v1 表示竖直方向的细节,c1表示对角线方向的细节,b表示下2采样的图像。图c中表示继续进行了三次Haar小波变换的结果:

详细过程经过下面的代码来解释

 

2.实现

1)

代码:https://github.com/lpj0/MWCNN_PyTorch/blob/master/model/common.py:

原图为:

 

中间有个问题,就是逆向重构的时候发现并没有成功,得到的结果是:

 

于是对操作的数据进行了一番输出:

#coding:utf-8
import torch.nn as nn
import torch

def dwt_init(x):
    print('-------------- origin ---------------')
    print(x[:,0,:,:])
    x01 = x[:, :, 0::2, :] / 2
    x02 = x[:, :, 1::2, :] / 2
    x1 = x01[:, :, :, 0::2]
    x2 = x02[:, :, :, 0::2]
    x3 = x01[:, :, :, 1::2]
    x4 = x02[:, :, :, 1::2]
    x_LL = x1 + x2 + x3 + x4
    x_HL = -x1 - x2 + x3 + x4
    x_LH = -x1 + x2 - x3 + x4
    x_HH = x1 - x2 - x3 + x4
    print('---------------- LL ------------------')
    print(x_LL[:,0,:,:])
    print()
    print('---------------- HL ------------------')
    print(x_HL[:,0,:,:])
    print()
    print('---------------- LH ------------------')
    print(x_LH[:,0,:,:])
    print()
    print('---------------- HH ------------------')
    print(x_HH[:,0,:,:])
    return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)


# 使用哈尔 haar 小波变换来实现二维逆向离散小波
def iwt_init(x):
    r = 2
    in_batch, in_channel, in_height, in_width = x.size()
    # print([in_batch, in_channel, in_height, in_width]) #[1, 12, 56, 56]
    out_batch, out_channel, out_height, out_width = in_batch, int(in_channel / (r**2)), r * in_height, r * in_width
    # print(out_batch, out_channel, out_height, out_width) #1 3 112 112
    x1 = x[:, 0:out_channel, :, :] / 2
    print('-------------- enter iwt ---------------')
    print(x1[:,0,:,:])
    x2 = x[:, out_channel:out_channel * 2, :, :] / 2
    x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
    x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2
    # print(x1.shape) #torch.Size([1, 3, 56, 56])
    # print(x2.shape) #torch.Size([1, 3, 56, 56])
    # print(x3.shape) #torch.Size([1, 3, 56, 56])
    # print(x4.shape) #torch.Size([1, 3, 56, 56])
    # h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
    h = torch.zeros([out_batch, out_channel, out_height, out_width]).float()

    h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
    h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
    h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
    h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
    print('-------------- back ---------------')
    print(h[:,0,:,:])
    return h


# 二维离散小波
class DWT(nn.Module):
    def __init__(self):
        super(DWT, self).__init__()
        self.requires_grad = False  # 信号处理,非卷积运算,不需要进行梯度求导

    def forward(self, x):
        return dwt_init(x)


# 逆向二维离散小波
class IWT(nn.Module):
    def __init__(self):
        super(IWT, self).__init__()
        self.requires_grad = False

    def forward(self, x):
        return iwt_init(x)

if __name__ == '__main__':
    import os, cv2, torchvision
    from PIL import Image
    import numpy as np
    from torchvision import transforms as trans

    img = Image.open('./1.jpg')
    transform = trans.Compose([
        trans.ToTensor()
    ])
    img = transform(img).unsqueeze(0)
    dwt = DWT()
    change_img_tensor = dwt(img)
    # print(change_img_tensor.shape) #torch.Size([1, 12, 56, 56])
    print('-------------- after dwt ---------------')
    print(change_img_tensor[:,0,:,:])

    for i in range(change_img_tensor.size(1)//3):
        torchvision.utils.save_image(change_img_tensor[:,i*3:i*3+3:,:], os.path.join('./', 'change_{}.jpg'.format(i)))
         
    iwt = IWT()
    back_img_tensor = iwt(change_img_tensor)
    print(back_img_tensor.shape)

    torchvision.utils.save_image(back_img_tensor, 'back.jpg')

返回:

(deeplearning) bogon:learning user$ python delete.py 
-------------- origin ---------------
tensor([[[0.9020, 0.9882, 0.9216,  ..., 0.7176, 0.7843, 0.8431],
         [0.8941, 0.9608, 0.9255,  ..., 0.7490, 0.7569, 0.7490],
         [0.8980, 0.9333, 0.8863,  ..., 0.6941, 0.7333, 0.7608],
         ...,
         [0.1373, 0.1373, 0.1451,  ..., 0.9529, 0.9686, 0.9804],
         [0.1451, 0.1451, 0.1490,  ..., 0.9294, 0.9373, 0.9569],
         [0.1373, 0.1412, 0.1451,  ..., 0.9137, 0.9020, 0.9176]]])
---------------- LL ------------------
tensor([[[1.8725, 1.8569, 1.9314,  ..., 1.4176, 1.4510, 1.5667],
         [1.8294, 1.7588, 1.5118,  ..., 1.4490, 1.4314, 1.5137],
         [1.9039, 1.6059, 1.0412,  ..., 1.2765, 1.4588, 1.5039],
         ...,
         [0.3078, 0.3216, 0.3490,  ..., 1.8784, 1.8333, 1.7627],
         [0.2647, 0.3059, 0.3784,  ..., 1.7510, 1.8725, 1.9137],
         [0.2843, 0.3020, 0.3922,  ..., 1.8941, 1.8549, 1.8569]]])

---------------- HL ------------------
tensor([[[ 0.0765,  0.0098,  0.0059,  ...,  0.0098,  0.0157,  0.0255],
         [ 0.0294, -0.0294, -0.0922,  ..., -0.0098,  0.0078,  0.0235],
         [-0.0412, -0.1588, -0.0725,  ...,  0.0569,  0.0314, -0.0059],
         ...,
         [-0.0137,  0.0196,  0.0039,  ...,  0.0275, -0.0412,  0.0098],
         [-0.0020,  0.0196,  0.0176,  ...,  0.0216,  0.0216,  0.0039],
         [ 0.0020,  0.0078,  0.0314,  ...,  0.0039, -0.0118,  0.0176]]])

---------------- LH ------------------
tensor([[[-0.0176,  0.0098, -0.0176,  ..., -0.0137,  0.0392, -0.0608],
         [-0.0020, -0.0098, -0.1510,  ...,  0.1078,  0.0745,  0.0196],
         [ 0.0176, -0.1392, -0.0882,  ..., -0.0059,  0.0588,  0.0725],
         ...,
         [-0.0255, -0.0078,  0.0118,  ..., -0.0431,  0.0216, -0.0098],
         [ 0.0098,  0.0000,  0.0020,  ...,  0.0020, -0.0020,  0.0353],
         [-0.0059, -0.0039,  0.0039,  ...,  0.0549,  0.0039, -0.0373]]])

---------------- HH ------------------
tensor([[[-0.0098,  0.0059, -0.0216,  ...,  0.0216, -0.0078, -0.0333],
         [-0.0059, -0.0255, -0.0294,  ...,  0.0137, -0.0235, -0.0039],
         [-0.0373, -0.0373,  0.0608,  ..., -0.0020,  0.0196, -0.0098],
         ...,
         [ 0.0059,  0.0039,  0.0039,  ...,  0.0118,  0.0137, -0.0137],
         [ 0.0020, -0.0039,  0.0020,  ..., -0.0098,  0.0137,  0.0078],
         [ 0.0020,  0.0000,  0.0039,  ...,  0.0000, -0.0196, -0.0020]]])
-------------- after dwt ---------------
tensor([[[1.8725, 1.8569, 1.9314,  ..., 1.4176, 1.4510, 1.5667],
         [1.8294, 1.7588, 1.5118,  ..., 1.4490, 1.4314, 1.5137],
         [1.9039, 1.6059, 1.0412,  ..., 1.2765, 1.4588, 1.5039],
         ...,
         [0.3078, 0.3216, 0.3490,  ..., 1.8784, 1.8333, 1.7627],
         [0.2647, 0.3059, 0.3784,  ..., 1.7510, 1.8725, 1.9137],
         [0.2843, 0.3020, 0.3922,  ..., 1.8941, 1.8549, 1.8569]]])
-------------- enter iwt ---------------
tensor([[[127.5000, 127.5000, 127.5000,  ..., 127.5000, 127.5000, 127.5000],
         [127.5000, 127.5000, 127.5000,  ..., 127.5000, 127.5000, 127.5000],
         [127.5000, 127.5000, 127.5000,  ..., 127.5000, 127.5000, 127.5000],
         ...,
         [ 39.5000,  41.2500,  44.7500,  ..., 127.5000, 127.5000, 127.5000],
         [ 34.0000,  39.2500,  48.5000,  ..., 127.5000, 127.5000, 127.5000],
         [ 36.5000,  38.7500,  50.2500,  ..., 127.5000, 127.5000, 127.5000]]])
-------------- back ---------------
tensor([[[117.5000, 137.5000, 125.5000,  ..., 124.5000, 124.0000, 131.0000],
         [117.5000, 137.5000, 126.5000,  ..., 135.0000, 124.0000, 131.0000],
         [123.5000, 131.5000, 127.5000,  ..., 119.0000, 121.5000, 128.0000],
         ...,
         [ 35.0000,  36.0000,  36.7500,  ..., 132.5000, 130.2500, 134.2500],
         [ 36.5000,  36.5000,  37.7500,  ..., 126.7500, 125.0000, 130.0000],
         [ 35.5000,  37.5000,  37.2500,  ..., 128.2500, 125.0000, 130.0000]]])
torch.Size([1, 3, 112, 112])

发现输入iwt的结果变化了,突然想起来torchvision.utils.save_image函数是会对数据进行处理的

解决办法就是调整下顺序即可

 

重新运行一遍:

#coding:utf-8
import torch.nn as nn
import torch

def dwt_init(x):
    x01 = x[:, :, 0::2, :] / 2
    x02 = x[:, :, 1::2, :] / 2
    x1 = x01[:, :, :, 0::2]
    x2 = x02[:, :, :, 0::2]
    x3 = x01[:, :, :, 1::2]
    x4 = x02[:, :, :, 1::2]
    x_LL = x1 + x2 + x3 + x4
    x_HL = -x1 - x2 + x3 + x4
    x_LH = -x1 + x2 - x3 + x4
    x_HH = x1 - x2 - x3 + x4

    return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)


# 使用哈尔 haar 小波变换来实现二维逆向离散小波
def iwt_init(x):
    r = 2
    in_batch, in_channel, in_height, in_width = x.size()
    # print([in_batch, in_channel, in_height, in_width]) #[1, 12, 56, 56]
    out_batch, out_channel, out_height, out_width = in_batch, int(in_channel / (r**2)), r * in_height, r * in_width
    # print(out_batch, out_channel, out_height, out_width) #1 3 112 112
    x1 = x[:, 0:out_channel, :, :] / 2
    x2 = x[:, out_channel:out_channel * 2, :, :] / 2
    x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
    x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2
    # print(x1.shape) #torch.Size([1, 3, 56, 56])
    # print(x2.shape) #torch.Size([1, 3, 56, 56])
    # print(x3.shape) #torch.Size([1, 3, 56, 56])
    # print(x4.shape) #torch.Size([1, 3, 56, 56])
    # h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
    h = torch.zeros([out_batch, out_channel, out_height, out_width]).float()

    h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
    h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
    h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
    h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
    return h


# 二维离散小波
class DWT(nn.Module):
    def __init__(self):
        super(DWT, self).__init__()
        self.requires_grad = False  # 信号处理,非卷积运算,不需要进行梯度求导

    def forward(self, x):
        return dwt_init(x)


# 逆向二维离散小波
class IWT(nn.Module):
    def __init__(self):
        super(IWT, self).__init__()
        self.requires_grad = False

    def forward(self, x):
        return iwt_init(x)

if __name__ == '__main__':
    import os, cv2, torchvision
    from PIL import Image
    import numpy as np
    from torchvision import transforms as trans
    # img = cv2.imread('./1.jpg')
    # print(img.shape)
    # img = Image.fromarray(img.astype(np.uint8))
    img = Image.open('./1.jpg')
    transform = trans.Compose([
        trans.ToTensor()
    ])
    img = transform(img).unsqueeze(0)
    dwt = DWT()
    change_img_tensor = dwt(img)
    iwt = IWT()
    back_img_tensor = iwt(change_img_tensor)
    print(back_img_tensor.shape)
    # print(change_img_tensor.shape) #torch.Size([1, 12, 56, 56])
    
    
    #合并成一张4格的图
    h = torch.zeros([4,3,change_img_tensor.size(2),change_img_tensor.size(2)]).float()

    
    for i in range(change_img_tensor.size(1)//3):
        h[i,:,:,:] = change_img_tensor[:,i*3:i*3+3:,:]
        #分别保存为一个图片
        torchvision.utils.save_image(change_img_tensor[:,i*3:i*3+3:,:], os.path.join('./', 'change_{}.jpg'.format(i)))
    
            
    change_img_grid = torchvision.utils.make_grid(h, 2) #一行2张图片    
    torchvision.utils.save_image(change_img_grid, 'change_img_grid.jpg')

    torchvision.utils.save_image(back_img_tensor, 'back.jpg')

小波变换后的结果为:

 

重构的图为:

 

2)对代码进行解释

1》dwt

def dwt_init(x):
    x01 = x[:, :, 0::2, :] / 2
    x02 = x[:, :, 1::2, :] / 2
    x1 = x01[:, :, :, 0::2]
    x2 = x02[:, :, :, 0::2]
    x3 = x01[:, :, :, 1::2]
    x4 = x02[:, :, :, 1::2]
    x_LL = x1 + x2 + x3 + x4
    x_HL = -x1 - x2 + x3 + x4
    x_LH = -x1 + x2 - x3 + x4
    x_HH = x1 - x2 - x3 + x4

    return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)

首先:

    x01 = x[:, :, 0::2, :] / 2
    x02 = x[:, :, 1::2, :] / 2

将矩阵分为偶数行和奇数行,并将所有值都除以2,这样后面只要进行求和和求差即可,因为已经求均值了

然后下面的:

    x1 = x01[:, :, :, 0::2]
    x2 = x02[:, :, :, 0::2]
    x3 = x01[:, :, :, 1::2]
    x4 = x02[:, :, :, 1::2]

就是分为偶数列和奇数列,假设矩阵为6*6大小,那么就将该矩阵分成了4个3*3大小的x1、x2、x3和x4,如下图所示:

 

 

那么接下来在进行的计算就是进行行、列的和、差变换了:

    x_LL = x1 + x2 + x3 + x4
    x_HL = -x1 - x2 + x3 + x4
    x_LH = -x1 + x2 - x3 + x4
    x_HH = x1 - x2 - x3 + x4

再用一张图说明:

 

 

 

2》iwt

def iwt_init(x):
    r = 2
    in_batch, in_channel, in_height, in_width = x.size()
    # print([in_batch, in_channel, in_height, in_width]) #[1, 12, 56, 56]
    out_batch, out_channel, out_height, out_width = in_batch, int(in_channel / (r**2)), r * in_height, r * in_width
    # print(out_batch, out_channel, out_height, out_width) #1 3 112 112
    x1 = x[:, 0:out_channel, :, :] / 2
    x2 = x[:, out_channel:out_channel * 2, :, :] / 2
    x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
    x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2
    # print(x1.shape) #torch.Size([1, 3, 56, 56])
    # print(x2.shape) #torch.Size([1, 3, 56, 56])
    # print(x3.shape) #torch.Size([1, 3, 56, 56])
    # print(x4.shape) #torch.Size([1, 3, 56, 56])
    # h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
    h = torch.zeros([out_batch, out_channel, out_height, out_width]).float()

    h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
    h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
    h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
    h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
    return h

在这里的行x1=x_LL/2,  x2=x_HL/2,  x3=x_LH/2,  x4=x_HH/2

所以我们想重构,其实就是从这些值中恢复dwt中的x1,x2,x3,x4,分别放到h对应的位置变为原来的矩阵,如x1对应的是h[:, :, 0::2, 0::2],如下图所示:

这就是重构的方法

 

过程中遇到的一点问题pytorch图像处理的问题

posted @ 2020-03-19 15:45  慢行厚积  阅读(7209)  评论(1编辑  收藏  举报