SPL(Self-Paced Learning for Latent Variable Models)代码复现-pytorch版
目录
-
起因:最近在看CL(curriculum learning)相关的文章,然后发现了SPL学习策略,简单来说就是让model学习的数据从
简单
到容易
. 看SPL相关的文章必然跳不过这篇文章:Self-Paced Learning for Latent Variable Models -
困难:这篇文章写于2010年,基于pytorch实现的代码很难找,在我找了两天的情况下,终于有所发现
正题:理论部分
引入公式1:基础的公式
- 该公式优化的参数为:
w
\[\mathbf{w}_{t+1}=\underset{\mathbf{w} \in \mathbb{R}^{d}}{\operatorname{argmin}}\left(r(\mathbf{w})+\sum_{i=1}^{n} f\left(\mathbf{x}_{i}, \mathbf{y}_{i} ; \mathbf{w}\right)\right)
\]
- 括号中第一项
r(w)
函数是regularization function
,下面为了方便不对该部分展开(对后续结果无影响) - 括号中第二项$$\sum_{i=1}^{n} f()$$部分就是我们常写的部分,用
w
和x
做运算得到p_red
,并和y
用损失函数计算loss,然后参数更新...
引入公式2:基础的公式+SPL
- 引入
v
参数控制学习,v
可以取0
,1
. 表示数据难
,易
- 该公式优化的参数为:
w
和v
\[\left(\mathbf{w}_{t+1}, \mathbf{v}_{t+1}\right)=\underset{\mathbf{w} \in \mathbb{R}^{d}, \mathbf{v} \in\{0,1\}^{n}}{\operatorname{argmin}}\left(r(\mathbf{w})+\sum_{i=1}^{n} v_{i} f\left(\mathbf{x}_{i}, \mathbf{y}_{i} ; \mathbf{w}\right)-\frac{1}{K} \sum_{i=1}^{n} v_{i}\right)
\]
- 括号中第三项$$-\frac{1}{K} \sum_{i=1}^{n} v_{i}$$就是正则项,它也有许多变体
见参考[0]
,我找到的资料中正则项修改成了hard
版的,即$$-\lambda \sum_{i=1}^{n} v_{i}$$
引入公式3:基础的公式+SPL(变体 hard
)
- 基于公式2,并修改其中的正则项
- 该公式优化的参数为:
w
和v
\[L=r(w)+\sum_{i=1}^{n} v_{i} f\left(x_{i}, y_{i}, w\right)-\lambda \sum_{i=1}^{n} v_{I}
\]
- 括号中第一项
r(w)
函数是regularization function
,下面为了方便不对该部分展开(对后续结果无影响) - 括号中第二项$$\sum_{i=1}^{n}v_{i} f()$$部分新增了
v_i
,即v
. 如果取0
表示数据难
,总体为0,即该部分对后续无影响,也就达到了难
的数据不学习。如果取1
,则相反 - 括号中第三项$$-\lambda \sum_{i=1}^{n} v_{i}$$ 其中$$\lambda$$用来和loss的值对比,来确定数据的简单与否,代码解释如下
def spl_loss(super_loss, lambda_a):
# 如果 模型的loss < lambda --> v=1,表示该数据集简单
# 否则 --> v=0,表示该数据集难
v = super_loss < lambda_a
return v.int()
且需要随着epoch的增加而增加,实现随着epoch的增加选择的数据集越多,代码解释如下
def increase_threshold(lambda_a, growing_factor):
lambda_a *= growing_factor
return lambda_a
正题:部分代码架构 完整可运行代码 见参考[1]
SPL-LOSS 部分
import torch
from torch import Tensor
import torch.nn as nn
class SPLLoss(nn.NLLLoss):
def __init__(self, *args, n_samples=0, **kwargs):
super(SPLLoss, self).__init__(*args, **kwargs)
self.threshold = 0.1
self.growing_factor = 1.35
self.v = torch.zeros(n_samples).int()
def forward(self, input: Tensor, target: Tensor, index: Tensor) -> Tensor:
super_loss = nn.functional.nll_loss(input, target, reduction="none")
v = self.spl_loss(super_loss)
self.v[index] = v
return (super_loss * v).mean()
# 通过增加threshold 来增加每次训练的大小
def increase_threshold(self):
self.threshold *= self.growing_factor
def spl_loss(self, super_loss):
# 如果 模型的loss < threshold --> v=1,表示该数据集简单
# 否则 --> v=0,表示该数据集难
v = super_loss < self.threshold
return v.int()
train部分
def train():
model = Model(2, 2)
dataloader = get_dataloader()
criterion = SPLLoss(n_samples=len(dataloader.dataset))
optimizer = optim.Adam(model.parameters())
for epoch in range(10):
for index, data, target in tqdm.tqdm(dataloader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target, index)
loss.backward()
optimizer.step()
criterion.increase_threshold()
plot(dataloader.dataset, model, criterion)
animation = camera.animate()
animation.save("plot.gif")
参考
[0]:束俊, 孟德宇, 徐宗本. 元自步学习. 中国科学: 信息科学, 2020, 50: 781–793, doi: 10.1360/SSI-2020-0005 Shu J, Meng D Y, Xu Z B. Meta self-paced learning (in Chinese). Sci Sin Inform, 2020, 50: 781–793, doi: 10.1360/SSI-2020-0005
[1]: GitHub