基于rar压缩文件自定义pytorch数据集
# coding:utf-8 import os import torch import numpy as np import rarfile as rar from PIL import Image from torch.utils.data import Dataset, DataLoader class myDataset(Dataset): # 基于压缩文件rar, 定义自己的数据集 def __init__(self, inrar): # inrar:.rar windows 压缩文件,这里为图片数据 self.inrar = inrar def __len__(self): orar = rar.RarFile(self.inrar) fnames = orar.namelist() orar.close() return len(fnames) def __getitem__(self, item): orar = rar.RarFile(self.inrar) fnames = orar.namelist() fname = fnames[item] fp = orar.extract(fname) img = Image.open(fp) if img.mode == "RGBA": img = np.array(img)[:, :, :3] / 255 else: img = np.array(img) / 255 os.remove(fp) orar.close() return torch.tensor(img, dtype=torch.float32)