image-classification-dataset

import torchtext
from torchvision import transforms
from torch.utils import data
from d2l import torch as d2l
import torchvision

trans = transforms.ToTensor()

fashion_mnist_train = torchvision.datasets.FashionMNIST("../data" , 
                                                        train = True , 
                                                        transform = trans , 
                                                        download = True)
fashion_mnist_test = torchvision.datasets.FashionMNIST("../data", 
                                                       train = True , 
                                                       transform = trans , 
                                                       download = True)

def get_fashion_mnist_label(label):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [ text_labels[i] for i in label ]

batch_size = 50
dataloader_train = data.DataLoader(fashion_mnist_train , 
                                   batch_size =batch_size  , 
                                   shuffle= True , 
                                   num_workers = 0 )

dataloader_test = data.DataLoader(fashion_mnist_test ,
                                  batch_size = batch_size ,
                                  shuffle = False ,
                                  num_workers = 0
                                 )  
a = [1,2]
a.insert(0,5)
print(a)
[5, 1, 2]
def get_fashion_mnist_dataloader(batch_size , resize = None):
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0,transforms.Resize(resize))
        
    trans  = transforms.Compose(trans)
    train_dataset = torchvision.datasets.FashionMNIST("../data",
                                                      train = True,
                                                      transform = trans,
                                                      download = True
                                                     )
    test_dataset = torchvision.datasets.FashionMNIST("../data",
                                                     train = False,
                                                     transform = trans,
                                                     download = True
                                                    )
    return ( 
            data.DataLoader( train_dataset , batch_size = batch_size , shuffle = True , num_workers = 0 ) , 
            data.DataLoader( test_dataset , batch_size = batch_size , shuffle = False , num_workers = 0 )
    )  
batch_size= 50
train_dataloader , test_dataloader = get_fashion_mnist_dataloader(batch_size , resize= (60,60))

for X , Y  in train_dataloader:
    print(X.shape , Y.shape)
    break
for X, Y in test_dataloader:
    print(X.shape , Y.shape)
    break
torch.Size([50, 1, 60, 60]) torch.Size([50])
torch.Size([50, 1, 60, 60]) torch.Size([50])

重点函数

  • transforms.ToTensor() 不可忘记括号
  • 同时 结合transforms.Compose() 传入一个列表 顺序对数据进行处理 作为合成的trans
  • 图片转换相关函数 在 torchvision.transforms
  • 常用dataset在 torchvision.dataset里
  • 读取常用数据集流程
    • 应用 torchvision.transforms 封装好trans
    • 应用 torchvision.dataset 读取出相应数据集
      • 配置参数:
      • data:数据位置,
      • train:bool 是否读取训练集false则代表测试集 ,
      • transform:格式转换。
      • download:是否下载数据集
    • 应用torch.utils.data.Dataloader 将dataset 封装为一个可迭代对象
      • 配置参数:
      • dataset:对应dataset
      • batch_size:批量大小
      • shuffle:是否打乱

posted @ 2024-06-23 10:49  Mr小明同学  阅读(9)  评论(0编辑  收藏  举报