机器学习常见的sampling策略 附PyTorch实现

初始工作

定义一个模拟的长尾数据集

import torch
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader


np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
class LongTailDataset(Dataset):
    def __init__(self, num_classes=25, max_samples_per_class=100):
        self.num_classes = num_classes
        self.max_samples_per_class = max_samples_per_class

        # Generate number of samples for each class inversely proportional to class index
        self.samples_per_class = [self.max_samples_per_class // (i + 1) for i in range(self.num_classes)]
        self.total_samples = sum(self.samples_per_class)

        # Generate targets for the dataset
        self.targets = torch.cat(
            [torch.full((samples,), i, dtype=torch.long) for i, samples in enumerate(self.samples_per_class)])

    def __len__(self):
        return self.total_samples

    def __getitem__(self, idx):
        # For simplicity, just return the index as the data
        return idx, self.targets[idx]


# Create dataset
batch_size = 64
dataset = LongTailDataset()
print(f'The total number of samples: {len(dataset) // 2}')
print(f'The number of samples per class: {dataset.samples_per_class}')
print(f'The {len(dataset) // 2} th samples of the dataset: {dataset[len(dataset) // 2]}')

Output:

The total number of samples: 187
The number of samples per class: [100, 50, 33, 25, 20, 16, 14, 12, 11, 10, 9, 8, 7, 7, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 4]
The 187 th samples of the dataset: (187, tensor(3))

定义一个测试sample一个batch的函数

def test_loader_in_one_batch(test_dataloader: DataLoader, inf: str):
    print(inf)
    for (_, target) in test_dataloader:
        cls_idx, cls_counts = np.unique(target.numpy(), return_counts=True)
        cls_idx = [int(i) for i in cls_idx]
        cls_counts = [int(i) for i in cls_counts]
        print(f'Class indices: {cls_idx}')
        print(f'Class counts: {cls_counts}')
        break  # just show one batch
    print('-' * 20)

采样介绍

每个类的采样概率可抽象为:\(p_j=\frac{n_j^q}{\sum_{i=1}^Cn_i^q}\)

  • \(p_j\)表示从j类采样数据的概率;
  • \(C\)表示类别数量;\(n_j\)表示j类样本数;
  • \(q\in\{1,0\}\)

均匀采样

\(q=1\),实例平衡采样(Instance-balanced sampling)(也称uniform sampling),最常见的数据采样方式,每个训练样本被选择的概率相等均为\(\frac{1}{N}\)。对j类的采样,按数据集中j类的基数\(n_j\)进行采样,即\(p^{\mathbf{IB}}_j=\frac{n_j}{\sum_{i=1}^Cn_i}\)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
test_loader_in_one_batch(dataloader, inf='Instance-balanced sampling(Default):')

Output:

Instance-balanced sampling(Default):
Class indices: [0, 1, 2, 3, 4, 5, 6, 10, 11, 13, 15, 17, 18, 19, 20, 21, 23]
Class counts: [13, 10, 4, 4, 4, 6, 5, 3, 1, 4, 3, 1, 2, 1, 1, 1, 1]
--------------------

类平衡采样

实例平衡采样在不平衡的数据集中往往表现不佳,类平衡采样(Class-balanced sampling)让所有的类有相同的被采样概率(\(q=0\)):\(p^{\mathbf{CB}}_j=\frac{1}{C}\)。采样可分为两个阶段:1. 从类集中统一选择一个类;2. 对该类中的实例进行统一采样。

这里具体实现使用很多论文都在使用的 Class Aware Sampler,通过循环过采样,使得batch内每个类别的样本数相等。

import random
from torch.utils.data.sampler import Sampler
import numpy as np


class RandomCycleIter:

    def __init__(self, data, test_mode=False):
        self.data_list = list(data)
        self.length = len(self.data_list)
        self.i = self.length - 1
        self.test_mode = test_mode

    def __iter__(self):
        return self

    def __next__(self):
        self.i += 1

        if self.i == self.length:
            self.i = 0
            if not self.test_mode:
                random.shuffle(self.data_list)

        return self.data_list[self.i]


def class_aware_sample_generator(cls_iter, data_iter_list, n, num_samples_cls=1):
    i = 0
    j = 0
    while i < n:
        if j >= num_samples_cls:
            j = 0

        if j == 0:
            temp_tuple = next(zip(*[data_iter_list[next(cls_iter)]] * num_samples_cls))
            # next(cls_iter) 会返回一个类别的index,
            # data_iter_list[next(cls_iter)]会返回list,list内包括该类的所有样本的index
            # 用*解包上面的list,然后内部每个元素重复 num_samples_cls 次,然后用zip打包,再用next取出
            yield temp_tuple[j]
        else:
            yield temp_tuple[j]

        i += 1
        j += 1


class ClassAwareSampler(Sampler):

    def __init__(self, targets, num_samples_cls=1):
        super().__init__()
        num_classes = len(np.unique(targets))
        self.class_iter = RandomCycleIter(range(num_classes))  # 返回一个循环迭代器,迭代器每次返回一个类的index
        cls_data_list = [list() for _ in range(num_classes)]  # N个类,每个类对应一个list
        for i, label in enumerate(targets):
            cls_data_list[label].append(i)  # 将每个样本的index按照类别放入对应的list
        self.data_iter_list = [RandomCycleIter(x) for x in cls_data_list]  # 每个类用循环迭代器包装,返回类内sample的index
        self.num_samples = max([len(x) for x in cls_data_list]) * len(cls_data_list)  # 总样本数 = 最大样本数的类的样本数 * 类别数
        self.num_samples_cls = num_samples_cls

    def __iter__(self):
        return class_aware_sample_generator(self.class_iter, self.data_iter_list,
                                            self.num_samples, self.num_samples_cls)

    def __len__(self):
        return self.num_samples

    
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=ClassAwareSampler(dataset.targets))
test_loader_in_one_batch(dataloader, inf='Class-aware sampling:')

Output:

Class-aware sampling:
Class indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
Class counts: [2, 2, 2, 3, 2, 3, 2, 3, 3, 3, 2, 3, 3, 2, 2, 3, 3, 3, 3, 3, 2, 2, 2, 3, 3]
--------------------

类平衡采样的另一种写法(通过调整采样器的类的权重)

最早是把每个类的权重(采样概率)设为样本数倒数:\(p_j=\frac{1}{n_j}\)。[3]提出effective number,对每个类的权重(effective number)调整为:

\[E_n=(1-\beta^n)/(1-\beta),\ \mathrm{where~}\beta=(N-1)/N. \]

并用这个权重调整损失。[4]把这个权重用于采样权重,这里用PyTorch提供的WeightedRandomSampler实现:第一个参数表示每个样本(不是类)的权重,第二个参数表示采样的样本数,第三个参数表示是否有放回采样。

from torch.utils.data.sampler import WeightedRandomSampler

def imbalance_sampler(targets, mode='inverse'):
    cls_counts = np.bincount(targets)
    cls_weights = None
    if mode == 'inverse':
        cls_weights = 1. / cls_counts
    elif mode == 'effective':
        beta = (len(targets) - 1) / len(targets)
        cls_weights = (1.0 - beta) / (1.0 - np.power(beta, cls_counts))
    assert cls_weights is not None
    return WeightedRandomSampler(cls_weights[targets], len(targets), replacement=True)

modes = ['inverse', 'effective']
for mode in modes:
    sampler = imbalance_sampler(dataset.targets.numpy(), mode)
    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    test_loader_in_one_batch(dataloader, inf=f'{mode.capitalize()}:')

Output:

Inverse:
Class indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24]
Class counts: [1, 3, 1, 2, 5, 4, 2, 1, 3, 3, 1, 3, 3, 3, 6, 3, 3, 1, 3, 5, 1, 3, 1, 3]
--------------------
Effective:
Class indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20, 21, 23, 24]
Class counts: [2, 3, 2, 2, 5, 1, 1, 2, 3, 3, 7, 2, 1, 3, 3, 4, 1, 4, 3, 2, 1, 3, 6]
--------------------

实际中,提到类平衡采样Class-Balanced Re-Sampling,两种实现方式都有可能,注意上下文描述和参考文献的引用。

混合采样策略

最早的混合采样是在 \(0\le epoch\le t\)时采用Instance-balanced采样,\(t\le epoch\le T\)时采用Class-balanced采样,这需要设置合适的超参数t。在[1]中,作者提出了soft版本的混合采样策略:Progressively-balanced sampling。随着epoch的增加每个类的采样概率(权重)\(p_j\)也发生变化:

\[p_j^{\mathbf{PB}}(t)=(1-\frac tT)p_j^{\mathbf{IB}}+\frac tTp_j^{\mathbf{CB}} \]

t表示当前epoch,T表示总epoch数。

运行环境

# Name                    Version                   Build  Channel
pytorch                   2.3.1           py3.12_cuda12.1_cudnn8_0    pytorch

参考文献

  1. Kang, Bingyi, et al. "Decoupling Representation and Classifier for Long-Tailed Recognition." International Conference on Learning Representations. 2019.
  2. torch.utils.data.WeightedRandomSampler
  3. Cui, Yin, et al. "Class-balanced loss based on effective number of samples." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.
  4. Cao, Kaidi, et al. "Learning imbalanced datasets with label-distribution-aware margin loss." Advances in neural information processing systems 32 (2019).
  5. Shi, Jiang-Xin, et al. "How Re-sampling Helps for Long-Tail Learning?." Advances in Neural Information Processing Systems 36 (2023).
posted @ 2024-04-09 21:07  October-  阅读(426)  评论(0编辑  收藏  举报