image-classification-dataset
import torchtext
from torchvision import transforms
from torch.utils import data
from d2l import torch as d2l
import torchvision
trans = transforms.ToTensor()
fashion_mnist_train = torchvision.datasets.FashionMNIST("../data" ,
train = True ,
transform = trans ,
download = True)
fashion_mnist_test = torchvision.datasets.FashionMNIST("../data",
train = True ,
transform = trans ,
download = True)
def get_fashion_mnist_label(label):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [ text_labels[i] for i in label ]
batch_size = 50
dataloader_train = data.DataLoader(fashion_mnist_train ,
batch_size =batch_size ,
shuffle= True ,
num_workers = 0 )
dataloader_test = data.DataLoader(fashion_mnist_test ,
batch_size = batch_size ,
shuffle = False ,
num_workers = 0
)
a = [1,2]
a.insert(0,5)
print(a)
[5, 1, 2]
def get_fashion_mnist_dataloader(batch_size , resize = None):
trans = [transforms.ToTensor()]
if resize:
trans.insert(0,transforms.Resize(resize))
trans = transforms.Compose(trans)
train_dataset = torchvision.datasets.FashionMNIST("../data",
train = True,
transform = trans,
download = True
)
test_dataset = torchvision.datasets.FashionMNIST("../data",
train = False,
transform = trans,
download = True
)
return (
data.DataLoader( train_dataset , batch_size = batch_size , shuffle = True , num_workers = 0 ) ,
data.DataLoader( test_dataset , batch_size = batch_size , shuffle = False , num_workers = 0 )
)
batch_size= 50
train_dataloader , test_dataloader = get_fashion_mnist_dataloader(batch_size , resize= (60,60))
for X , Y in train_dataloader:
print(X.shape , Y.shape)
break
for X, Y in test_dataloader:
print(X.shape , Y.shape)
break
torch.Size([50, 1, 60, 60]) torch.Size([50])
torch.Size([50, 1, 60, 60]) torch.Size([50])
重点函数
- transforms.ToTensor() 不可忘记括号
- 同时 结合transforms.Compose() 传入一个列表 顺序对数据进行处理 作为合成的trans
- 图片转换相关函数 在 torchvision.transforms
- 常用dataset在 torchvision.dataset里
- 读取常用数据集流程
- 应用 torchvision.transforms 封装好trans
- 应用 torchvision.dataset 读取出相应数据集
- 配置参数:
- data:数据位置,
- train:bool 是否读取训练集false则代表测试集 ,
- transform:格式转换。
- download:是否下载数据集
- 应用torch.utils.data.Dataloader 将dataset 封装为一个可迭代对象
- 配置参数:
- dataset:对应dataset
- batch_size:批量大小
- shuffle:是否打乱
个人作品,
如有错误,请指出;
如要转载,请注明出处。
三克油。。