pytorch数据集加载Dataset

一、Dataset基类介绍

在torch中提供了数据集的基类torch.utils.data.Dataset,继承这个基类,可以快速实现对数据的加载

torch.utils.data.Dataset的源码如下:

class Dataset(Generic[T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

我们需在自定义的数据集类中继承Dataset类,同时还需要实现以下方法:

1、__getitem__,能够通过传入索引的方式获取数据,例如通过dataset[i]获取其中的第i条数据

二、torch.utils.data.Dataloader

 DataLoader(dataset=my_dataset,batch_size=2,shuffle=True)

 

1、dataset:提前定义的dataset实例

2、batch_size:传入数据的batch的大小,常用有128,256等等

3、shuffle:bool类型,表示是否在每次获取数据的时候提前打乱数据

4、num_workers:加载数据的线程数

三、数据加载案例

下载国外正常短信和骚扰短信数据集,数据下载地址:

http://archive.ics.uci.edu/dataset/228/sms+spam+collection     

代码示例:

 

import torch
from torch.utils.data import Dataset,DataLoader

data_path = r"D:\coding\learning\python\pytorchtest\data\SMSSpamCollection"

#完成数据集类
class MyDataset(Dataset):
    def __init__(self):
        self.lines = open(data_path,encoding='utf-8').readlines()

    def __getitem__(self, index):
        #获取索引对应位置的一条数据
        cur_line = self.lines[index].strip()
        label = cur_line[:4].strip()   #取短信内容类型,前4个字符
        content = cur_line[4:].strip()  #短信内容
        return label,content

    def __len__(self):
        #返回数据集总量
        return  len(self.lines)
my_dataset = MyDataset()
data_loader = DataLoader(dataset=my_dataset,batch_size=2,shuffle=True)

if __name__ == '__main__':
    my_dataset = MyDataset()
    print(my_dataset[0])   #取第0个数据
    print(len(my_dataset))  #数据数量
    for i in data_loader:
        print(i)  #循环输出

    print(len(my_dataset))
    print(len(data_loader))  #math.ceil(len(my_dataset)/batch_size) 向上取整

 运行结果:

可以看到print(i)一次取出是两条,因为batch_size=2,print(len(data_loader))输出的是2787=math.ceil(len(my_dataset)/batch_size)

 

posted @ 2024-02-04 10:55  ziff123  阅读(43)  评论(0编辑  收藏  举报