Pytorch笔记|小土堆|P6-7|Dataset类、TensorDataset类
Dataset类
作用:模型的数据集接口
__init__将对象实例化,创建对象时obj = class(..., ...)
会立即被调用,需要提供(输入)类中使用到的变量。
__getitem__通过img, label = obj[idx]
获取(返回)每一个数据和label
__len__通过len(obj)
获取(返回)数据量
练习代码如下:
from torch.utils.data import Dataset
import os
from PIL import Image
# 数据集下载地址:https://download.pytorch.org/tutorial/hymenoptera_data.zip
class Mydata(Dataset):
def __init__(self, root_dir, label_dir):
# self为全局变量
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path) # 获取path路径下所有子文件的名称
def __getitem__(self, idx):
img_name = self.img_path[idx] # idx索引
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = r'D:\ai-learning\pytorch\hymenoptera_data\train' # 改为自己的路径
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = Mydata(root_dir, ants_label_dir)
bees_dataset = Mydata(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset
img, label = train_dataset[0]
img.show()
len(train_dataset)
一个可以参考的通用数据集接口:https://blog.csdn.net/leviopku/article/details/99958182
关于Dataset类和TensorDataset类的区别:https://blog.csdn.net/qq_43611080/article/details/113575167
Dataset需要用户自己实现__getitem__和__len__两个方法。常用于创建自定义数据集,适合处理复杂的数据预处理和加载逻辑
TensorDataset专门处理Tensor数据,已实现__getitem__和__len__两个方法,可直接使用。TensorDataset封装了两个张量:data_tensor 和 label_tensor,并且可以通过索引访问每个样本。适合将多个现有的张量数据封装成数据集,方便快速使用
from torch.utils.data import TensorDataset
# 定义一些张量数据
data_tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])
label_tensor = torch.tensor([0, 1, 0])
# 创建TensorDataset对象
dataset = TensorDataset(data_tensor, label_tensor)
# 访问数据
print(len(dataset)) # 输出: 3
print(dataset[0]) # 输出: (tensor([1, 2]), tensor(0))
参考:
[1]https://www.cnblogs.com/seansheep/p/16163159.html
[2]https://www.bilibili.com/video/BV1hE411t7RN?p=7&vd_source=fa1d778abbb911d02be7ac36f2b2e32a
禁转载搬运,谢谢观看