SPL(Self-Paced Learning for Latent Variable Models)代码复现-pytorch版

  1. 起因:最近在看CL(curriculum learning)相关的文章,然后发现了SPL学习策略,简单来说就是让model学习的数据从简单容易. 看SPL相关的文章必然跳不过这篇文章:Self-Paced Learning for Latent Variable Models

  2. 困难:这篇文章写于2010年,基于pytorch实现的代码很难找,在我找了两天的情况下,终于有所发现

正题:理论部分

引入公式1:基础的公式

  1. 该公式优化的参数为:w

\[\mathbf{w}_{t+1}=\underset{\mathbf{w} \in \mathbb{R}^{d}}{\operatorname{argmin}}\left(r(\mathbf{w})+\sum_{i=1}^{n} f\left(\mathbf{x}_{i}, \mathbf{y}_{i} ; \mathbf{w}\right)\right) \]

  1. 括号中第一项r(w)函数是regularization function,下面为了方便不对该部分展开(对后续结果无影响)
  2. 括号中第二项$$\sum_{i=1}^{n} f()$$部分就是我们常写的部分,用wx做运算得到p_red,并和y用损失函数计算loss,然后参数更新...

引入公式2:基础的公式+SPL

  1. 引入v参数控制学习,v可以取0,1. 表示数据,
  2. 该公式优化的参数为:wv

\[\left(\mathbf{w}_{t+1}, \mathbf{v}_{t+1}\right)=\underset{\mathbf{w} \in \mathbb{R}^{d}, \mathbf{v} \in\{0,1\}^{n}}{\operatorname{argmin}}\left(r(\mathbf{w})+\sum_{i=1}^{n} v_{i} f\left(\mathbf{x}_{i}, \mathbf{y}_{i} ; \mathbf{w}\right)-\frac{1}{K} \sum_{i=1}^{n} v_{i}\right) \]

  1. 括号中第三项$$-\frac{1}{K} \sum_{i=1}^{n} v_{i}$$就是正则项,它也有许多变体见参考[0],我找到的资料中正则项修改成了hard版的,即$$-\lambda \sum_{i=1}^{n} v_{i}$$

引入公式3:基础的公式+SPL(变体 hard)

  1. 基于公式2,并修改其中的正则项
  2. 该公式优化的参数为:wv

\[L=r(w)+\sum_{i=1}^{n} v_{i} f\left(x_{i}, y_{i}, w\right)-\lambda \sum_{i=1}^{n} v_{I} \]

  1. 括号中第一项r(w)函数是regularization function,下面为了方便不对该部分展开(对后续结果无影响)
  2. 括号中第二项$$\sum_{i=1}^{n}v_{i} f()$$部分新增了v_i,即v. 如果取0表示数据,总体为0,即该部分对后续无影响,也就达到了的数据不学习。如果取1,则相反
  3. 括号中第三项$$-\lambda \sum_{i=1}^{n} v_{i}$$ 其中$$\lambda$$用来和loss的值对比,来确定数据的简单与否,代码解释如下
    def spl_loss(super_loss, lambda_a):
        # 如果 模型的loss < lambda --> v=1,表示该数据集简单
        # 否则                       --> v=0,表示该数据集难
        v = super_loss < lambda_a
        return v.int()

且需要随着epoch的增加而增加,实现随着epoch的增加选择的数据集越多,代码解释如下

def increase_threshold(lambda_a, growing_factor):
       lambda_a *= growing_factor
       return lambda_a

正题:部分代码架构 完整可运行代码 见参考[1]

SPL-LOSS 部分

import torch
from torch import Tensor
import torch.nn as nn


class SPLLoss(nn.NLLLoss):
    def __init__(self, *args, n_samples=0, **kwargs):
        super(SPLLoss, self).__init__(*args, **kwargs)
        self.threshold = 0.1
        self.growing_factor = 1.35
        self.v = torch.zeros(n_samples).int()

    def forward(self, input: Tensor, target: Tensor, index: Tensor) -> Tensor:
        super_loss = nn.functional.nll_loss(input, target, reduction="none")
        v = self.spl_loss(super_loss)
        self.v[index] = v
        return (super_loss * v).mean()
    # 通过增加threshold 来增加每次训练的大小
    def increase_threshold(self):
        self.threshold *= self.growing_factor

    def spl_loss(self, super_loss):
        # 如果 模型的loss < threshold --> v=1,表示该数据集简单
        # 否则                       --> v=0,表示该数据集难
        v = super_loss < self.threshold
        return v.int()

train部分

def train():
    model = Model(2, 2)
    dataloader = get_dataloader()
    criterion = SPLLoss(n_samples=len(dataloader.dataset))
    optimizer = optim.Adam(model.parameters())

    for epoch in range(10):
        for index, data, target in tqdm.tqdm(dataloader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target, index)
            loss.backward()
            optimizer.step()
        criterion.increase_threshold()
        plot(dataloader.dataset, model, criterion)

    animation = camera.animate()
    animation.save("plot.gif")

参考
[0]:束俊, 孟德宇, 徐宗本. 元自步学习. 中国科学: 信息科学, 2020, 50: 781–793, doi: 10.1360/SSI-2020-0005 Shu J, Meng D Y, Xu Z B. Meta self-paced learning (in Chinese). Sci Sin Inform, 2020, 50: 781–793, doi: 10.1360/SSI-2020-0005
[1]: GitHub

posted @ 2021-12-03 11:11  Adam_lxd  阅读(1899)  评论(0编辑  收藏  举报