Pytorch数据集读入——Dataset类,实现数据集打乱Shuffle

在进行相关平台的练习过程中,由于要自己导入数据集,而导入方法在市面上五花八门,各种库都可以应用,在这个过程中我准备尝试torchvision的库dataset
torchvision.datasets.ImageFolder
简单应用起来非常简单,用torchvision.datasets.ImageFolder实现图片的导入,在随后训练过程中用Datalodar处理后可按批次取出训练集

class ImageFolder(root, transform=None, target_transform=None, loader=default_loader, is_valid_file=None)
ImageFolder有这么几个参数,其中root指的是数据所在的文件夹,其中该文件夹的存储方式应为
root/labels/xxx.jpg
即根据自身分类标签存储在对应标签名的文件夹内
ImageFolder在读入的过程中会自行加好标签,最后形成一对对的数据
另外比较常用的就是transform,表示对于传入图片的预处理,如剪裁,颜色选择等等
比如

transform_t = transforms.Compose([
    transforms.Resize([64, 64]), 
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor()]
    )

具体参数可以上网查看
在之后用DataLodar处理后虽然的确有Shuffle的参数,但是却只是在一个小批次内进行打乱,原本是按照类别存储的,这样的话会导致很严重的过拟合,为了避免这个,我决定常识改写一下Dataset的类(主要是看起来Dataset看起来改写比较顺手...ImageFolder还没有看源码并没要对此下手)
但是Dataset需要读入一个个的训练数据的位置,怎么办呢?我就先写了一个小脚本,生成一个txt文件来存储所有数据的名称(相对路径),同时在这一步就进行打乱操作【一眼看下去甚至会发现init的classnum参数完全没用上(捂脸

import os
import numpy as np
'''
self.target     顺序存储数据集
self.DataFile   存储根目录
self.s          存储所有数据
self.label      存储所有标签及其对应的值
'''
class create_list():
    def __init__(self,root,classnum=2):
        self.target=open("./Data.txt",'w')
        self.DataFile=root
        self.s=[]
        self.label={}
        self.datanum=0
    
    def create(self):
        files=os.listdir(self.DataFile)
        for labels in files:
            tempdata=os.listdir(self.DataFile+"/"+labels)
            self.label[labels]=len(self.label)
            for img in tempdata:
                self.datanum+=1
                self.target.write(self.DataFile+"/"+labels+"/"+img+" "+labels+"\n")
                self.s.append([self.DataFile+"/"+labels+"/"+img,labels])
    
    def detail(self):
        #查看数据数量以及标签对应
        print(self.datanum)
        print(self.label)
    
    def get_all(self):
        #查看所有数据
        print(self.s)

    def get_root(self):
        #获得根目录
        return self.DataFile

    def shuffle(self):
        #获得打乱的存储txt
        shuffle_file=open("./Shuffle_Data.txt",'w')
        temp=self.s
        np.random.shuffle(temp)
        for i in temp:
            shuffle_file.write(i[0]+" "+str(i[1])+"\n")
        return self.DataFile+"/Shuffle_Data.txt"

    def label_id(self,label):
        #获得该标签对应的值
        return self.label[label]

数据集的存储方式上的要求跟之前的ImageFolder一样
最终会生成一个这样的txt文件
image
数据集来源于某x光胸片判断...
而Shuffle操作就是为了生成打乱后的txt文件,我写的比较简单粗暴...先将就看吧,生成后大概就是这个样子
image
至少真正的做到打乱数据了
完成这个以后,就可以用此来帮助DataLodar了
接下来的代码或许比较辣眼睛...但是事实证明是有用的,但是可能Python技巧不太熟练所以就会显得很生涩...
我重现的Dataset类:

from PIL import Image
import torch

class cDataset(torch.utils.data.Dataset):
    def __init__(self, datatxt, root="", transform=None, target_transform=None, LabelDic=None):
        super(cDataset,self).__init__()
        files = open(root + "/" + datatxt, 'r')
        self.img=[]
        for i in files:
            i = i.rstrip()
            temp = i.split()
            if LabelDic!=None:
                self.img.append((temp[0],LabelDic[temp[1]]))
            else:
                self.img.append((temp[0],temp[0]))
            
        self.transform = transform
        self.target_transform = target_transform
    
    def __getitem__(self, index):
        files, label = self.img[index]
        img = Image.open(files).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img,label
    
    def __len__(self):
        return len(self.img)

其实直接看就能大概看明白,主要也就是要实现类里面的几个方法

class cDataset(torch.utils.data.Dataset):
    def __init__():
    def __getitem__(self, index):
    def __len__(self):

其中getitm类似一次次的取出数据,len就是返回数据集数目
其中init的参数我做了稍许调整,由于我之前的txt内标签是字符串,而为了能让对应生成的tag是所要求的,可以传入一个字典,如:
LabelDic={"NORMAL":0,"PNEUMONIA":1}
这样就可以在之后转化为数字的标签,onehot或者怎么怎么样了,,,

posted @ 2019-11-07 21:12  LOSKI  阅读(7395)  评论(0编辑  收藏  举报