使用t-SNE可视化CIFAR-10的表征
t-SNE理论相关理论可参见t-SNE 算法。本文通过PyTorch提供的预训练Resnet50提取CIFAR-10表征,并使用t-SNE进行可视化。
加载预训练Resnet50
import torch
from torchvision.models import resnet50, ResNet50_Weights
# 加载ResNet模型
resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
# 移除最后一层全连接层
resnet_fe = torch.nn.Sequential(*(list(resnet.children())[:-1]))
resnet_fe.cuda()
resnet_fe.eval()
加载CIFAR-10数据集
from torchvision.datasets import CIFAR10
from torchvision import transforms
transformer = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
dataset = CIFAR10(root='./data', train=True, download=True, transform=transformer)
提取CIFAR-10表征
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
features = []
labels = []
for i, (x, y) in enumerate(dataloader):
x = x.cuda()
with torch.no_grad():
feature = resnet_fe(x) # feature shape: (batch_size, 512, 1, 1)
feature = feature.view(feature.size(0), -1).cpu() # feature shape: (batch_size, 512)
for f,l in zip(feature,y):
features.append(f.numpy())
labels.append(l.numpy())
训练t-SNE
from sklearn.manifold import TSNE
import numpy as np
features = np.array(features)
labels = np.array(labels)
tsne = TSNE(n_components=2, random_state=0).fit_transform(X=features)
可视化
import altair as alt
import pandas as pd
# 提取 x 和 y 坐标
k = 5000 # 否则会报错:MaxRowsError: The number of rows in your dataset is greater than the maximum allowed (5000).
label = labels[:k]
x = tsne[:k, 0]
y = tsne[:k:, 1]
# 创建 DataFrame
df = pd.DataFrame({'x': x, 'y': y, 'label': label})
# 创建散点图
chart = alt.Chart(df).mark_point(filled=True).encode(x="x", y="y", color="label:N").properties(width=400, height=400)
chart = chart.configure_axis(
disable=True, # 禁用坐标轴
)
chart
参考文献
运行环境
# Name Version Build Channel
altair 5.0.1 py312haa95532_0
jupyter 1.0.0 py312haa95532_9
pandas 2.2.1 py312h0158946_0
pytorch 2.2.2 py3.12_cuda12.1_cudnn8_0 pytorch
scikit-learn 1.3.0 py312hc7c4135_2
有时候,PyCharm 2024.1 (Professional Edition)运行的Jupyter,altair的图显示不出来,可以用浏览器打开.ipynb文件查看。