用于pytorch的H5Dataset接口(类比TensorDataset接口)
pytorch的TensorDataset接口
1 class TensorDataset(Dataset): 2 """Dataset wrapping data and target tensors. 3 Each sample will be retrieved by indexing both tensors along the first 4 dimension. 5 Arguments: 6 data_tensor (Tensor): contains sample data. 7 target_tensor (Tensor): contains sample targets (labels). 8 """ 9 10 def __init__(self, data_tensor, target_tensor): 11 assert data_tensor.size(0) == target_tensor.size(0) 12 self.data_tensor = data_tensor 13 self.target_tensor = target_tensor 14 15 def __getitem__(self, index): 16 return self.data_tensor[index], self.target_tensor[index] 17 18 def __len__(self): 19 return self.data_tensor.size(0)
用于hdf5的H5Dataset接口
1 class H5Dataset(Dataset): 2 """Dataset wrapping data and target tensors. 3 4 Each sample will be retrieved by indexing both tensors along the first 5 dimension. 6 7 Arguments: 8 data_tensor (Tensor): contains sample data. 9 target_tensor (Tensor): contains sample targets (labels). 10 """ 11 12 def __init__(self, data_tensor, target_tensor): 13 assert data_tensor.shape[0] == target_tensor.shape[0] 14 self.data_tensor = data_tensor 15 self.target_tensor = target_tensor 16 17 def __getitem__(self, index): 18 # print(index) 19 return self.data_tensor[index], self.target_tensor[index] 20 21 def __len__(self): 22 return self.data_tensor.shape[0]
对应的DataLoader(把TensorDataset改成H5Dataset即可)
1 def load_data(): 2 f = h5py.File("./dataset/CAVE.h5", 'r') 3 MS_train = f['train']["MS"] 4 RGB_train = f['train']["RGB"] 5 MS_test = f['test']["MS"] 6 RGB_test = f['test']["RGB"] 7 train_set = H5Dataset(RGB_train, MS_train) 8 test_set = H5Dataset(RGB_test, MS_test) 9 training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, pin_memory=True, 10 shuffle=True) 11 testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, pin_memory=True, 12 shuffle=False) 13 return training_data_loader, testing_data_loader