论文阅读《Pruning Attention Heads of Transformer Models Using A* Search》
1. 现有算法的局限性
局部剪枝
遍历所有head,找出“剪掉这个head后精确度下降最小”的head,进行剪枝。虽然是精确度最高的剪枝算法,但剪枝过程开销太大。
全局剪枝
每次剪枝时,从每一层的某一个位置剪掉一个head。开销很小,但精确度不高。
2. A*算法概述
算法描述
是最小路径算法的一种,通过维护两个值:到起点的代价g(n)和到终点的预计代价h(n)(也叫做启发函数)
其中,f(n) = g(n) + h(n), f(n) 是节点n的综合优先级。
同时有了两个集合:待遍历的节点和已经遍过的节点。算法的具体流程是:
-初始化两个集合,简称为go和pass
-将起点加入go,并设置优先级为0(最高)
-如果go不为空,则从go里选出优先级最高的节点,判断:
-----if 节点n为终点
---------从终点开始追踪parent,找到起点,得出路径,算法结束
-----if 节点n不为终点
---------节点n从go删除,加入pass
---------遍历n的所有邻近节点
---------if n的邻近节点m在pass里面
-------------遍历过了,跳过
-------------if 邻近节点m也不在go中
-------------设置m的parent为n,计算m的优先级,放入go
算法思想
A*算法引入启发函数(预计代价),为优化方法提供在目标函数中添加启发式的思路。
3. A*算法应用于剪枝
如果是局部剪枝,计算步骤:
给定一个最大能接受的损失B
- 计算一个head被剪掉以后模型的accuracy P
- 计算相对于剪枝之前的accuracy损失 C
- 剪掉C最小的head
- 对剩余的head重复2,3操作,直到B被用完
如果是A*剪枝:
引入启发式H来估计下一次修剪head时,修剪剩余头部的成本
预测,在下一次迭代时,剪掉一个头部的开销和这一次迭代相当。但由于下一次迭代的开销一定大于这一次迭代。,所以不用担心超出了budget。
下面是算法的通俗版本解释:
(初始化已经被剪的集合L)
设置预算B
设置可剪区间S
B和S都大于0时,(开始循环)计算P和C;对C进行排序;找到C最小的head X
if X的C小于0,将其C设置为0;否则,将它剪掉,放入L,更新B和S
更新后如果S>0,初始化T(用于后续统计剪掉剩余head的总开销)
对S中的所有节点:用L集合中节点的开销作为估计值E,用E和已经剪完的C做差得到I。这个I就是心得节点cost的估计值。如果T + I <= B,更新T,否则开始移除一些不可能被剪掉的head(B不够用来剪他们了)
4. 实验结果
剪枝速度很快,相比于局部剪枝开销很小,而精确度比全局剪枝高很多。同时,有一些节点在被减去后,模型accuracy反而上升。
个人理解是,全局剪枝一定能维持模型的accuracy处于最高水平。A*算法的应用是在尽可能少的阻止accuracy降低的同时,大幅度减少时间开销。