pytorch数据集加载Dataset

一、Dataset基类介绍

在torch中提供了数据集的基类torch.utils.data.Dataset,继承这个基类,可以快速实现对数据的加载

torch.utils.data.Dataset的源码如下:

复制代码
class Dataset(Generic[T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py
复制代码

我们需在自定义的数据集类中继承Dataset类,同时还需要实现以下方法:

1、__getitem__,能够通过传入索引的方式获取数据,例如通过dataset[i]获取其中的第i条数据

二、torch.utils.data.Dataloader

 DataLoader(dataset=my_dataset,batch_size=2,shuffle=True)

 

1、dataset:提前定义的dataset实例

2、batch_size:传入数据的batch的大小,常用有128,256等等

3、shuffle:bool类型,表示是否在每次获取数据的时候提前打乱数据

4、num_workers:加载数据的线程数

三、数据加载案例

下载国外正常短信和骚扰短信数据集,数据下载地址:

http://archive.ics.uci.edu/dataset/228/sms+spam+collection     

代码示例:

 

复制代码
import torch
from torch.utils.data import Dataset,DataLoader

data_path = r"D:\coding\learning\python\pytorchtest\data\SMSSpamCollection"

#完成数据集类
class MyDataset(Dataset):
    def __init__(self):
        self.lines = open(data_path,encoding='utf-8').readlines()

    def __getitem__(self, index):
        #获取索引对应位置的一条数据
        cur_line = self.lines[index].strip()
        label = cur_line[:4].strip()   #取短信内容类型,前4个字符
        content = cur_line[4:].strip()  #短信内容
        return label,content

    def __len__(self):
        #返回数据集总量
        return  len(self.lines)
my_dataset = MyDataset()
data_loader = DataLoader(dataset=my_dataset,batch_size=2,shuffle=True)

if __name__ == '__main__':
    my_dataset = MyDataset()
    print(my_dataset[0])   #取第0个数据
    print(len(my_dataset))  #数据数量
    for i in data_loader:
        print(i)  #循环输出

    print(len(my_dataset))
    print(len(data_loader))  #math.ceil(len(my_dataset)/batch_size) 向上取整
复制代码

 运行结果:

可以看到print(i)一次取出是两条,因为batch_size=2,print(len(data_loader))输出的是2787=math.ceil(len(my_dataset)/batch_size)

 

posted @   ziff123  阅读(48)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列01:轻松3步本地部署deepseek,普通电脑可用
· 25岁的心里话
· 按钮权限的设计及实现
点击右上角即可分享
微信分享提示