深度学习(蒸馏)

模型蒸馏是指通过训练一个小而简单的模型来复制和学习一个大模型的知识和性能。这种方法通常用于减少模型的计算资源需求,加速推理过程或者使模型适用于资源受限的设备上。

步骤如下:

1. 准备教师模型和学生模型:

  教师模型:一个复杂的模型,这里用的是resnet。

  学生模型:简化的卷积神经网络,较少的参数和层次结构。

2. 定义损失函数:

  交叉熵损失:使用Softmax激活函数输出的概率分布,以及温度参数来平衡模型的软化度。

3. 训练学生模型:

  在训练过程中,通过比较学生模型预测和教师模型预测之间的差异来优化模型参数。

4. 优化和调整:

  可以尝试不同的模型结构、损失函数设置和超参数调整来优化学生模型的性能和效率。

5. 评估和比较:

  使用测试数据集评估学生模型的性能,并与未经蒸馏的模型以及教师模型进行比较。

测试代码如下:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import os

# 设置是否使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 定义数据转换
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# 定义学生模型(简单的卷积神经网络)
class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 512, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(512)
        self.dropout1 = nn.Dropout(0.2)
        self.dropout2 = nn.Dropout(0.2)
        self.dropout3 = nn.Dropout(0.2)
        self.dropout4 = nn.Dropout(0.2)

        self.fc1 = nn.Linear(512*4*4, 256)
        self.fc2 = nn.Linear(256, 10)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.dropout1(self.relu(self.bn1(self.conv1(x)))))
        x = self.pool(self.dropout2(self.relu(self.bn2(self.conv2(x)))))
        x = self.pool(self.dropout3(self.relu(self.bn3(self.conv3(x)))))
        x = x.view(x.size(0), -1)
        x = self.dropout4(self.relu(self.fc1(x)))
        x = self.fc2(x)
        return x


# 测试模型
def test(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    accuracy = correct / total
    print(f'Accuracy on test set: {100 * accuracy:.2f}%')


def trainTecher(model,trainLoader,testloader,optimizer,criterion):

    for epoch in range(5):
        model.train()
        correct = 0
        total = 0
        for inputs, labels in trainLoader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            print(epoch,loss.item(),f" train teacher Accuracy: {(100 * correct / total):.2f}%")

        test(model,testloader)

def trainStudent(model,teacher_model, trainloader,testloader):
    
    criterion = nn.KLDivLoss()  # KL散度损失函数
    optimizer = optim.AdamW(student_model.parameters(), lr=5e-4, weight_decay=1e-3)

    for epoch in range(20):
        model.train()
        correct_stu = 0
        correct_teh = 0
        total = 0
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            teacher_outputs = teacher_model(inputs).detach()  # 使用教师模型的输出作为软标签

            loss = criterion(nn.functional.log_softmax(outputs/5, dim=1),
                             nn.functional.softmax(teacher_outputs/5, dim=1))
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total += labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            correct_stu += (predicted == labels).sum().item()

            _, predicted = torch.max(teacher_outputs.data, 1)
            correct_teh += (predicted == labels).sum().item()
            print(epoch,loss.item(),f" train student Accuracy: {(100 * correct_stu / total):.2f}%",f"{(100 * correct_teh / total):.2f}%")
        
        test(model,testloader)

if __name__ == '__main__':

    # 加载数据集
    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

    testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

    # 定义教师模型
    teacher_model = models.resnet18(pretrained=True)
    teacher_model.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)  
    teacher_model.maxpool = nn.MaxPool2d(1, 1, 0)  
    teacher_model.fc = nn.Linear(teacher_model.fc.in_features, 10)
    teacher_model.to(device)

    student_model = StudentNet()
    student_model.to(device)

    total = sum([param.nelement() for param in teacher_model.parameters()])
    print("Number of parameter: %.2fM" % (total/1e6))   

    total = sum([param.nelement() for param in student_model.parameters()])
    print("Number of parameter: %.2fM" % (total/1e6))   

    if os.path.exists('resnet.pth'):    
        teacher_model.load_state_dict(torch.load('resnet.pth'))     
        teacher_model.eval()
    else:
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(teacher_model.parameters(), lr=5e-4, weight_decay=1e-3)
        trainTecher(teacher_model,trainloader,testloader,optimizer,criterion)
        torch.save(teacher_model.state_dict(), 'resnet.pth')

    trainStudent(student_model,teacher_model, trainloader,testloader)  
posted @ 2024-08-03 10:52  Dsp Tian  阅读(22)  评论(0编辑  收藏  举报