pytorch数据加载

一、方法一
数据组织形式
dataset_name
----train
----val

from
torchvision import datasets, models, transforms # Data augmentation and normalization for training # Just normalization for validation data_transforms = { 'train': transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]),} data_dir = 'hymenoptera_data' image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

二、方法二

自定路径+txt内写入的路径

txt内容,前面是图片路径,后面是label类别

 

 生成txt代码

# -*-coding:utf-8-*-
"""
    @Project: googlenet_classification
    @File   : create_labels_files.py
    @Author : panjq
    @E-mail : pan_jinquan@163.com
    @Date   : 2018-08-11 10:15:28
"""

import os
import os.path


def write_txt(content, filename, mode='w'):
    """保存txt数据
    :param content:需要保存的数据,type->list
    :param filename:文件名
    :param mode:读写模式:'w' or 'a'
    :return: void
    """
    with open(filename, mode) as f:
        for line in content:
            str_line = ""
            for col, data in enumerate(line):
                if not col == len(line) - 1:
                    # 以空格作为分隔符
                    str_line = str_line + str(data) + " "
                else:
                    # 每行最后一个数据用换行符“\n”
                    str_line = str_line + str(data) + "\n"
            f.write(str_line)


def get_files_list(dir):
    '''
    实现遍历dir目录下,所有文件(包含子文件夹的文件)
    :param dir:指定文件夹目录
    :return:包含所有文件的列表->list
    '''
    # parent:父目录, filenames:该目录下所有文件夹,filenames:该目录下的文件名
    files_list = []
    for parent, dirnames, filenames in os.walk(dir):
        for filename in filenames:
            # print("parent is: " + parent)
            # print("filename is: " + filename)
            # print(os.path.join(parent, filename))  # 输出rootdir路径下所有文件(包含子文件)信息
            curr_file = parent.split(os.sep)[-1]
            if curr_file == '010101':
                labels = 0
            elif curr_file == '010102':
                labels = 1
            elif curr_file == '010103':
                labels = 2
            elif curr_file == '010105':
                labels = 3
            elif curr_file == '010106':
                labels = 4
            elif curr_file == '010107':
                labels = 5
            elif curr_file == '010201':
                labels = 6
            elif curr_file == '010202':
                labels = 7
            elif curr_file == '030000':
                labels = 8
            files_list.append([os.path.join(curr_file, filename), labels])
    return files_list


if __name__ == '__main__':
    train_dir = r'F:\WU_work\guandao\data\guandao20190904_10\train'
    train_txt = r'F:\WU_work\guandao\data\guandao20190904_10/train.txt'
    train_data = get_files_list(train_dir)
    write_txt(train_data, train_txt, mode='w')

    val_dir = r'F:\WU_work\guandao\data\guandao20190904_10\validation'
    val_txt = r'F:\WU_work\guandao\data\guandao20190904_10/val.txt'
    val_data = get_files_list(val_dir)
    write_txt(val_data, val_txt, mode='w')

 

    # 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制
  #img_path是训练集或验证集路径,如F:\WU_work\guandao\data\guandao20190904_10\train
train_data = MyDataset(img_path = '', txt_path=train_txt_path, transform=trainTransform) valid_data = MyDataset(img_path = '', txt_path=valid_txt_path, transform=validTransform)

数据加载

# -------------------------------------------- step 1/5 : 加载数据 -------------------------------------------
    train_txt_path = './Data/train.txt'
    valid_txt_path = './Data/valid.txt'
    # 数据预处理设置
    normMean = [0.4948052, 0.48568845, 0.44682974]
    normStd = [0.24580306, 0.24236229, 0.2603115]
    normTransform = transforms.Normalize(normMean, normStd)
    trainTransform = transforms.Compose([
        transforms.Resize(224),
        transforms.RandomCrop(224, padding=4),
        transforms.ToTensor(),
        normTransform
    ])
 
    validTransform = transforms.Compose([
        transforms.ToTensor(),
        normTransform
    ])
 
    # 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制
    train_data = MyDataset(img_path = '', txt_path=train_txt_path, transform=trainTransform)
    valid_data = MyDataset(img_path = '', txt_path=valid_txt_path, transform=validTransform)
 
    # 构建DataLoder
    train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=16, shuffle=True)
    valid_loader = torch.utils.data.DataLoader(dataset=valid_data, batch_size=16)
train_loader 是迭代器,每次返回图片和对应的label

posted @ 2019-09-05 10:50  X18301096  阅读(689)  评论(0编辑  收藏  举报