fpl

import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from math import sqrt
import os
import cv2
import torchvision.utils as vutils
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torchvision.utils as vutils


# 读取两张图像
img1 = Image.open('img/low/1.png')
transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop((224, 224)),
        transforms.ToTensor()
    ])
img1 = transform(img1).unsqueeze(0)
img1_fft = torch.fft.fft2(img1, dim=[2, 3])
img1_fft_shift = torch.fft.fftshift(img1_fft)
# 获取图像尺寸
b, c, rows, cols = img1.shape

# 计算低通滤波器和高通滤波器的掩码
crow, ccol = int(rows / 2), int(cols / 2)
mask_low = np.zeros((rows, cols), np.uint8)
mask_low[crow - 30:crow + 30, ccol - 30:ccol + 30] = 1
mask_high = 1 - mask_low

# 应用掩码
fshift_low = img1_fft_shift * mask_low
fshift_high = img1_fft_shift * mask_high

# 傅里叶逆变换
ishift_low = torch.fft.ifftshift(fshift_low)
ishift_high = torch.fft.ifftshift(fshift_high)
img_back_low = torch.fft.ifft2(ishift_low)
img_back_high = torch.fft.ifft2(ishift_high)
img_back_low = np.abs(img_back_low)
img_back_high = torch.abs(img_back_high)

# 显示结果
vutils.save_image(img_back_low, 'img/FFTOutput/img_LF.jpg', normalize=True)
vutils.save_image(img_back_high, 'img/FFTOutput/img_HF.jpg', normalize=True)

 

####################
# Loss Functions
####################

class focal_pixel_learning(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.alpha_sp, self.gamma_sp = 1, 0.5
        self.alpha_lp, self.gamma_lp = 1, 1
        self.upscale_func = functools.partial(
            F.interpolate, mode='bicubic', align_corners=False)
        self.weig_func = lambda x, y, z: torch.exp((x-x.min()) / (x.max()-x.min()) * y) * z

    def forward(self, x, hr, lr):
        f_BI_x = self.upscale_func(lr, size=hr.size()[2:])

        y_sp = torch.abs(hr - f_BI_x)
        w_y_sp = self.weig_func(y_sp, self.alpha_sp, self.gamma_sp).detach()

        y_lp = torch.abs(hr - f_BI_x - x)
        w_y_lp = self.weig_func(y_lp, self.alpha_lp, self.gamma_lp).detach()

        y_hat = hr - f_BI_x
        loss = torch.mean(w_y_sp * w_y_lp * torch.abs(x - y_hat))

        return loss

 

posted @ 2023-07-24 09:18  helloWorldhelloWorld  阅读(59)  评论(0)    收藏  举报