pytorch数据集加载Dataset
一、Dataset基类介绍
在torch中提供了数据集的基类torch.utils.data.Dataset,继承这个基类,可以快速实现对数据的加载
torch.utils.data.Dataset的源码如下:
class Dataset(Generic[T_co]): r"""An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. """ def __getitem__(self, index) -> T_co: raise NotImplementedError def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]': return ConcatDataset([self, other]) # No `def __len__(self)` default? # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] # in pytorch/torch/utils/data/sampler.py
我们需在自定义的数据集类中继承Dataset类,同时还需要实现以下方法:
1、__getitem__,能够通过传入索引的方式获取数据,例如通过dataset[i]获取其中的第i条数据
二、torch.utils.data.Dataloader
DataLoader(dataset=my_dataset,batch_size=2,shuffle=True)
1、dataset:提前定义的dataset实例
2、batch_size:传入数据的batch的大小,常用有128,256等等
3、shuffle:bool类型,表示是否在每次获取数据的时候提前打乱数据
4、num_workers:加载数据的线程数
三、数据加载案例
下载国外正常短信和骚扰短信数据集,数据下载地址:
http://archive.ics.uci.edu/dataset/228/sms+spam+collection
代码示例:
import torch from torch.utils.data import Dataset,DataLoader data_path = r"D:\coding\learning\python\pytorchtest\data\SMSSpamCollection" #完成数据集类 class MyDataset(Dataset): def __init__(self): self.lines = open(data_path,encoding='utf-8').readlines() def __getitem__(self, index): #获取索引对应位置的一条数据 cur_line = self.lines[index].strip() label = cur_line[:4].strip() #取短信内容类型,前4个字符 content = cur_line[4:].strip() #短信内容 return label,content def __len__(self): #返回数据集总量 return len(self.lines) my_dataset = MyDataset() data_loader = DataLoader(dataset=my_dataset,batch_size=2,shuffle=True) if __name__ == '__main__': my_dataset = MyDataset() print(my_dataset[0]) #取第0个数据 print(len(my_dataset)) #数据数量 for i in data_loader: print(i) #循环输出 print(len(my_dataset)) print(len(data_loader)) #math.ceil(len(my_dataset)/batch_size) 向上取整
运行结果:
可以看到print(i)一次取出是两条,因为batch_size=2,print(len(data_loader))输出的是2787=math.ceil(len(my_dataset)/batch_size)