机器学习常见的sampling策略 附PyTorch实现
初始工作
定义一个模拟的长尾数据集
import torch
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
class LongTailDataset(Dataset):
def __init__(self, num_classes=25, max_samples_per_class=100):
self.num_classes = num_classes
self.max_samples_per_class = max_samples_per_class
# Generate number of samples for each class inversely proportional to class index
self.samples_per_class = [self.max_samples_per_class // (i + 1) for i in range(self.num_classes)]
self.total_samples = sum(self.samples_per_class)
# Generate targets for the dataset
self.targets = torch.cat(
[torch.full((samples,), i, dtype=torch.long) for i, samples in enumerate(self.samples_per_class)])
def __len__(self):
return self.total_samples
def __getitem__(self, idx):
# For simplicity, just return the index as the data
return idx, self.targets[idx]
# Create dataset
batch_size = 64
dataset = LongTailDataset()
print(f'The total number of samples: {len(dataset) // 2}')
print(f'The number of samples per class: {dataset.samples_per_class}')
print(f'The {len(dataset) // 2} th samples of the dataset: {dataset[len(dataset) // 2]}')
Output:
The total number of samples: 187
The number of samples per class: [100, 50, 33, 25, 20, 16, 14, 12, 11, 10, 9, 8, 7, 7, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 4]
The 187 th samples of the dataset: (187, tensor(3))
定义一个测试sample一个batch的函数
def test_loader_in_one_batch(test_dataloader: DataLoader, inf: str):
print(inf)
for (_, target) in test_dataloader:
cls_idx, cls_counts = np.unique(target.numpy(), return_counts=True)
cls_idx = [int(i) for i in cls_idx]
cls_counts = [int(i) for i in cls_counts]
print(f'Class indices: {cls_idx}')
print(f'Class counts: {cls_counts}')
break # just show one batch
print('-' * 20)
采样介绍
每个类的采样概率可抽象为:\(p_j=\frac{n_j^q}{\sum_{i=1}^Cn_i^q}\),
- \(p_j\)表示从j类采样数据的概率;
- \(C\)表示类别数量;\(n_j\)表示j类样本数;
- \(q\in\{1,0\}\)
均匀采样
\(q=1\),实例平衡采样(Instance-balanced sampling)(也称uniform sampling),最常见的数据采样方式,每个训练样本被选择的概率相等均为\(\frac{1}{N}\)。对j类的采样,按数据集中j类的基数\(n_j\)进行采样,即\(p^{\mathbf{IB}}_j=\frac{n_j}{\sum_{i=1}^Cn_i}\)。
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
test_loader_in_one_batch(dataloader, inf='Instance-balanced sampling(Default):')
Output:
Instance-balanced sampling(Default):
Class indices: [0, 1, 2, 3, 4, 5, 6, 10, 11, 13, 15, 17, 18, 19, 20, 21, 23]
Class counts: [13, 10, 4, 4, 4, 6, 5, 3, 1, 4, 3, 1, 2, 1, 1, 1, 1]
--------------------
类平衡采样
实例平衡采样在不平衡的数据集中往往表现不佳,类平衡采样(Class-balanced sampling)让所有的类有相同的被采样概率(\(q=0\)):\(p^{\mathbf{CB}}_j=\frac{1}{C}\)。采样可分为两个阶段:1. 从类集中统一选择一个类;2. 对该类中的实例进行统一采样。
这里具体实现使用很多论文都在使用的 Class Aware Sampler,通过循环过采样,使得batch内每个类别的样本数相等。
import random
from torch.utils.data.sampler import Sampler
import numpy as np
class RandomCycleIter:
def __init__(self, data, test_mode=False):
self.data_list = list(data)
self.length = len(self.data_list)
self.i = self.length - 1
self.test_mode = test_mode
def __iter__(self):
return self
def __next__(self):
self.i += 1
if self.i == self.length:
self.i = 0
if not self.test_mode:
random.shuffle(self.data_list)
return self.data_list[self.i]
def class_aware_sample_generator(cls_iter, data_iter_list, n, num_samples_cls=1):
i = 0
j = 0
while i < n:
if j >= num_samples_cls:
j = 0
if j == 0:
temp_tuple = next(zip(*[data_iter_list[next(cls_iter)]] * num_samples_cls))
# next(cls_iter) 会返回一个类别的index,
# data_iter_list[next(cls_iter)]会返回list,list内包括该类的所有样本的index
# 用*解包上面的list,然后内部每个元素重复 num_samples_cls 次,然后用zip打包,再用next取出
yield temp_tuple[j]
else:
yield temp_tuple[j]
i += 1
j += 1
class ClassAwareSampler(Sampler):
def __init__(self, targets, num_samples_cls=1):
super().__init__()
num_classes = len(np.unique(targets))
self.class_iter = RandomCycleIter(range(num_classes)) # 返回一个循环迭代器,迭代器每次返回一个类的index
cls_data_list = [list() for _ in range(num_classes)] # N个类,每个类对应一个list
for i, label in enumerate(targets):
cls_data_list[label].append(i) # 将每个样本的index按照类别放入对应的list
self.data_iter_list = [RandomCycleIter(x) for x in cls_data_list] # 每个类用循环迭代器包装,返回类内sample的index
self.num_samples = max([len(x) for x in cls_data_list]) * len(cls_data_list) # 总样本数 = 最大样本数的类的样本数 * 类别数
self.num_samples_cls = num_samples_cls
def __iter__(self):
return class_aware_sample_generator(self.class_iter, self.data_iter_list,
self.num_samples, self.num_samples_cls)
def __len__(self):
return self.num_samples
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=ClassAwareSampler(dataset.targets))
test_loader_in_one_batch(dataloader, inf='Class-aware sampling:')
Output:
Class-aware sampling:
Class indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
Class counts: [2, 2, 2, 3, 2, 3, 2, 3, 3, 3, 2, 3, 3, 2, 2, 3, 3, 3, 3, 3, 2, 2, 2, 3, 3]
--------------------
类平衡采样的另一种写法(通过调整采样器的类的权重)
最早是把每个类的权重(采样概率)设为样本数倒数:\(p_j=\frac{1}{n_j}\)。[3]提出effective number,对每个类的权重(effective number)调整为:
并用这个权重调整损失。[4]把这个权重用于采样权重,这里用PyTorch提供的WeightedRandomSampler
实现:第一个参数表示每个样本(不是类)的权重,第二个参数表示采样的样本数,第三个参数表示是否有放回采样。
from torch.utils.data.sampler import WeightedRandomSampler
def imbalance_sampler(targets, mode='inverse'):
cls_counts = np.bincount(targets)
cls_weights = None
if mode == 'inverse':
cls_weights = 1. / cls_counts
elif mode == 'effective':
beta = (len(targets) - 1) / len(targets)
cls_weights = (1.0 - beta) / (1.0 - np.power(beta, cls_counts))
assert cls_weights is not None
return WeightedRandomSampler(cls_weights[targets], len(targets), replacement=True)
modes = ['inverse', 'effective']
for mode in modes:
sampler = imbalance_sampler(dataset.targets.numpy(), mode)
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
test_loader_in_one_batch(dataloader, inf=f'{mode.capitalize()}:')
Output:
Inverse:
Class indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24]
Class counts: [1, 3, 1, 2, 5, 4, 2, 1, 3, 3, 1, 3, 3, 3, 6, 3, 3, 1, 3, 5, 1, 3, 1, 3]
--------------------
Effective:
Class indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20, 21, 23, 24]
Class counts: [2, 3, 2, 2, 5, 1, 1, 2, 3, 3, 7, 2, 1, 3, 3, 4, 1, 4, 3, 2, 1, 3, 6]
--------------------
实际中,提到类平衡采样Class-Balanced Re-Sampling,两种实现方式都有可能,注意上下文描述和参考文献的引用。
混合采样策略
最早的混合采样是在 \(0\le epoch\le t\)时采用Instance-balanced采样,\(t\le epoch\le T\)时采用Class-balanced采样,这需要设置合适的超参数t。在[1]中,作者提出了soft版本的混合采样策略:Progressively-balanced sampling。随着epoch的增加每个类的采样概率(权重)\(p_j\)也发生变化:
t表示当前epoch,T表示总epoch数。
运行环境
# Name Version Build Channel
pytorch 2.3.1 py3.12_cuda12.1_cudnn8_0 pytorch
参考文献
- Kang, Bingyi, et al. "Decoupling Representation and Classifier for Long-Tailed Recognition." International Conference on Learning Representations. 2019.
- torch.utils.data.WeightedRandomSampler
- Cui, Yin, et al. "Class-balanced loss based on effective number of samples." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.
- Cao, Kaidi, et al. "Learning imbalanced datasets with label-distribution-aware margin loss." Advances in neural information processing systems 32 (2019).
- Shi, Jiang-Xin, et al. "How Re-sampling Helps for Long-Tail Learning?." Advances in Neural Information Processing Systems 36 (2023).