深度学习(VIT)
将Transformer引入图像领域之作,学习一下。
网络结构:
VIT结构有几个关键的地方:
1. 图像分块:输入图像被划分为固定大小的非重叠小块(patches),每个小块被展平并线性嵌入到一个固定维度的向量中。这里是将32x32的图像划分成4x4的小块,总共会有16个小块,每个小块有64维向量。
2. 位置编码:由于Transformer不具备位置敏感性,需要添加位置编码来提供位置信息。每个图像块向量都会加上一个对应的可学习的位置编码,以保留图像空间信息。
3. Transformer编码:嵌入向量连同位置编码一起被输入到Transformer编码器中,编码器由多个相同自注意力层堆叠而成。
4. MLP分类。
测试代码如下:
import torch import torch.nn as nn import torch.optim as optim from torchvision import transforms from torchvision.datasets import CIFAR10 import torchvision.models class EmbedLayer(nn.Module): def __init__(self,channels, embed_dim,img_size,patch_size): super().__init__() self.embed_dim = embed_dim self.conv1 = nn.Conv2d(channels, embed_dim, patch_size, patch_size) self.pos_embedding = nn.Parameter(torch.zeros(1, (img_size // patch_size) ** 2,embed_dim), requires_grad=True) # Positional Embedding def forward(self, x): x = self.conv1(x) x = x.reshape([x.shape[0], self.embed_dim, -1]) x = x.transpose(1, 2) x = x + self.pos_embedding return x class SelfAttention(nn.Module): def __init__(self,embed_dim, heads): super().__init__() self.heads = heads self.embed_dim = embed_dim self.head_embed_dim = self.embed_dim // heads self.queries = nn.Linear(self.embed_dim, self.head_embed_dim * heads, bias=True) self.keys = nn.Linear(self.embed_dim, self.head_embed_dim * heads, bias=True) self.values = nn.Linear(self.embed_dim, self.head_embed_dim * heads, bias=True) def forward(self, x): m, s, e = x.shape q = self.queries(x).reshape(m, s, self.heads, self.head_embed_dim).transpose(1, 2) k = self.keys(x).reshape(m, s, self.heads, self.head_embed_dim).transpose(1, 2) v = self.values(x).reshape(m, s, self.heads, self.head_embed_dim).transpose(1, 2) q = q.reshape([-1, s, self.head_embed_dim]) k = k.reshape([-1, s, self.head_embed_dim]) v = v.reshape([-1, s, self.head_embed_dim]) k = k.transpose(1, 2) x_attention = q.bmm(k) x_attention = torch.softmax(x_attention, dim=-1) x = x_attention.bmm(v) x = x.reshape([-1, self.heads, s, self.head_embed_dim]) x = x.transpose(1, 2) x = x.reshape(m, s, e) return x class Encoder(nn.Module): def __init__(self, embed_dim,heads): super().__init__() self.attention = SelfAttention(embed_dim,heads) self.fc1 = nn.Linear(embed_dim, embed_dim * 2) self.activation = nn.GELU() self.fc2 = nn.Linear(embed_dim * 2,embed_dim) self.norm1 = nn.LayerNorm(embed_dim) self.norm2 = nn.LayerNorm(embed_dim) def forward(self, x): x = x + self.attention(self.norm1(x)) x = x + self.fc2(self.activation(self.fc1(self.norm2(x)))) return x class Classifier(nn.Module): def __init__(self, embed_dim,num_patches,classes): super().__init__() self.fc1 = nn.Linear(embed_dim*num_patches, embed_dim) self.activation = nn.Tanh() self.fc2 = nn.Linear(embed_dim, classes) def forward(self, x): x = x.view(x.shape[0],-1) x = self.fc1(x) x = self.activation(x) x = self.fc2(x) return x class VisionTransformer(nn.Module): def __init__(self,channels, embed_dim,n_layers,heads,img_size,patch_size,classes): super().__init__() self.embedding = EmbedLayer(channels,embed_dim,img_size,patch_size) self.encoder = nn.Sequential(*[Encoder(embed_dim,heads) for _ in range(n_layers)], nn.LayerNorm(embed_dim)) self.norm = nn.LayerNorm(embed_dim) self.classifier = Classifier(embed_dim,(img_size//patch_size)**2,classes) def forward(self, x): x = self.embedding(x) x = self.encoder(x) x = self.norm(x) x = self.classifier(x) return x if __name__ == '__main__': device = torch.device("cuda") trainTransforms = transforms.Compose([ transforms.ToTensor() , transforms.RandomHorizontalFlip(p=0.5) , transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) testTransforms = transforms.Compose([ transforms.ToTensor() , transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) trainset = CIFAR10(root='./data', train=True, download=True, transform=trainTransforms) trainLoader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True) testset = CIFAR10(root='./data', train=False,download=False, transform=testTransforms) testLoader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False) model = VisionTransformer(channels=3, embed_dim=128,n_layers=6,heads=8,img_size=32,patch_size=8,classes=10) # model = torchvision.models.resnet18(pretrained=True) # model.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False) # model.maxpool = nn.MaxPool2d(1, 1, 0) # model.fc = nn.Linear(model.fc.in_features, 10) criterion = nn.CrossEntropyLoss() optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-3) cos_decay = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100, verbose=True) model.to(device) for epoch in range(50): print("epoch :",epoch) model.train() correct = 0 total = 0 for images, labels in trainLoader: images = images.to(device) labels = labels.to(device) outputs = model(images) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(loss.item(),f" train Accuracy: {(100 * correct / total):.2f}%") cos_decay.step() model.eval() with torch.no_grad(): correct = 0 total = 0 for images,labels in testLoader: images = images.to(device) labels = labels.to(device) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f"test Accuracy: {(100 * correct / total):.2f}%") # 保存模型 torch.save(model.state_dict(), 'vit.pth')