Pytorch数据读取框架
训练一个模型需要有一个数据库,一个网络,一个优化函数。数据读取是训练的第一步,以下是pytorch数据输入框架。
1)实例化一个数据库
假设我们已经定义了一个FaceLandmarksDataset数据库,此数据库将在以下建立。
import FaceLandmarksDataset face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/', transform=transforms.Compose([ Rescale(256), RandomCrop(224), ToTensor()]) )
或者使用torchvision.datasets里封装的数据集(MNIST、Fashion-MNIST、KMNIST、EMNIST、COCO、LSUN、ImageFolder、DatasetFolder、Imagenet-12、CIFAR、STL10、SVHN、PhotoTour、SBU、Flickr、VOC、Cityscapes)
import torchvision.datasets imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
2)创建一个数据加载器
import torch.utils.data.DataLoader imagenet_loader = torch.utils.data.DataLoader(imagenet_data, batch_size=4, shuffle=True, num_workers=4) #or facelandmark_loader = torch.utils.data.DataLoader(face_dataset, batch_size=4, shuffle=True, num_workers=4)
可见,数据加载器是通用的,只有数据库实例不一样,其它的都参数都一样,参数值可以根据任务需要自己调。
3)使用数据库
数据加载器可迭代的,我们可以使用数据库:
for item in facelandmark_loader: images,labels = item
do_somethi
当然, 我们也可以直接对数据库实例face_dataset进行下标操作,但这样只能够每次获取一条数据。
sample = face_dataset[index]
手与大脑的距离决定了理想与现实的相似度