torch加载参数

 1 from torch.utils.data import DataLoader
 2 from torchvision import datasets
 3 from PIL import Image as img
 4 
 5 dataPath = './data/imgs/'
 6 
 7 dataset = datasets.ImageFolder('./data/', loader=img.open)
 8 dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
 9 
10 # 方式一
11 for epoch in range(100):
12     for i, (img, _)in enumerate(dataloader):
13         # do training
14 
15 # 方式二
16 
17 def data_gen(data_loader):
18     while True:
19         for (images, _) in enumerate(data_loader):
20             yield images
21 
22 gen_img = data_gen(dataloader)
23 
24 for iter in range(100):
25     imgs = gen_img.__next__()
26     # do training

 

posted @ 2020-03-17 20:44  Junzhao  阅读(420)  评论(0编辑  收藏  举报