机器学习常见的sampling策略 附PyTorch实现
初始工作
定义一个模拟的长尾数据集
import torch import numpy as np import random from torch.utils.data import Dataset, DataLoader np.random.seed(0) random.seed(0) torch.manual_seed(0) class LongTailDataset(Dataset): def __init__(self, num_classes=25, max_samples_per_class=100): self.num_classes = num_classes self.max_samples_per_class = max_samples_per_class # Generate number of samples for each class inversely proportional to class index self.samples_per_class = [self.max_samples_per_class // (i + 1) for i in range(self.num_classes)] self.total_samples = sum(self.samples_per_class) # Generate targets for the dataset self.targets = torch.cat( [torch.full((samples,), i, dtype=torch.long) for i, samples in enumerate(self.samples_per_class)]) def __len__(self): return self.total_samples def __getitem__(self, idx): # For simplicity, just return the index as the data return idx, self.targets[idx] # Create dataset batch_size = 64 dataset = LongTailDataset() print(f'The total number of samples: {len(dataset) // 2}') print(f'The number of samples per class: {dataset.samples_per_class}') print(f'The {len(dataset) // 2} th samples of the dataset: {dataset[len(dataset) // 2]}')
Output:
The total number of samples: 187 The number of samples per class: [100, 50, 33, 25, 20, 16, 14, 12, 11, 10, 9, 8, 7, 7, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 4] The 187 th samples of the dataset: (187, tensor(3))
定义一个测试sample一个batch的函数
def test_loader_in_one_batch(test_dataloader: DataLoader, inf: str): print(inf) for (_, target) in test_dataloader: cls_idx, cls_counts = np.unique(target.numpy(), return_counts=True) cls_idx = [int(i) for i in cls_idx] cls_counts = [int(i) for i in cls_counts] print(f'Class indices: {cls_idx}') print(f'Class counts: {cls_counts}') break # just show one batch print('-' * 20)
采样介绍
每个类的采样概率可抽象为:
表示从j类采样数据的概率; 表示类别数量; 表示j类样本数;
均匀采样
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) test_loader_in_one_batch(dataloader, inf='Instance-balanced sampling(Default):')
Output:
Instance-balanced sampling(Default): Class indices: [0, 1, 2, 3, 4, 5, 6, 10, 11, 13, 15, 17, 18, 19, 20, 21, 23] Class counts: [13, 10, 4, 4, 4, 6, 5, 3, 1, 4, 3, 1, 2, 1, 1, 1, 1] --------------------
类平衡采样
实例平衡采样在不平衡的数据集中往往表现不佳,类平衡采样(Class-balanced sampling)让所有的类有相同的被采样概率(
这里具体实现使用很多论文都在使用的 Class Aware Sampler,通过循环过采样,使得batch内每个类别的样本数相等。
import random from torch.utils.data.sampler import Sampler import numpy as np class RandomCycleIter: def __init__(self, data, test_mode=False): self.data_list = list(data) self.length = len(self.data_list) self.i = self.length - 1 self.test_mode = test_mode def __iter__(self): return self def __next__(self): self.i += 1 if self.i == self.length: self.i = 0 if not self.test_mode: random.shuffle(self.data_list) return self.data_list[self.i] def class_aware_sample_generator(cls_iter, data_iter_list, n, num_samples_cls=1): i = 0 j = 0 while i < n: if j >= num_samples_cls: j = 0 if j == 0: temp_tuple = next(zip(*[data_iter_list[next(cls_iter)]] * num_samples_cls)) # next(cls_iter) 会返回一个类别的index, # data_iter_list[next(cls_iter)]会返回list,list内包括该类的所有样本的index # 用*解包上面的list,然后内部每个元素重复 num_samples_cls 次,然后用zip打包,再用next取出 yield temp_tuple[j] else: yield temp_tuple[j] i += 1 j += 1 class ClassAwareSampler(Sampler): def __init__(self, targets, num_samples_cls=1): super().__init__() num_classes = len(np.unique(targets)) self.class_iter = RandomCycleIter(range(num_classes)) # 返回一个循环迭代器,迭代器每次返回一个类的index cls_data_list = [list() for _ in range(num_classes)] # N个类,每个类对应一个list for i, label in enumerate(targets): cls_data_list[label].append(i) # 将每个样本的index按照类别放入对应的list self.data_iter_list = [RandomCycleIter(x) for x in cls_data_list] # 每个类用循环迭代器包装,返回类内sample的index self.num_samples = max([len(x) for x in cls_data_list]) * len(cls_data_list) # 总样本数 = 最大样本数的类的样本数 * 类别数 self.num_samples_cls = num_samples_cls def __iter__(self): return class_aware_sample_generator(self.class_iter, self.data_iter_list, self.num_samples, self.num_samples_cls) def __len__(self): return self.num_samples dataloader = DataLoader(dataset, batch_size=batch_size, sampler=ClassAwareSampler(dataset.targets)) test_loader_in_one_batch(dataloader, inf='Class-aware sampling:')
Output:
Class-aware sampling: Class indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] Class counts: [2, 2, 2, 3, 2, 3, 2, 3, 3, 3, 2, 3, 3, 2, 2, 3, 3, 3, 3, 3, 2, 2, 2, 3, 3] --------------------
类平衡采样的另一种写法(通过调整采样器的类的权重)
最早是把每个类的权重(采样概率)设为样本数倒数:
并用这个权重调整损失。[4]把这个权重用于采样权重,这里用PyTorch提供的WeightedRandomSampler
实现:第一个参数表示每个样本(不是类)的权重,第二个参数表示采样的样本数,第三个参数表示是否有放回采样。
from torch.utils.data.sampler import WeightedRandomSampler def imbalance_sampler(targets, mode='inverse'): cls_counts = np.bincount(targets) cls_weights = None if mode == 'inverse': cls_weights = 1. / cls_counts elif mode == 'effective': beta = (len(targets) - 1) / len(targets) cls_weights = (1.0 - beta) / (1.0 - np.power(beta, cls_counts)) assert cls_weights is not None return WeightedRandomSampler(cls_weights[targets], len(targets), replacement=True) modes = ['inverse', 'effective'] for mode in modes: sampler = imbalance_sampler(dataset.targets.numpy(), mode) dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler) test_loader_in_one_batch(dataloader, inf=f'{mode.capitalize()}:')
Output:
Inverse: Class indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24] Class counts: [1, 3, 1, 2, 5, 4, 2, 1, 3, 3, 1, 3, 3, 3, 6, 3, 3, 1, 3, 5, 1, 3, 1, 3] -------------------- Effective: Class indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20, 21, 23, 24] Class counts: [2, 3, 2, 2, 5, 1, 1, 2, 3, 3, 7, 2, 1, 3, 3, 4, 1, 4, 3, 2, 1, 3, 6] --------------------
实际中,提到类平衡采样Class-Balanced Re-Sampling,两种实现方式都有可能,注意上下文描述和参考文献的引用。
混合采样策略
最早的混合采样是在
t表示当前epoch,T表示总epoch数。
运行环境
# Name Version Build Channel pytorch 2.3.1 py3.12_cuda12.1_cudnn8_0 pytorch
参考文献
- Kang, Bingyi, et al. "Decoupling Representation and Classifier for Long-Tailed Recognition." International Conference on Learning Representations. 2019.
- torch.utils.data.WeightedRandomSampler
- Cui, Yin, et al. "Class-balanced loss based on effective number of samples." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.
- Cao, Kaidi, et al. "Learning imbalanced datasets with label-distribution-aware margin loss." Advances in neural information processing systems 32 (2019).
- Shi, Jiang-Xin, et al. "How Re-sampling Helps for Long-Tail Learning?." Advances in Neural Information Processing Systems 36 (2023).
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人