使用cifar100上训练的resnet18进行ood测试
以cifar100作为闭集(closed-set)数据集,使用resnet18模型进行训练,然后在常见的开集(out-of-distribution)数据集上进行OOD检测。使用MSP(Maximum Softmax Probability)作为OOD检测的依据。
开集噪声数据集使用gaussian, rademacher, blob, svhn四种类型。其中gaussian、rademacher、blob是生成的随机噪声,svhn是额外引入的噪声数据集。
输出结果
Error Rate 46.3000
AUROC: 81.9790, AUPR: 85.7377, FPR95: 73.3909
ood type: gaussian
AUROC: 68.1596, AUPR: 92.9277, FPR95: 99.4000
ood type: rademacher
AUROC: 69.9099, AUPR: 93.1788, FPR95: 96.1500
ood type: blob
AUROC: 68.0615, AUPR: 92.7477, FPR95: 97.5500
ood type: svhn
AUROC: 66.9684, AUPR: 91.6508, FPR95: 89.0500
可以看到,在使用简单的交叉熵损失且不经过其他处理的resnet18,在开集检测上的表示并不好。
闭集数据集上训练一个resnet18
# train.py
import torch
from torch.optim.lr_scheduler import MultiStepLR
from torchvision.datasets.cifar import CIFAR100
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
from torchvision.models import resnet18
from sklearn.metrics import accuracy_score
import torch.nn.functional as F
def get_transform(train=True):
mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]
if train:
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
else:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
return transform
def get_loader(train=True):
transform = get_transform(train)
dataset = CIFAR100(root='~/data', train=train, transform=transform)
loader = DataLoader(dataset, batch_size=128, shuffle=train, num_workers=8, pin_memory=True)
return loader
def train_model():
loader = get_loader(train=True)
test_loader = get_loader(train=False)
model = resnet18(num_classes=100)
model = model.cuda()
epochs = 100
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = MultiStepLR(optimizer, milestones=[50, 75], gamma=0.1)
model.eval()
all_preds = []
all_labels = []
for epoch in range(epochs):
model.train()
print('Training')
for i, (inputs, labels) in enumerate(loader):
inputs, labels = inputs.cuda(), labels.cuda()
outputs = model(inputs)
loss = F.cross_entropy(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 100 == 0:
print(f'Epoch[{epoch}] Iter: {i}/{len(loader)} Loss: {loss.item()}')
scheduler.step()
print('Testing')
for inputs, labels in test_loader:
inputs = inputs.cuda()
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.numpy())
accuracy = accuracy_score(all_labels, all_preds)
print(f'Epoch[{epoch}] acc@1 {accuracy:.4f}')
torch.save(model.state_dict(), 'cifar100_resnet18.pth')
if __name__ == '__main__':
train_model()
构建常用的开集数据集
# ood_data.py
import torch
import numpy as np
from torch.utils.data.dataset import TensorDataset
from torch.utils.data.dataloader import DataLoader
from skimage.filters import gaussian
from torchvision.datasets import SVHN
from train import get_transform
def build_ood_loader(noise_type, ood_num_examples, batch_size, worker):
dummy_targets = torch.ones(ood_num_examples)
if noise_type in ['gaussian', 'rademacher', 'blob']:
if noise_type == 'gaussian':
ood_data = torch.from_numpy(np.float32(np.clip(
np.random.normal(size=(ood_num_examples, 3, 32, 32), scale=0.5), -1, 1)))
elif noise_type == 'rademacher':
ood_data = torch.from_numpy(np.random.binomial(
n=1, p=0.5, size=(ood_num_examples, 3, 32, 32)).astype(np.float32)) * 2 - 1
else:
ood_data = np.float32(np.random.binomial(n=1, p=0.7, size=(ood_num_examples, 32, 32, 3)))
for i in range(ood_num_examples):
ood_data[i] = gaussian(ood_data[i], sigma=1.5)
ood_data[i][ood_data[i] < 0.75] = 0.0
ood_data = torch.from_numpy(ood_data.transpose((0, 3, 1, 2))) * 2 - 1
dataset = TensorDataset(ood_data, dummy_targets)
elif noise_type == 'svhn':
transform = get_transform(train=False)
dataset = SVHN(root='~/data/svhn', split='test', transform=transform, download=True)
data = dataset.data[:ood_num_examples]
dataset.data = data
else:
raise ValueError(f'Unknown noise type: {noise_type}')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
num_workers=worker, pin_memory=True)
return dataloader
使用常见的OOD检测评估指标
# ood_utils.py
import torch
import numpy as np
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve
@torch.no_grad()
def get_ood_scores(model, dataloader, closed_set=False):
model.eval()
scores = []
right_scores = []
wrong_scores = []
for i, (data, targets) in enumerate(dataloader):
data = data.cuda()
output = model(data)
smax = F.softmax(output, dim=1).cpu().numpy()
scores.append(np.max(smax, axis=1))
if closed_set:
pred = np.argmax(smax, axis=1)
targets = targets.numpy().squeeze()
right_indices = pred == targets
wrong_indices = np.invert(right_indices)
right_scores.append(np.max(smax[right_indices], axis=1))
wrong_scores.append(np.max(smax[wrong_indices], axis=1))
if closed_set:
return (np.concatenate(scores),
np.concatenate(right_scores),
np.concatenate(wrong_scores))
else:
return np.concatenate(scores)
def get_performance(pos, neg):
pos = np.array(pos).reshape(-1)
neg = np.array(neg).reshape(-1)
scores = np.concatenate([pos, neg])
labels = [1] * len(pos) + [0] * len(neg)
auroc = roc_auc_score(labels, scores)
aupr = average_precision_score(labels, scores)
fpr, tpr, _ = roc_curve(labels, scores)
fpr95 = fpr[np.argmax(tpr >= 0.95)]
return auroc, aupr, fpr95
def show_performance(pos, neg):
auroc, aupr, fpr95 = get_performance(pos, neg)
print(f"AUROC: {auroc * 100:.4f}, AUPR: {aupr * 100:.4f}, FPR95: {fpr95 * 100:.4f}")
测试模型的OOD检测性能
# test.py
import torch
from torchvision.models import resnet18
from train import get_loader
from ood_utils import get_ood_scores, show_performance
from ood_data import build_ood_loader
def evaluate():
model = resnet18(num_classes=100)
model.load_state_dict(torch.load('cifar100_resnet18.pth'))
model = model.cuda()
# closed-set test
test_loader = get_loader(train=False)
in_score, right_score, wrong_score = get_ood_scores(model, test_loader, True)
num_right, num_wrong = len(right_score), len(wrong_score)
print(f'Error Rate {100 * num_wrong / (num_right + num_wrong):.4f}')
show_performance(right_score, wrong_score)
# open-set test
ood_num_examples = len(test_loader.dataset) // 5
ood_types = ['gaussian', 'rademacher', 'blob', 'svhn']
for i in ood_types:
print(f'ood type: {i}')
ood_loader = build_ood_loader(i, ood_num_examples, batch_size=128, worker=8)
out_score = get_ood_scores(model, ood_loader)
show_performance(in_score, out_score)
if __name__ == '__main__':
evaluate()
依赖
scikit-learn 1.5.2
scipy 1.14.1
torch 2.4.1