深度学习模型压缩-知识蒸馏工程实践
学生模型以较少的参数学习老师的分布,在老师的知道下获得性能提升,可以作为模型压缩的一种思路,示例代码如下:
""" Function:knowledge distillation """ import math import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.data from torchvision import datasets, transforms import matplotlib.pyplot as plt torch.manual_seed(0) # torch.cuda.manual_seed(0) # 定义教师网络 class TeacherNet(nn.Module): def __init__(self): super(TeacherNet, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout2d(0.3) self.dropout2 = nn.Dropout2d(0.5) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = F.max_pool2d(x,2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout2(x) output = self.fc2(x) return output # 训练过程 def train_teacher(model, device, train_loader, optimizer, epoch): # 启用 BatchNormalization 和 Dropout model.train() trained_samples = 0 for batch_idx, (data, target) in enumerate(train_loader): # 搬到指定gpu或者cpu设备上运算 data, target = data.to(device), target.to(device) # 梯度清零 optimizer.zero_grad() # 前向传播 output = model(data) # 计算误差 loss = F.cross_entropy(output, target) # 误差反向传播 loss.backward() # 梯度更新一步 optimizer.step() # 统计已经训练的数据量 trained_samples += len(data) progress = math.ceil(batch_idx / len(train_loader) * 50) print('\rTrain epoch: {} {}/{} [{}]{}%'.format(epoch, trained_samples, len(train_loader.dataset), '-'*progress+'>', progress*2), end='') # 测试过程 def test_teacher(model, device, test_loader): # 不启用 BatchNormalization 和 Dropout model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) test_loss += F.cross_entropy(output, target, reduction='sum').item() # 输出预测类别 pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) print('\nTest: average loss: {:.4f}, accuracy:{}/{},({:.0f}%)'.format( test_loss, correct,len(test_loader.dataset), 100* correct /len(test_loader.dataset) )) return test_loss, correct / len(test_loader.dataset) def teacher_main(): epochs = 3 batch_size = 64 torch.manual_seed(0) mnist_path = 'D:\\my_AI\\MNIST' # 动态设置硬件设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_loader = torch.utils.data.DataLoader( datasets.MNIST(mnist_path, train=True, download=False, transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,)) ])), batch_size=batch_size, shuffle=True ) test_loader = torch.utils.data.DataLoader( datasets.MNIST(mnist_path, train=False, download=False, transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,)) ])), batch_size=1000, shuffle=True ) # 实例化模型 model = TeacherNet().to(device) # 选取优化器 optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) teacher_history = [] for epoch in range(1, epochs+1): print(epoch) train_teacher(model, device, train_loader, optimizer, epoch) loss, acc = test_teacher(model, device, test_loader) teacher_history.append((loss, acc)) # 保存模型,state_dict:Returns a dictionary containing a whole state of the module. torch.save(model.state_dict(), 'model/teacher.pt') return model, teacher_history # 构建学生网络 class Student(nn.Module): def __init__(self): super(Student, self).__init__() self.fc1 = nn.Linear(28*28, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 10) def forward(self, x): x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) output = F.relu(self.fc3(x)) return output # 蒸馏部分:定义kd的loss def distillation(y, labels, teacher_scores, temp, alpha): """ :param y: 学生预测的概率分布 :param labels: 实际标签 :param teacher_scores: 老师预测的概率分布 :param temp: 温度系数 :param alpha: 损失调整因子 :return: """ kl_stu_tea = nn.KLDivLoss()(F.log_softmax(y / temp, dim=1), F.softmax(teacher_scores / temp, dim=1)) * temp * temp * 2.0 * alpha stu_loss = F.cross_entropy(y, labels) * (1-alpha) return kl_stu_tea+stu_loss # 训练学生网络 def train_student_kd(model, teacher_model, device, train_loader, optimizer, epoch): model.train() trained_samples = 0 for batch_idx, (data, target) in enumerate(train_loader): # 搬到指定gpu或者cpu设备上运算 data, target = data.to(device), target.to(device) # 梯度清零 optimizer.zero_grad() # 前向传播 output = model(data) # 老师输出 teacher_output = teacher_model(data) # 计算误差 loss = distillation(output, target, teacher_output, temp=5., alpha=.7) # 误差反向 loss.backward() # 梯度更新一步 optimizer.step() # 统计已经训练的数据量 trained_samples += len(data) progress = math.ceil(batch_idx / len(train_loader) * 50) print('\rTrain epoch: {} {}/{} [{}]{}%'.format(epoch, trained_samples, len(train_loader.dataset), '-'*progress+'>', progress*2), end='') # 测试学生网络 def test_student_kd(model, device, test_loader): # 不启用 BatchNormalization 和 Dropout model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) test_loss += F.cross_entropy(output, target, reduction='sum').item() # 输出预测类别 pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) print('\nTest: average loss: {:.4f}, accuracy:{}/{},({:.0f}%)'.format( test_loss, correct,len(test_loader.dataset), 100* correct /len(test_loader.dataset) )) return test_loss, correct / len(test_loader.dataset) def student_kd_main(teacher_m): epochs = 2 batch_size = 64 torch.manual_seed(0) mnist_path = 'D:\\my_AI\\MNIST' # 动态设置硬件设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_loader = torch.utils.data.DataLoader( datasets.MNIST(mnist_path, train=True, download=False, transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,)) ])), batch_size=batch_size, shuffle=True ) test_loader = torch.utils.data.DataLoader( datasets.MNIST(mnist_path, train=False, download=False, transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,)) ])), batch_size=1000, shuffle=True ) # 实例化模型 model = Student().to(device) # 选取优化器 optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) student_history = [] for epoch in range(1, epochs+1): print(epoch) train_student_kd(model, teacher_m, device, train_loader, optimizer, epoch) loss, acc = test_student_kd(model, device, test_loader) student_history.append((loss, acc)) # 保存模型,state_dict:Returns a dictionary containing a whole state of the module. torch.save(model.state_dict(), 'model/student.pt') return model, student_history # 让学生自己学,不使用KD def train_student(model, device, train_loader, optimizer, epoch): model.train() trained_samples = 0 for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.cross_entropy(output, target) loss.backward() optimizer.step() trained_samples += len(data) progress = math.ceil(batch_idx / len(train_loader) * 50) print('\rTrain epoch: {} {}/{} [{}]{}%'.format(epoch, trained_samples, len(train_loader.dataset), '-'*progress+'>', progress*2), end='') def test_student(model, device, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) return test_loss, correct / len(test_loader.dataset) def student_main(): epochs = 10 batch_size = 64 torch.manual_seed(0) mnist_path = 'D:\\my_AI\\MNIST' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_loader = torch.utils.data.DataLoader( datasets.MNIST(mnist_path, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader( datasets.MNIST(mnist_path, train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=1000, shuffle=True) model = Student().to(device) optimizer = torch.optim.Adadelta(model.parameters()) student_history = [] for epoch in range(1, epochs + 1): train_student(model, device, train_loader, optimizer, epoch) loss, acc = test_student(model, device, test_loader) student_history.append((loss, acc)) torch.save(model.state_dict(), "student.pt") return model, student_history if __name__ == '__main__': teacher_model, teacher_history = teacher_main() student_kd_model, student_kd_history = student_kd_main(teacher_model) student_simple_model, student_simple_history = student_main() # 三个模型的loss和acc分析 epochs = 10 x = list(range(1, epochs + 1)) plt.subplot(2, 1, 1) plt.plot(x, [teacher_history[i][1] for i in range(epochs)], label='teacher') plt.plot(x, [student_kd_history[i][1] for i in range(epochs)], label='student with KD') plt.plot(x, [student_simple_history[i][1] for i in range(epochs)], label='student without KD') plt.title('Test accuracy') plt.legend() plt.subplot(2, 1, 2) plt.plot(x, [teacher_history[i][0] for i in range(epochs)], label='teacher') plt.plot(x, [student_kd_history[i][0] for i in range(epochs)], label='student with KD') plt.plot(x, [student_simple_history[i][0] for i in range(epochs)], label='student without KD') plt.title('Test loss') plt.legend()
模型分析对比,可以看到在有老师知道下的学生模型student_kd在acc和loss的表现上比单纯自己训练的要好的多
from matplotlib import pyplot as plt epoch = 10 x = list(range(1,epoch+1)) # print(teacher_history) plt.subplot(2,1,1) plt.plot(x, [teacher_history[i][1] for i in range(epoch)], label='teacher') plt.plot(x, [sutdent_history_kd[i][1] for i in range(epoch)], label='student_kd') plt.plot(x, [student_simple_history[i][1] for i in range(epoch)], label='student_simple') plt.title('Test Accuracy') plt.legend() plt.subplot(2,1,2) plt.plot(x, [teacher_history[i][0] for i in range(epoch)], label='teacher') plt.plot(x, [sutdent_history_kd[i][0] for i in range(epoch)], label='student_kd') plt.plot(x, [student_simple_history[i][0] for i in range(epoch)], label='student_simple') plt.title('Test Loss') plt.legend()
时刻记着自己要成为什么样的人!