噪声标签学习方式之一:噪声转移矩阵估计

基本原理

样本的干净标签后验概率\(P(\mathbf{Y}|X=\mathbf{x})\),可通过噪声标签的后验概率\(P(\bar{\mathbf{Y}}|X = \mathbf{x})\)和噪声转移矩阵\(T(\mathbf{x})\)得到,即:

\[P(\bar{\mathbf{Y}}|X=\mathbf{x})=T(\mathbf{x})P(\mathbf{Y}|X=\mathbf{x}) \]

其中\(T_{ij}(\mathbf{x}) = P(\bar{Y} = j|Y = i,X = \mathbf{x})\)

通常,转移矩阵\(T\)是不可识别的,并且在没有额外假设的情况下很难学习。因此实际上,噪声标签问题下,使用噪声转移矩阵估计的方式较少,本文只讨论最简单的噪声转移矩阵估计的形式。

代码层面上

转移矩阵由一个 \(C\times C\)的矩阵表示,其中\(C\)是类别数目。转移矩阵的参数随着模型训练更新。

import torch
import torch.nn as nn
import torch.nn.functional as F


class TransitionMatrix(nn.Module):
    def __init__(self, num_classes, device='cpu'):
        super().__init__()
        if num_classes == 10:
            init = -2
        else:
            init = -4.5
        
        w = torch.ones([num_classes, num_classes]) * init
        self.register_parameter(name="w", param=nn.parameter.Parameter(w))
        self.w.to(device)

        self.identity = torch.eye(num_classes).to(device)

        self.coeff = torch.ones([num_classes, num_classes]) - torch.eye(num_classes)
        self.coeff = self.coeff.to(device)

    def forward(self):
        sig = torch.sigmoid(self.w)
        T = self.identity.detach() + sig * self.coeff.detach()
        T = F.normalize(T, p=1, dim=1)
        return T

在训练过程中,使用噪声转移矩阵对模型输出调整,然后计算损失。注意,此处model输出为类别概率分布,也就是经过softmax后的logits

...
transition_matrix = TransitionMatrix(num_classes=num_classes, device=device)
for epoch in range(EPOCH):
    transition_matrix.train()
    ...
    for index, (batch_x, batch_y) in loop:
        ...
        clean = model(batch_x)
        t_hat = transition_matrix()
        y_tilde = torch.mm(clean, t_hat)
        vol_loss = torch.abs(t_hat.slogdet().logabsdet)
        ce_loss = loss_func_ce(y_tilde.log(), batch_y.long())
        loss = ce_loss + opt.lam * vol_loss
        ...
...

依赖:

torch                     2.4.1

参考文献

  1. noise-transition-matrix
posted @ 2024-10-24 16:25  October-  阅读(27)  评论(0编辑  收藏  举报