fpl
import torch from torch import nn import numpy as np import matplotlib.pyplot as plt from PIL import Image from torchvision import transforms from math import sqrt import os import cv2 import torchvision.utils as vutils os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" import torchvision.utils as vutils # 读取两张图像 img1 = Image.open('img/low/1.png') transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.CenterCrop((224, 224)), transforms.ToTensor() ]) img1 = transform(img1).unsqueeze(0) img1_fft = torch.fft.fft2(img1, dim=[2, 3]) img1_fft_shift = torch.fft.fftshift(img1_fft) # 获取图像尺寸 b, c, rows, cols = img1.shape # 计算低通滤波器和高通滤波器的掩码 crow, ccol = int(rows / 2), int(cols / 2) mask_low = np.zeros((rows, cols), np.uint8) mask_low[crow - 30:crow + 30, ccol - 30:ccol + 30] = 1 mask_high = 1 - mask_low # 应用掩码 fshift_low = img1_fft_shift * mask_low fshift_high = img1_fft_shift * mask_high # 傅里叶逆变换 ishift_low = torch.fft.ifftshift(fshift_low) ishift_high = torch.fft.ifftshift(fshift_high) img_back_low = torch.fft.ifft2(ishift_low) img_back_high = torch.fft.ifft2(ishift_high) img_back_low = np.abs(img_back_low) img_back_high = torch.abs(img_back_high) # 显示结果 vutils.save_image(img_back_low, 'img/FFTOutput/img_LF.jpg', normalize=True) vutils.save_image(img_back_high, 'img/FFTOutput/img_HF.jpg', normalize=True)
#################### # Loss Functions #################### class focal_pixel_learning(torch.nn.Module): def __init__(self): super().__init__() self.alpha_sp, self.gamma_sp = 1, 0.5 self.alpha_lp, self.gamma_lp = 1, 1 self.upscale_func = functools.partial( F.interpolate, mode='bicubic', align_corners=False) self.weig_func = lambda x, y, z: torch.exp((x-x.min()) / (x.max()-x.min()) * y) * z def forward(self, x, hr, lr): f_BI_x = self.upscale_func(lr, size=hr.size()[2:]) y_sp = torch.abs(hr - f_BI_x) w_y_sp = self.weig_func(y_sp, self.alpha_sp, self.gamma_sp).detach() y_lp = torch.abs(hr - f_BI_x - x) w_y_lp = self.weig_func(y_lp, self.alpha_lp, self.gamma_lp).detach() y_hat = hr - f_BI_x loss = torch.mean(w_y_sp * w_y_lp * torch.abs(x - y_hat)) return loss