数据处理工具

PyTorch数据处理工具箱

torch.utils.data

  • Dataset
    抽象类,其他数据集类定义时需继承自该类,并覆写两个方法:getitem__和__len
  • DataLoader
    定义一个新的迭代器,实现批量batch读取,打乱shuffle数据和并行加速等功能
  • random_split
    将数据集随机拆分成给定长度的非重叠的新数据集
  • *sample
    多种采样函数

torch.utils.data.Dataset抽象类,自定义数据集需继承这个类,并实现两个函数__len__和__getitem__
torch.utils.data.DataLoader定义数据集迭代器,实现batch读取

class TestDataset(data.Dataset):
    def __init__(self):
        self.Data= np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]])
        self.Label= np.asarray([0,1,0,1,2])
    def __getitem__(self, index):
        txt= torch.from_numpy(self.Data[index]) # 将numpy转换成Tensor
        label= torch.tensor(self.Label[index])
        return txt,label
    def __len__(self):
        return len(self.Data)
# 使用DataLoader对数据集Dataset进行批量batch处理,同时进行shuffle和并行加速等操作
DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None, # 样本抽取
    batch_sampler=None,
    num_workers=0, # 多进程加载
    collate_fn=<function default_collate at 0x7f108ee01620>, # 多个样本拼接成一个batch的拼接方式
    pin_memory=False, # 是否将数据保存在pin memory区,加速加载到GPU
    drop_last=False, # 将多出的不足一个batch_size的数据丢弃
    timeout=0,
    worker_init_fn=None,
)

test= TestDataset()
test_loader=data.DataLoader(test, batch_size=2, shuffle=False, num_workers=2)
for i,traindata in enumerate(test_loader):
    data,label= traindata
# 可以像使用迭代器一样使用test_loader,不过由于他不是迭代器,使用iter()将其转换成迭代器并用next()遍历
iter(test_loader)
next(dataiter)

一般使用data.Dataset处理同一目录下的数据,若数据在不同目录下(不同目录表示不同类别),此时可用torchvision处理数据

torchvision

  • datasets
    继承自torch.utils.data.Dataset,提供常用数据集Mnist,Cifar10/100,ImageNet和COCO
  • models
    提供经典的网络结构和模型pretrained=True,如AlexNet,VGG,ResNet,Inception系列
  • transforms
    常用的数据预处理操作,主要对Tensor和PIL Image类型操作,当预处理有多个函数时,可用transforms.Compose将其组合
  • utils
    含有两个函数:make_grid将多张图像拼接在一个网格中,save_img将Tensor保存成图像

transforms 对 PIL.Image 常见操作

  1. Scale/Resize 调整尺寸,保持长宽比不变
  2. CenterCrop,RandomCrop,RandomSizeCrop 裁减图像
  3. Pad 填充
  4. ToTensor 将取值范围为[0,255]的PIL.Image或形状为(H,W,C)的ndarray
    转换成(C,H,W),取值范围为[0,1.0]的torch.FloatTensor
  5. RandomHorizontalFlip 图像随机水平翻转,翻转概率为0.5
  6. RandomVerticalFlip 图像随机垂直翻转
  7. ColorJitter 修改图像亮度,对比度和饱和度

transforms对 Tensor 常见操作

  1. Normalize 标准化
  2. ToPILImage 将Tensor转换成PIL.Image

transforms.Lambda()使用自定义lambda表达式,如每个像素加10:transforms.Lambda(lambda x:x.add(10))

当对数据集进行多个操作时,可通过Compose()将这些操作拼接,类似于nn.Sequential

transforms.Compose({
    # 将给定的PIL.Image进行中心切割,size可以是tuple或Integer
	transforms.CenterCrop(10),
	transforms.RandomCrop(20,padding=0),
	transforms.ToTensor(),
	transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))
})

函数ImageFoldertorchvision.datasets中成员
当文件依据标签存储在不同文件夹下时,可以使用其直接构造出Dataset,ImageFolder会将文件夹名自动转换成序列

my_trans=transforms.Compose({
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(224),
	transforms.ToTensor()
})
train_data= torchvision.datasets.ImageFolder("",transforms=my_trans)
train_loader= torch.utils.data.DataLoader(train_data, batch_size=8, shuffle=True)
for i_batch,img in enumerate(train_loader):
    if i_batch==0:
	    print(img[1])
		fig= plt.figure()
		grid= torchvision.utils.make_grid(img[0])
		plt.imshow(grid.numpy().transpose((1,2,0)))
		plt.show()
		utils.save_image(grid,'te.png')
posted @   sgqmax  阅读(22)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· SQL Server 2025 AI相关能力初探
· 单线程的Redis速度为什么快?
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
点击右上角即可分享
微信分享提示