Pytorch训练模型常用操作

One-hot编码

将标签转换为one-hot编码形式

def to_categorical(y, num_classes):
    """ 1-hot encodes a tensor """
    new_y = torch.eye(num_classes)[y.cpu().data.numpy(), ]
    if (y.is_cuda):
        return new_y.cuda()
    return new_y
  • 示例
>>> y = np.array([1,2,3])
>>> y
array([1, 2, 3])
>>> torch.eye(4)[y,]
tensor([[0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])

>>> y
array([[1, 2, 2],
       [1, 2, 3]])
>>> torch.eye(4)[y,]
tensor([[[0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 1., 0.]],

        [[0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.]]])
>>> torch.eye(4)[y]
tensor([1., 1., 0.])

分别初始化

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
      torch.nn.init.xavier_normal_(m.weight.data)
      torch.nn.init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
      torch.nn.init.xavier_normal_(m.weight.data)
      torch.nn.init.constant_(m.bias.data, 0.0)

classifier = classifier.apply(weights_init)

checkpoint检查是否接着训练

try:
    checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
    start_epoch = checkpoint['epoch']
    classifier.load_state_dict(checkpoint['model_state_dict'])
    log_string('Use pretrain model')
except:
    log_string('No existing model, starting training from scratch...')
    start_epoch = 0

根据迭代次数调整学习率


def bn_momentum_adjust(m, momentum):
    if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
        m.momentum = momentum

lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP)
log_string('Learning rate:%f' % lr)
for param_group in optimizer.param_groups:
    param_group['lr'] = lr
momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))
if momentum < 0.01:
    momentum = 0.01
print('BN momentum updated to: %f' % momentum)
classifier = classifier.apply(lambda x: bn_momentum_adjust(x, momentum))
classifier = classifier.train()

批量数据维度不一致

自定义torch.utils.data.Dataloader(dataset, collate_fn=collate_fn)中的collate_fn

def my_collate_fn(batch_data):
    """
    descriptions: 对齐批量数据维度, [(data, label),(data, label)...]转化成([data, data...],[label,label...])
    :param batch_data:  list,[(data, label),(data, label)...]
    :return: tuple, ([data, data...],[label,label...])
    """
    batch_data.sort(key=lambda x: len(x[0]), reverse=False)  # 按照数据长度升序排序
    data_list = []
    cls_list = []
    label_list = []
    min_len = len(batch_data[0][0])
    for batch in range(0, len(batch_data)):
        data = batch_data[batch][0]
        cls = batch_data[batch][1]
        label = batch_data[batch][2]

        choice = np.random.choice(range(0, len(data)), min_len, replace=False)
        data = data[choice, :]
        label = label[choice]

        data_list.append(data)
        cls_list.append(cls)
        label_list.append(label)

    data_tensor = torch.tensor(data_list, dtype=torch.float32)
    cls_tensor = torch.tensor(cls_list, dtype=torch.float32)
    label_tensor = torch.tensor(label_list, dtype=torch.float32)
    data_copy = (data_tensor, cls_tensor, label_tensor)
    return data_copy

分割标签分配不同权值

labelweights = np.zeros(N_Class)
tmp, _ = np.histogram(labels, range(N_Class+ 1))
labelweights += tmp

labelweights = labelweights.astype(np.float32)
labelweights = labelweights / np.sum(labelweights)
labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)
print(labelweights)

class get_loss(nn.Module):
    def __init__(self):
        super(get_loss, self).__init__()

    def forward(self, pred, target, trans_feat, weight=None):
        total_loss = F.nll_loss(pred, target, weight=weight)

        return total_loss

posted @ 2021-10-21 10:05  半夜打老虎  阅读(467)  评论(0编辑  收藏  举报