论文阅读《Pruning Attention Heads of Transformer Models Using A* Search》

1. 现有算法的局限性

局部剪枝

遍历所有head,找出“剪掉这个head后精确度下降最小”的head,进行剪枝。虽然是精确度最高的剪枝算法,但剪枝过程开销太大。

全局剪枝

每次剪枝时,从每一层的某一个位置剪掉一个head。开销很小,但精确度不高。

2. A*算法概述

算法描述

是最小路径算法的一种,通过维护两个值:到起点的代价g(n)和到终点的预计代价h(n)(也叫做启发函数)

其中,f(n) = g(n) + h(n), f(n) 是节点n的综合优先级。

同时有了两个集合:待遍历的节点已经遍过的节点。算法的具体流程是:

-初始化两个集合,简称为go和pass
-将起点加入go,并设置优先级为0(最高)
-如果go不为空,则从go里选出优先级最高的节点,判断:
-----if 节点n为终点
---------从终点开始追踪parent,找到起点,得出路径,算法结束
-----if 节点n不为终点
---------节点n从go删除,加入pass
---------遍历n的所有邻近节点
---------if n的邻近节点m在pass里面
-------------遍历过了,跳过
-------------if 邻近节点m也不在go中
-------------设置m的parent为n,计算m的优先级,放入go

算法思想

A*算法引入启发函数(预计代价),为优化方法提供在目标函数中添加启发式的思路。

3. A*算法应用于剪枝

如果是局部剪枝,计算步骤:

给定一个最大能接受的损失B

  1. 计算一个head被剪掉以后模型的accuracy P
  2. 计算相对于剪枝之前的accuracy损失 C
  3. 剪掉C最小的head
  4. 对剩余的head重复2,3操作,直到B被用完

如果是A*剪枝:

引入启发式H来估计下一次修剪head时,修剪剩余头部的成本

预测,在下一次迭代时,剪掉一个头部的开销和这一次迭代相当。但由于下一次迭代的开销一定大于这一次迭代。,所以不用担心超出了budget。

下面是算法的通俗版本解释:

(初始化已经被剪的集合L)

设置预算B

设置可剪区间S

B和S都大于0时,(开始循环)计算P和C;对C进行排序;找到C最小的head X

if X的C小于0,将其C设置为0;否则,将它剪掉,放入L,更新B和S

更新后如果S>0,初始化T(用于后续统计剪掉剩余head的总开销)

对S中的所有节点:用L集合中节点的开销作为估计值E,用E和已经剪完的C做差得到I。这个I就是心得节点cost的估计值。如果T + I <= B,更新T,否则开始移除一些不可能被剪掉的head(B不够用来剪他们了)

4. 实验结果

剪枝速度很快,相比于局部剪枝开销很小,而精确度比全局剪枝高很多。同时,有一些节点在被减去后,模型accuracy反而上升。

个人理解是,全局剪枝一定能维持模型的accuracy处于最高水平。A*算法的应用是在尽可能少的阻止accuracy降低的同时,大幅度减少时间开销。

posted @ 2022-02-08 21:18  PaB式乌龙茶  阅读(155)  评论(0编辑  收藏  举报