深度学习| 通过蒸馏收敛一个更优模型部署
蒸馏收敛
基于keras的知识蒸馏(Knowledge Distillation)-分类与回归
如果通过蒸馏收敛到一个更优的的部署模型
Knowledge Distillation Introduction to Knowledge Distillation
知识提取是一种模型压缩过程,其中对小(学生)模型进行训练,以匹配预先训练的大(教师)模型。通过最小化损失函数,将知识从教师模型转移到学生身上,目的是匹配软化的教师逻辑和基本事实
标签。通过在softmax中应用“温度”标度函数来软化logits,有效地平滑了概率分布,并揭示了教师学习到的课堂间关系。
导入基础包
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
构造Distiller类
自定义Distiller()类覆盖Model方法train_step、test_step和compile()。为使用蒸馏器,我们需要:
训练有素的教师模型
要训练的学生模型
关于学生预测和基本事实之间差异的学生损失函数
关于学生软预测和教师软标签之间差异的蒸馏损失函数以及温度
衡量学生体重和蒸馏损失的阿尔法因素
针对学生的优化器和(可选)评估绩效的指标
在train_step方法中,我们执行教师和学生的前向传递,分别通过α和1-alpha对student_loss和distraction_loss进行加权来计算损失,并执行后向传递。注意:只有学生权重会更新,因此我们只计算学生权重的梯度。
在test_step方法中,我们在提供的数据集上评估学生模型。
class Distiller(keras.Model):
def __init__(self, student, teacher):
super().__init__()
self.teacher = teacher
self.student = student
def compile(
self,
optimizer,
metrics,
student_loss_fn,
distillation_loss_fn,
alpha=0.1,
temperature=3,
):
""" Configure the distiller.
Args:
optimizer: Keras optimizer for the student weights
metrics: Keras metrics for evaluation
student_loss_fn: Loss function of difference between student
predictions and ground-truth
distillation_loss_fn: Loss function of difference between soft
student predictions and soft teacher predictions
alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
temperature: Temperature for softening probability distributions.
Larger temperature gives softer distributions.
"""
super().compile(optimizer=optimizer, metrics=metrics)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
self.alpha = alpha
self.temperature = temperature
def train_step(self, data):
# Unpack data 解析数据
x, y = data
# Forward pass of teacher 前向传递
teacher_predictions = self.teacher(x, training=False)
with tf.GradientTape() as tape:
# Forward pass of student 前向传递
student_predictions = self.student(x, training=True)
# Compute losses
student_loss = self.student_loss_fn(y, student_predictions)
# Compute scaled distillation loss from https://arxiv.org/abs/1503.02531
# The magnitudes of the gradients produced by the soft targets scale
# as 1/T^2, multiply them by T^2 when using both hard and soft targets.
distillation_loss = (
self.distillation_loss_fn(
tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
tf.nn.softmax(student_predictions / self.temperature, axis=1),
)
* self.temperature**2
)
# Total loss: alpha*hard loss + (1-alpha)*soft loss
loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
# Compute gradients
trainable_vars = self.student.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update the metrics configured in `compile()`.
self.compiled_metrics.update_state(y, student_predictions)
# Return a dict of performance
results = {m.name: m.result() for m in self.metrics}
results.update(
{"student_loss": student_loss, "distillation_loss": distillation_loss}
)
return results
def test_step(self, data):
# Unpack the data
x, y = data
# Compute predictions
y_prediction = self.student(x, training=False)
# Calculate the loss
student_loss = self.student_loss_fn(y, y_prediction)
# Update the metrics.
self.compiled_metrics.update_state(y, y_prediction)
# Return a dict of performance
results = {m.name: m.result() for m in self.metrics}
results.update({"student_loss": student_loss})
return results
创建学生和教师模型
首先,创建一个教师模型和一个较小的学生模型。这两个模型都是卷积神经网络,使用Sequential()创建,也可以是其他Keras模型。
# Create the teacher
teacher = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
layers.Flatten(),
layers.Dense(10),
],
name="teacher",
)
# Create the student
student = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
layers.Flatten(),
layers.Dense(10),
],
name="student",
)
# Clone student for later comparison
student_scratch = keras.models.clone_model(student)
准备数据集
用于训练教师和提取教师的数据集是MNIST,并且该过程对于任何其他数据集都是等效的,例如CIFAR-10,只要选择合适的模型。学生和老师都在训练集上接受训练,并在测试集上进行评估
# Prepare the train and test dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))
x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
训练教师模型
在知识提炼中,我们假设老师是经过训练和固定的。因此,我们从以通常的方式在训练集上训练教师模型开始。
# Train teacher as usual
teacher.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=5)
teacher.evaluate(x_test, y_test)
Epoch 1/5 1875/1875 [==============================] - 162s 86ms/step - loss: 0.1438 - sparse_categorical_accuracy: 0.9553 Epoch 2/5 1875/1875 [==============================] - 172s 92ms/step - loss: 0.0905 - sparse_categorical_accuracy: 0.9732 Epoch 3/5 1875/1875 [==============================] - 172s 92ms/step - loss: 0.0798 - sparse_categorical_accuracy: 0.9768 Epoch 4/5 1875/1875 [==============================] - 171s 91ms/step - loss: 0.0767 - sparse_categorical_accuracy: 0.9785 Epoch 5/5 1875/1875 [==============================] - 179s 95ms/step - loss: 0.0699 - sparse_categorical_accuracy: 0.9808 313/313 [==============================] - 6s 20ms/step - loss: 0.0894 - sparse_categorical_accuracy: 0.9763 [0.08935610204935074, 0.9763000011444092]
从老师蒸馏到学生
已经训练了教师模型,只需要初始化Distiller(学生,教师)实例,用所需的损失、超参数和优化器对其进行compile(),并将教师提取给学生。从头开始训练学生进行比较;
# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
optimizer=keras.optimizers.Adam(),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),#需要进行回归的时候可相应替换损失函数
distillation_loss_fn=keras.losses.KLDivergence(),
alpha=0.1,
temperature=10,
)
# Distill teacher to student
distiller.fit(x_train, y_train, epochs=3)
# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)
结果数据如下
Epoch 1/3 1875/1875 [==============================] - 37s 19ms/step - sparse_categorical_accuracy: 0.8863 - student_loss: 0.5352 - distillation_loss: 8.6172 Epoch 2/3 1875/1875 [==============================] - 37s 20ms/step - sparse_categorical_accuracy: 0.9647 - student_loss: 0.1374 - distillation_loss: 1.8981 Epoch 3/3 1875/1875 [==============================] - 38s 20ms/step - sparse_categorical_accuracy: 0.9718 - student_loss: 0.1047 - distillation_loss: 1.2105 313/313 [==============================] - 1s 2ms/step - sparse_categorical_accuracy: 0.9732 - student_loss: 0.1035
[0.9732000231742859, 0.0381324402987957]
从头开始训练学生进行比较
#Train student model from scratch for comparison
student_scratch.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate student trained from scratch.
student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)
#student(train from scratch) accuracy: 0.9778
#0.9896 VS. 0.9778
Epoch 1/3 1875/1875 [==============================] - 7s 4ms/step - loss: 0.0680 - sparse_categorical_accuracy: 0.9791 Epoch 2/3 1875/1875 [==============================] - 7s 4ms/step - loss: 0.0597 - sparse_categorical_accuracy: 0.9819 Epoch 3/3 1875/1875 [==============================] - 7s 4ms/step - loss: 0.0545 - sparse_categorical_accuracy: 0.9829 313/313 [==============================] - 1s 2ms/step - loss: 0.0640 - sparse_categorical_accuracy: 0.9797
[0.06404071301221848, 0.9797000288963318]
如果老师接受了5个epoch的训练,而学生在这个老师身上被提炼了3个epoch,那么在这个例子中,与从头开始训练相同的学生模型相比,甚至与老师本身相比,应该会体验到一种成绩提升。
应该期望老师的准确率在97.6%左右,从头开始训练的学生的准确率应该在97.6%附近,蒸馏的学生应该在98.1%左右。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· C#/.NET/.NET Core优秀项目和框架2025年2月简报
· 葡萄城 AI 搜索升级:DeepSeek 加持,客户体验更智能
· 什么是nginx的强缓存和协商缓存
· 一文读懂知识蒸馏
2019-04-14 Spark |05 SparkStreaming
2018-04-14 第六章|网络编程-socket开发