基于cifar数据集合成含开集、闭集噪声的数据集

前言

噪声标签学习下的一个任务是:训练集上存在开集噪声和闭集噪声;然后在测试集上对闭集样本进行分类。

训练集中被加入的开集样本,会被均匀得打上闭集样本的标签充当开集噪声;而闭集噪声的设置与一般的噪声标签学习一致,分为对称噪声:随机将闭集样本的标签替换为其他类别;和非对称噪声:将闭集样本的标签替换为特定的类别。

论文实验中,常用cifar数据集模拟这类任务。目前已知有两类方法:

  • 第一类基于cifar100,将100个类的一部分,通常是20个类作为开集样本,将它们标签替换了前80个类作为开集噪声;然后对于后续80个类,选择部分样本设置为对称/非对称闭集噪声。CVPR2022的PNP: Robust Learning From Noisy Labels by Probabilistic Noise Prediction提供的代码中,使用了这种方法。但是,如果要考虑非对称噪声,在cifar10上就很难实现,cifar10的类的顺序不像cifar100那样有规律,不好设置闭集噪声。

  • 第二类方法适用cifar10和cifar100,保持原始数据集的样本数不变,使用额外的数据集(通常是imagenet32、places365)代替部分样本作为开集噪声,对于剩下的非开集噪声样本再设置闭集噪声。ECCV2022的Embedding contrastive unsupervised features to cluster in-and out-of-distribution noise in corrupted image datasets提供的代码使用了这种方式。

places365可以使用torchvision.datasets.Places365下载,由于训练集较大,通常是用它的验证集作为辅助数据集。

imagenet32是imagnet的32x32版本,同样是1k类,但是类的具体含义的顺序与imagenet不同,imagenet32类的具体含义可见此处。image32下载地址在对应论文A downsampled variant of imagenet as an alternative to the cifar datasets提供的链接

尝试构造一下

使用第二种方法构造含开集、闭集噪声数据集,开集噪声率\(r_{ood}=0.2\),闭集噪声率\(r_{id}=0.2\);辅助数据集使用imagenet32,基于cifar构造含开集闭集噪声的训练集。

设计imagenet32数据集

import os
import pickle
import numpy as np
from PIL import Image
from torch.utils.data import Dataset

_train_list = ['train_data_batch_1',
               'train_data_batch_2',
               'train_data_batch_3',
               'train_data_batch_4',
               'train_data_batch_5',
               'train_data_batch_6',
               'train_data_batch_7',
               'train_data_batch_8',
               'train_data_batch_9',
               'train_data_batch_10']
_val_list = ['val_data']


def get_dataset(transform_train, transform_test):
    # prepare datasets

    # Train set
    train = Imagenet32(train=True, transform=transform_train)  # Load all 1000 classes in memory

    # Test set
    test = Imagenet32(train=False, transform=transform_test)  # Load all 1000 test classes in memory

    return train, test


class Imagenet32(Dataset):
    def __init__(self, root='~/data/imagenet32', train=True, transform=None):
        if root[0] == '~':
            root = os.path.expanduser(root)
        self.transform = transform
        size = 32
        # Now load the picked numpy arrays

        if train:
            data, labels = [], []

            for f in _train_list:
                file = os.path.join(root, f)

                with open(file, 'rb') as fo:
                    entry = pickle.load(fo, encoding='latin1')
                    data.append(entry['data'])
                    labels += entry['labels']
            data = np.concatenate(data)

        else:
            f = _val_list[0]
            file = os.path.join(root, f)
            with open(file, 'rb') as fo:
                entry = pickle.load(fo, encoding='latin1')
                data = entry['data']
                labels = entry['labels']

        data = data.reshape((-1, 3, size, size))
        self.data = data.transpose((0, 2, 3, 1))  # Convert to HWC
        labels = np.array(labels) - 1
        self.labels = labels.tolist()

    def __getitem__(self, index):

        img, target = self.data[index], self.labels[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        return img, target, index

    def __len__(self):
        return len(self.data)

目录结构:

imagenet32
├─ train_data_batch_1
├─ train_data_batch_10
├─ train_data_batch_2
├─ train_data_batch_3
├─ train_data_batch_4
├─ train_data_batch_5
├─ train_data_batch_6
├─ train_data_batch_7
├─ train_data_batch_8
├─ train_data_batch_9
└─ val_data

设计cifar数据集

import torchvision
import numpy as np
from dataset.imagenet32 import Imagenet32


class CIFAR10(torchvision.datasets.CIFAR10):
    nb_classes = 10

    def __init__(self, root='~/data', train=True, transform=None,
                 r_ood=0.2, r_id=0.2, seed=0, corruption='imagenet', ):

        super().__init__(root, train=train, transform=transform)
        if train is False:
            return
        np.random.seed(seed)
        if r_ood > 0.:
            ids_ood = [i for i in range(len(self.targets)) if np.random.random() < r_ood]
            if corruption == 'imagenet':
                imagenet32 = Imagenet32(root='~/data/imagenet32', train=True)
                img_ood = imagenet32.data[np.random.permutation(range(len(imagenet32)))[:len(ids_ood)]]
            else:
                raise ValueError(f'Unknown corruption: {corruption}')
            self.ids_ood = ids_ood
            self.data[ids_ood] = img_ood

        if r_id > 0.:
            ids_not_ood = [i for i in range(len(self.targets)) if i not in ids_ood]
            ids_id = [i for i in ids_not_ood if np.random.random() < (r_id / (1 - r_ood))]
            for i, t in enumerate(self.targets):
                if i in ids_id:
                    self.targets[i] = int(np.random.random() * self.nb_classes)
            self.ids_id = ids_id


class CIFAR100(CIFAR10):
    base_folder = "cifar-100-python"
    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
    train_list = [
        ["train", "16019d7e3df5f24257cddd939b257f8d"],
    ]

    test_list = [
        ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"],
    ]
    meta = {
        "filename": "meta",
        "key": "fine_label_names",
        "md5": "7973b15100ade9c7d40fb424638fde48",
    }

    nb_classes = 100

    def __init__(self, root='~/data', train=True, transform=None,
                 r_ood=0.2, r_id=0.2, seed=0, corruption='imagenet'):
        super().__init__(root=root, train=train, transform=transform, r_ood=r_ood, r_id=r_id, seed=seed,
                         corruption=corruption)

查看统计结果

import pandas as pd
import altair as alt
from dataset.cifar import CIFAR10, CIFAR100

# Initialize CIFAR10 dataset
cifar10 = CIFAR10()
cifar100 = CIFAR100()


def statistics_samples(dataset):
    ids_ood = dataset.ids_ood
    ids_id = dataset.ids_id

    # Collect statistics
    statistics = []
    for i in range(dataset.nb_classes):
        statistics.append({
            'class': i,
            'id': 0,
            'ood': 0,
            'clear': 0
        })

    for i, t in enumerate(dataset.targets):
        if i in ids_ood:
            statistics[t]['ood'] += 1
        elif i in ids_id:
            statistics[t]['id'] += 1
        else:
            statistics[t]['clear'] += 1

    df = pd.DataFrame(statistics)

    # Melt the DataFrame for Altair
    df_melt = df.melt(id_vars='class', var_name='type', value_name='count')

    # Create the bar chart
    chart = alt.Chart(df_melt).mark_bar().encode(
        x=alt.X('class:O', title='Classes'),
        y=alt.Y('count:Q', title='Sample Count'),
        color='type:N'
    )
    return chart


chart1 = statistics_samples(cifar10)
chart2 = statistics_samples(cifar100)
chart1 = chart1.properties(
    title='cifar10',
    width=100,  # Adjust width to fit both charts side by side
    height=400
)
chart2 = chart2.properties(
    title='cifar100',
    width=800,
    height=400
)
combined_chart = alt.hconcat(chart1, chart2).configure_axis(
    labelFontSize=12,
    titleFontSize=14
).configure_legend(
    titleFontSize=14,
    labelFontSize=12
)
combined_chart

运行环境

# Name                    Version                   Build  Channel
altair                    5.3.0                    pypi_0    pypi
pytorch                   2.3.1           py3.12_cuda12.1_cudnn8_0    pytorch
pandas                    2.2.2                    pypi_0    pypi
posted @ 2024-06-29 11:45  zh-jp  阅读(86)  评论(0编辑  收藏  举报