Note:[ wechat:Y466551 | 可加勿骚扰,付费咨询 ]

论文信息

论文标题:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
论文作者:Chelsea Finn、Pieter Abbeel、Sergey Levine
论文来源:2017 
论文地址:download 
论文代码:download
视屏讲解:click

1-摘要

  我们提出了一种与模型无关的元学习算法,在这个意义上,它与任何经过梯度下降训练的模型兼容,并适用于各种不同的学习问题,包括分类、回归和强化学习。元学习的目标是在各种学习任务上训练一个模型,这样它就可以只使用少量的训练样本来解决新的学习任务。在我们的方法中,模型的参数被明确地训练,这样少量的梯度步长和来自新任务的少量训练数据将在该任务上产生良好的泛化性能。实际上,我们的方法训练的模型易于微调。我们证明了这种方法在两个低镜头图像分类基准上取得了最先进的性能,在少镜头回归上产生了良好的结果,并加速了使用神经网络策略对策略梯度强化学习的微调。

2-方法

  

  代码:

复制代码
def maml_train(model, support_images, support_labels, query_images, query_labels, inner_step, args, optimizer, is_train=True):
    meta_loss = []
    meta_acc = []
    for support_image, support_label, query_image, query_label in zip(support_images, support_labels, query_images, query_labels):
        fast_weights = collections.OrderedDict(model.named_parameters())
        for _ in range(inner_step):  #inner_step = 1
            # Update weight
            support_logit = model.functional_forward(support_image, fast_weights)
            support_loss = nn.CrossEntropyLoss().cuda()(support_logit, support_label)
            grads = torch.autograd.grad(support_loss, fast_weights.values(), create_graph=True)
            fast_weights = collections.OrderedDict((name, param - args.inner_lr * grad) for ((name, param), grad) in zip(fast_weights.items(), grads))

        # Use trained weight to get query loss
        query_logit = model.functional_forward(query_image, fast_weights)
        query_prediction = torch.max(query_logit, dim=1)[1]
        query_loss = nn.CrossEntropyLoss().cuda()(query_logit, query_label)
        query_acc = torch.eq(query_label, query_prediction).sum() / len(query_label)
        meta_loss.append(query_loss)
        meta_acc.append(query_acc.data.cpu().numpy())

    # Zero the gradient
    optimizer.zero_grad()
    meta_loss = torch.stack(meta_loss).mean()
    meta_acc = np.mean(meta_acc)

    if is_train:
        meta_loss.backward()
        optimizer.step()

    return meta_loss, meta_acc
复制代码

 

posted @   别关注我了,私信我吧  阅读(180)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
历史上的今天:
2022-04-27 论文解读(SUBG-CON)《Sub-graph Contrast for Scalable Self-Supervised Graph Representation Learning》
Live2D
点击右上角即可分享
微信分享提示