基于rar压缩文件自定义pytorch数据集

# coding:utf-8
import os
import torch
import numpy as np
import rarfile as rar
from PIL import Image
from torch.utils.data import Dataset, DataLoader


class myDataset(Dataset):
    # 基于压缩文件rar, 定义自己的数据集
    def __init__(self, inrar):
        # inrar:.rar windows 压缩文件,这里为图片数据
        self.inrar = inrar

    def __len__(self):
        orar = rar.RarFile(self.inrar)
        fnames = orar.namelist()
        orar.close()
        return len(fnames)

    def __getitem__(self, item):
        orar = rar.RarFile(self.inrar)
        fnames = orar.namelist()
        fname = fnames[item]
        fp = orar.extract(fname)
        img = Image.open(fp)
        if img.mode == "RGBA":
            img = np.array(img)[:, :, :3] / 255
        else:
            img = np.array(img) / 255
        os.remove(fp)

        orar.close()
        return torch.tensor(img, dtype=torch.float32)

  

posted @ 2022-02-23 11:05  ddzhen  阅读(62)  评论(0编辑  收藏  举报