SRCNN数据预处理

复制代码
# 判断某个文件是否是图像
# enswith判断是否以指定的.png,.jpg,.jpeg结尾的字符串
# 可以根据情况扩充图像类型,加入.bmp、.tif等
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])

# 读取图像转为YCbCr模式,得到Y通道
def load_img(filepath):
    img = Image.open(filepath).convert('YCbCr')
    y, _, _ = img.split()
    return y

# 裁剪大小,宽高一致为300
# 如果想训练自己的数据集,请根据情况修改裁剪大小
CROP_SIZE = 300

# 封装数据集,适配后面的torch.utils.data.DataLoader中的dataset,定义成类似形式
# 类参数为图像文件夹路径和放大倍数
# __len__(self) 定义当被len()函数调用时的行为(返回容器中元素的个数)
#__getitem__(self) 定义获取容器中指定元素的行为,相当于self[key],即允许类对象可以有索引操作。
#__iter__(self) 定义当迭代容器中的元素的行为
# 返回输入图像和标签,传入DataLoader的dataset参数
class DatasetFromFolder(Dataset):
    def __init__(self, image_dir, zoom_factor):
        super(DatasetFromFolder, self).__init__()
        self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)] # 图像路径列表
        crop_size = CROP_SIZE - (CROP_SIZE % zoom_factor) # 处理放大倍数,防止用户瞎设置,本例只能设置为2,3,4,大小不变
        # 数据集变换
        # 还有一些其他的变换操作,如归一化等,遇到一个积累一个
        self.input_transform = transforms.Compose([transforms.CenterCrop(crop_size), # 从图片中心裁剪成300*300
                                                   transforms.Resize(
                                                       crop_size // zoom_factor),    # Resize, 输入应该是缩放倍数后的图像,因为先缩小后放大
                                                   transforms.Resize(
                                                       crop_size, interpolation=Image.BICUBIC), # 双三次插值
                                                   transforms.ToTensor()]) # 图像转成tensor
        # label标签,超分不是分类问题,定义成一样的就行
        self.target_transform = transforms.Compose(
            [transforms.CenterCrop(crop_size), transforms.ToTensor()])

    def __getitem__(self, index):
        input = load_img(self.image_filenames[index]) # 输入是图像的Y通道,即亮度通道
        target = input.copy()
        input = self.input_transform(input)
        target = self.target_transform(target)
        return input, target

    def __len__(self):
        return len(self.image_filenames) # 图像个数
复制代码

 

posted @   视觉书虫  阅读(10)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列01:轻松3步本地部署deepseek,普通电脑可用
· 按钮权限的设计及实现
· 25岁的心里话
历史上的今天:
2023-01-08 在Chrome中使用Feedbro订阅RSS信息流
2023-01-08 查询局域网内IP地址
点击右上角即可分享
微信分享提示