Pytorch训练模型常用操作
One-hot编码
将标签转换为one-hot编码形式
def to_categorical(y, num_classes):
""" 1-hot encodes a tensor """
new_y = torch.eye(num_classes)[y.cpu().data.numpy(), ]
if (y.is_cuda):
return new_y.cuda()
return new_y
- 示例
>>> y = np.array([1,2,3])
>>> y
array([1, 2, 3])
>>> torch.eye(4)[y,]
tensor([[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]])
>>> y
array([[1, 2, 2],
[1, 2, 3]])
>>> torch.eye(4)[y,]
tensor([[[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 1., 0.]],
[[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]]])
>>> torch.eye(4)[y]
tensor([1., 1., 0.])
分别初始化
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
torch.nn.init.xavier_normal_(m.weight.data)
torch.nn.init.constant_(m.bias.data, 0.0)
elif classname.find('Linear') != -1:
torch.nn.init.xavier_normal_(m.weight.data)
torch.nn.init.constant_(m.bias.data, 0.0)
classifier = classifier.apply(weights_init)
checkpoint检查是否接着训练
try:
checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
start_epoch = checkpoint['epoch']
classifier.load_state_dict(checkpoint['model_state_dict'])
log_string('Use pretrain model')
except:
log_string('No existing model, starting training from scratch...')
start_epoch = 0
根据迭代次数调整学习率
def bn_momentum_adjust(m, momentum):
if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
m.momentum = momentum
lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP)
log_string('Learning rate:%f' % lr)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))
if momentum < 0.01:
momentum = 0.01
print('BN momentum updated to: %f' % momentum)
classifier = classifier.apply(lambda x: bn_momentum_adjust(x, momentum))
classifier = classifier.train()
批量数据维度不一致
自定义torch.utils.data.Dataloader(dataset, collate_fn=collate_fn)
中的collate_fn
def my_collate_fn(batch_data):
"""
descriptions: 对齐批量数据维度, [(data, label),(data, label)...]转化成([data, data...],[label,label...])
:param batch_data: list,[(data, label),(data, label)...]
:return: tuple, ([data, data...],[label,label...])
"""
batch_data.sort(key=lambda x: len(x[0]), reverse=False) # 按照数据长度升序排序
data_list = []
cls_list = []
label_list = []
min_len = len(batch_data[0][0])
for batch in range(0, len(batch_data)):
data = batch_data[batch][0]
cls = batch_data[batch][1]
label = batch_data[batch][2]
choice = np.random.choice(range(0, len(data)), min_len, replace=False)
data = data[choice, :]
label = label[choice]
data_list.append(data)
cls_list.append(cls)
label_list.append(label)
data_tensor = torch.tensor(data_list, dtype=torch.float32)
cls_tensor = torch.tensor(cls_list, dtype=torch.float32)
label_tensor = torch.tensor(label_list, dtype=torch.float32)
data_copy = (data_tensor, cls_tensor, label_tensor)
return data_copy
分割标签分配不同权值
labelweights = np.zeros(N_Class)
tmp, _ = np.histogram(labels, range(N_Class+ 1))
labelweights += tmp
labelweights = labelweights.astype(np.float32)
labelweights = labelweights / np.sum(labelweights)
labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)
print(labelweights)
class get_loss(nn.Module):
def __init__(self):
super(get_loss, self).__init__()
def forward(self, pred, target, trans_feat, weight=None):
total_loss = F.nll_loss(pred, target, weight=weight)
return total_loss