Pytorch数据加载与使用

前言

在训练的时候通常使用Dataset来处理数据集。

Dataset的作用

提供一个方式获取数据内容和标签(label)。

实战

from torch.utils.data import Dataset

from PIL import Image
import os

class get_data(Dataset):

    def __init__(self,root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.img_dir = os.path.join(root_dir,label_dir)
        self.img_list = os.listdir(self.img_dir)


    def __getitem__(self, indx):

        img_path = os.path.join(self.img_dir,self.img_list[indx])
        img_label = self.label_dir
        img_data = Image.open(img_path)
        return img_data,img_label

    def __len__(self):
        return len(self.img_list)

root_dir = "C:\\Users\\Traveler\\Pictures"
label_dir = "Screenshots"

test = get_data(root_dir,label_dir)

img , label = test[1]

# img.show()
print(label)
print(len(test))

此代码定义了一个fet_data类,继承了Dataset,主要提供两个方法,获取数据(getitem)和获取大小(len)。
然而这两个方法使用的是内置的类,当达到一定条件时自动触发,比如__getitem__当需要获取数据时自动触发这个方法。

getitem

返回两个数据,一个是data,一个是label,实现原理就是主要看这两个函数,
os.listdir()是获取一个路径下的文件名(包含后缀)列表。类似于[‘1.txt’,'2.jpg']
Image.open()是打开图片文件的,打开一个图片后会赋值很多属性:如下图
image
使用img.show()就可以打开,img.size()就可以获取大小。

另外os的其他函数也挺重要,比如os.path.join()就是拼接路径,这个的好处是防止Linux和Windows之间的路径不匹配问题。

posted @ 2024-06-17 23:05  云岛夜川川  阅读(2)  评论(0编辑  收藏  举报