2023-2-18-过采样实现

2023-2-18-过采样实现

由于类别分布不均与、导致过拟合,使用过采样方法可以有效的避免过拟合

解决方案

1、安装 torchsampler

pip install torchsampler

2、dataset实现 get_labels

def get_labels(self):
	# 一些操作、获取labels
	return labels

3、dataloader使用

from torchsampler import ImbalancedDatasetSampler
train_loader = DataLoader(train_dataset, sampler=ImbalancedDatasetSampler(train_dataset),batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True)

[!attention]
shuffle 必须为 False

完整代码

from torchsampler import ImbalancedDatasetSampler

# Create the data loaders

train_image_transform = transforms.Compose([

        # transforms.RandomResizedCrop(224),

        transforms.RandomHorizontalFlip(),

        transforms.ToTensor(),

        transforms.Resize((224, 224)),

        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])


print('Loading the data')

train_dataset = AffectNet(root_path='dataset/AffectNetdataset/', subset='train', n_expression=n_expression,
                         transform_image_shape=None, transform_image=train_image_transform)

train_loader = DataLoader(train_dataset, sampler=ImbalancedDatasetSampler(train_dataset),
                          batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True)

test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=n_workers, pin_memory=True)

posted @ 2023-02-27 19:27  cyinen  阅读(24)  评论(0编辑  收藏  举报