论文解读(MAML)《Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks》

Note:[ wechat:Y466551 | 可加勿骚扰,付费咨询 ]

论文信息

论文标题:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
论文作者:Chelsea Finn、Pieter Abbeel、Sergey Levine
论文来源:2017 
论文地址:download 
论文代码:download
视屏讲解:click

1-摘要

  我们提出了一种与模型无关的元学习算法,在这个意义上,它与任何经过梯度下降训练的模型兼容,并适用于各种不同的学习问题,包括分类、回归和强化学习。元学习的目标是在各种学习任务上训练一个模型,这样它就可以只使用少量的训练样本来解决新的学习任务。在我们的方法中,模型的参数被明确地训练,这样少量的梯度步长和来自新任务的少量训练数据将在该任务上产生良好的泛化性能。实际上,我们的方法训练的模型易于微调。我们证明了这种方法在两个低镜头图像分类基准上取得了最先进的性能,在少镜头回归上产生了良好的结果,并加速了使用神经网络策略对策略梯度强化学习的微调。

2-方法

  

  代码:

def maml_train(model, support_images, support_labels, query_images, query_labels, inner_step, args, optimizer, is_train=True):
    meta_loss = []
    meta_acc = []
    for support_image, support_label, query_image, query_label in zip(support_images, support_labels, query_images, query_labels):
        fast_weights = collections.OrderedDict(model.named_parameters())
        for _ in range(inner_step):  #inner_step = 1
            # Update weight
            support_logit = model.functional_forward(support_image, fast_weights)
            support_loss = nn.CrossEntropyLoss().cuda()(support_logit, support_label)
            grads = torch.autograd.grad(support_loss, fast_weights.values(), create_graph=True)
            fast_weights = collections.OrderedDict((name, param - args.inner_lr * grad) for ((name, param), grad) in zip(fast_weights.items(), grads))

        # Use trained weight to get query loss
        query_logit = model.functional_forward(query_image, fast_weights)
        query_prediction = torch.max(query_logit, dim=1)[1]
        query_loss = nn.CrossEntropyLoss().cuda()(query_logit, query_label)
        query_acc = torch.eq(query_label, query_prediction).sum() / len(query_label)
        meta_loss.append(query_loss)
        meta_acc.append(query_acc.data.cpu().numpy())

    # Zero the gradient
    optimizer.zero_grad()
    meta_loss = torch.stack(meta_loss).mean()
    meta_acc = np.mean(meta_acc)

    if is_train:
        meta_loss.backward()
        optimizer.step()

    return meta_loss, meta_acc

 

posted @ 2024-04-27 21:11  多发Paper哈  阅读(40)  评论(0编辑  收藏  举报
Live2D