2023-2-18-过采样实现
2023-2-18-过采样实现
由于类别分布不均与、导致过拟合,使用过采样方法可以有效的避免过拟合
解决方案
1、安装 torchsampler
pip install torchsampler
2、dataset实现 get_labels
def get_labels(self):
# 一些操作、获取labels
return labels
3、dataloader使用
from torchsampler import ImbalancedDatasetSampler
train_loader = DataLoader(train_dataset, sampler=ImbalancedDatasetSampler(train_dataset),batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True)
[!attention]
shuffle 必须为 False
完整代码
from torchsampler import ImbalancedDatasetSampler
# Create the data loaders
train_image_transform = transforms.Compose([
# transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Resize((224, 224)),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
print('Loading the data')
train_dataset = AffectNet(root_path='dataset/AffectNetdataset/', subset='train', n_expression=n_expression,
transform_image_shape=None, transform_image=train_image_transform)
train_loader = DataLoader(train_dataset, sampler=ImbalancedDatasetSampler(train_dataset),
batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=n_workers, pin_memory=True)