Scalar evolution技术与i^n求和优化
(如果不想看一开始的引例,想直接看Scalar evolution,可以直接跳过这个“引例”部分。)
引例
考虑\(i^3\)求和
其C语言代码为
#include <stdio.h>
int main() {
int n, s = 0;
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
s += i * i * i;
}
printf("%d\n", s);
return 0;
}
非常神奇的是,这个算法本身的时间复杂度是\(O(n)\)的。但是观察打开\(O3\)优化下的汇编指令可以发现,编译器将该算法优化成了\(O(1)\)复杂度的直接公式求解。本文将介绍分析该优化的具体表现,并介绍该优化的本质——标量演化编译优化技术。值得一提的是,虽然我们人为知道\(\sum i^3\)的求和公式为\((\frac{n(n+1)}{2})^2\),但是编译器并不是靠”死记公式”去进行优化,而是以一种具有普遍意义的优化方式,循环幂次求和只是这种优化的很多种适用对象之一。
使用clang -S -O3观察其在O3优化下产生的汇编码,其计算\(s\)的主要部分为
leal -1(%rax), %ecx
leal -2(%rax), %edx
imulq %rcx, %rdx
leal -3(%rax), %ecx
imulq %rdx, %rcx
shrq %rdx
leal (%rdx,%rdx,8), %esi
leal (%rdx,%rsi,2), %edx
leal (%rdx,%rax,8), %edx
addl $-4, %eax
imulq %rcx, %rax
shrq %rcx
leal (%rcx,%rcx,2), %ecx
leal (%rdx,%rcx,2), %ecx
shrq $2, %rax
andl $-2, %eax
leal -7(%rax,%rcx), %ebx
可以分析出输入的n被存储在寄存器rax当中。n占据4个字节,所以其实eax当中已经被存有完整的n。我们用\(N\)来表示rax寄存器的值,于是我们可以假设
然后首先分析前两行,
leal -1(%rax), %ecx
leal -2(%rax), %edx
64位x86处理器的寻址模式如下图所示
这里GPR表示通用寄存器(General Purpose Register)。而lea指令表示取有效地址,将有效地址的运算结果存储到一个寄存器中。所以,lea字面意思是取有效地址,实际上还可以用来做一些简单的计算。
在这个代码中,
leal -1(%rax), %ecx
相对rax寄存器偏移量为-1的有效地址就是\(rax-1\)。也就是说这一句话等价于
因为ecx是32位的寄存器,实际上是把\(rax-1\)的低32位放入\(ecx\)。我们将\(rcx\)记为一个临时变量\(t1\)。进一步翻译为
下一行代码等价于
然后是一条有符号乘法指令
imulq %rcx, %rdx
等价于
利用相同的方法对剩余的代码进行翻译。最后得到等价的算法为
为了验证我们的人工转换是否正确,我们将这个算法写成另外一份C代码,
#include <stdio.h>
int main() {
int N, n, s = 0;
scanf("%d", &n);
long long t1, t2, t3, result;
N = n;
t1 = N - 1;
t2 = N - 2;
t2 = t2 * t1;
t1 = N - 3;
t1 = t1 * t2;
t2 = t2 / 2;
t3 = 9 * t2;
t2 = t2 + 2 * t3;
t2 = t2 + 8 * N;
N = N - 4;
N = N * t1;
t1 = t1 / 2;
t1 = 3 * t1;
t1 = t2 + 2 * t1;
N = N / 4;
N = N & (-2);
result = N + t1 - 7;
printf("%lld\n", result);
return 0;
}
测试了几组我们发现完全正确。那么这个算法是如何计算\(i^3\)求和的呢?分析如下:
下面还剩
没有分析,因为我们卡在了\(N\leftarrow N\&(-2)\)上。这一句是什么意思?
考虑\(-2\)这个二进制数。\(-2\)的原码为\((10)_2\),反码为\((1111...01)_2\),则补码为\((11111111110)_2\)。\(\&\)运算符是按每一个二进制位与。所以,\(N\leftarrow N\&(-2)\),本质上就是把\(N\)的前面的位都保持不变,而最后一位强制变为0。
实际上,因为我们知道,因为现在\(N\)已经被设置为
当\(n=2k\)的时候,\(N=(k-1)(k-2)(2k-1)(2k-3)\)。容易知道,\((k-1)(k-2)\)必然为偶数,而\((2k-1)(2k-3)\)必然为奇数。所以,\(N\)为偶数,而偶数的最后一位为0,也就是\(N\&(-2)=N\)。
当\(n=2k+1\)的时候,\(N=(2k-1)k(2k-3)(k-1)\)。同理,这个数仍然是偶数。也有\(N\&(-2)=N\)。
所以现在这个\(N\)一定是一个偶数。偶数的最后一位二进制位本来就是0。于是,这一句\(N\leftarrow N\&(-2)\),实际上并不会改变\(N\)的值。
最后一句,就相当于
下面证明
设\(F(n)=\frac{(n-1)(n-2)(n-3)(n-4)}{4}+\frac{19(n-1)(n-2)}{2}+8n+3(n-1)(n-2)(n-3)-7\)。
考虑数学归纳法。
当\(n=1\)的时候,\(1=F(1)=1\),结论成立。
当\(n=k\)的时候,需要证明
即证明
即证明
容易知道
可以看出,编译优化之后,编译器使用公式进行了求和,并且公式与直接求和是等价的。
本质探究
标量演化技术概述
Scalar evolution(SCEV)技术是一种现代化的高级编译优化技术。该技术主要用于分析循环中变量是如何被更新的,然后根据这个信息来进行优化。
引入
SCEV的核心是下列表示:
我们将这种表示称为循环链(chrec,Chains of Recurrences),至于为什么这么称呼,在后面就会介绍。其中,\(\phi\)为一个二元运算符且\(\phi\in\{+,*\}\)。也就是说,循环里的每一个标量变量\(var\),都可以用起始值(\(start\)),步长(\(step\))和更新方式(\(\phi\))三个参数来体现。值得注意的是,SCEV仅适用于标量(往往只针对整数类型)。而整数类型的变量在循环中用到的几率是非常大的。可以说,这种优化技术是“加速大概率事件”伟大思想的一种体现。
下面举个例子,考虑下面的循环:
int j = 10;
for (int i = 0; i < n; i++) {
k = i + j;
printf("%d\n", k);
j = j + 2;
}
这种表示方式可以清晰地体现一个变量在循环中的变化。\(i\)是一个变量,它从0开始,每次+1,于是记\(i=\{0,+,1\}\)。\(j\)是一个变量,从10开始,每次+2,则记\(j=\{10,+,2\}\)。
那么\(k\)怎么表示呢?实际上,循环链算式可以进行代数运算:
直观角度讲这也是显然的。\(k=i+j\),说明\(k\)将会从\(i+j\)的初值,也就是10开始,步长为\(i\)、\(j\)步长的和。
这样一来,循环就可以改写为:
int j = 10, k = 10;
for (int i = 0; i < n; i++) {
printf("%d\n", k);
k = k + 3;
j = j + 2;
}
这样你可能会说,这有什么意义呢?比如在MIPS汇编中,这个\(k=k+3\),和\(k=i+j\),都对应了一条加法指令,占用着一个时钟周期,看起来也没啥优化呀。但是万一情况变成下面这样,
int p = 1, q = 2, r = 3, s = 4;
for (int i = 0; i < n; i++) {
k = p + 2 * q + 3 * r + s;
printf("%d\n", k);
p = p + 2;
q = q + 3;
r = r + 1;
s = s + 5;
}
使用标量演化,则可知
则代码可以被优化为
int p = 1, q = 2, r = 3, s = 4, k = 18;
for (int i = 0; i < n; i++) {
printf("%d\n", k);
k = k + 16;
p = p + 2;
q = q + 3;
r = r + 1;
s = s + 5;
}
这下优化就明显了!原本计算\(k\)需要3条加法指令,2条乘法指令。而现在,只需要1条加法指令!这一点在矩阵运算中,有更明显的优势。比如计算矩阵加法时:
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
c[i][j] = a[i][j] + b[i][j];
}
}
编译器首先将数组运算转换为地址运算
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
p = i * n + j;
*(c + p) = *(a + p) + *(b + p);
}
}
对于内层循环,我们发现i和n是与这层循环的循环变量\(j\)无关的,我们称\(i\)、\(n\)为关于\(j\)的循环不变量(Loop invariant)。于是,考虑\(p\)的chrec:
于是,循环可以被优化为
for (int i = 0; i < n; i++) {
p = i * n;
for (int j = 0; j < n; j++) {
*(c + p) = *(a + p) + *(b + p);
p = p + 1;
}
}
下面对外层循环优化。此时,只有\(n\)为循环不变量。考虑\(p\)的chrec:
于是,将\(p\)进一步优化,
p = 0;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
*(c + p) = *(a + p) + *(b + p);
p = p + 1;
}
p = p + n;
}
我们分析一下,优化前,为了计算\(p\)地址,需要进行\(n^2\)次乘法运算,\(n^2\)次加法运算。而优化之后,只需要进行\(n^2+n\)次加法运算。并且我们知道,乘法运算的时钟周期数大于加法运算的周期数(乘法运算电路耗时较大)。可以看出,使用了标量演化优化之后,性能提升还是非常大的!
对于更加复杂的情况,又怎么办呢?比如这样的循环:
int j = 10;
for (int i = 0; i < n; i++) {
printf("%d\n", x);
x = i * j;
j = j + 2;
}
我们先从直观上理解。\(i\)和\(j\)的变化我们非常清楚。而\(x\)的变化(步长)怎么去体现?显然,
其中\(i_k\)表示循环到第\(k\)次时\(i\)的值,\(j_k\)表示循环到第\(k\)次时\(j\)的值。那么,\(x\)的步长应该是
可以看出,如果引入新的一个中间变量
那么\(x\)的步长可以表示为\(t\)。也就是
而对于\(t\),有
于是,可以记
在刚才,我们还有一个重要发现,也就是考虑
设\(x=\{\phi_0,+,\phi_1\}\),\(y=\{\delta_0,+,\delta_1\}\),则
故
也就是说,我们得到了公式
如果\(\delta_1\)和\(\phi_1\)为常数,则公式还可以表示为
我们将嵌套的大括号的层数成为标量演化公式的阶数。也就是说,乘积形式(高次)的循环链,可以被展开成高一阶的和形式的循环链。出于简化书写,将
简写为
类似地,
被简写为
此时,如果\(\delta_1\)和\(\phi_1\)为常数,刚才的乘积公式就变成
我们发现,这种结构是一种链式的结构。对于一个循环链,它的步长要么为一个常数,要么为另一个循环链。这就是“循环链”名字的含义(个人理解)。
深入
首先我们对刚才的研究做一些总结。我们定义几个基本公式:
循环链
其中\(\alpha\)必须要是一个常数,\(\phi\)是一个二元运算符,\(\phi\in\{+,*\}\)。\(\beta\)可能是一个循环链,也可能是一个常数。
线性
对于常数\(\alpha\)和\(\beta\),
乘积形式
有了这些定义,我们做一些更复杂的考虑。实际上,如果多项式次数更高,我们仍然可以用循环链来刻画。比如,
for (int i = 1; i <= n; i++) {
s = i * i * i;
printf("%d\n", s);
}
(注意这里讨论的是\(s=\),而不是前面那个+=)
此时,\(s=i^3\),考虑\(s\),
太优美了!于是,我们可以将其优化为
int t1 = 1, t2 = 7, t3 = 12;
int n = 5;
for (int i = 1; i <= n; i++) {
printf("%d\n", t1);
t1 += t2;
t2 += t3;
t3 += 6;
}
对比一下,之前需要\(2n\)条乘法指令,而现在,只需要\(3n\)条加法指令。而乘法指令比加法指令慢的远不止1.5倍。显然快多了。
s = 0;
for (int i = 1; i <= n; i++) {
s += i * i * i;
}
我们发现,
体现到代码上,
int t0 = 0, t1 = 1, t2 = 7, t3 = 12;
int n = 5;
for (int i = 1; i <= n; i++) {
t0 = t0 + t1;
t1 += t2;
t2 += t3;
t3 += 6;
}
最终的运算结果,将在\(t0\)当中保存。但是我们发现,这样虽然性能有很大提升,但还没有到之前连循环都没了的程度啊。这样实际上也是\(O(n)\)的,只不过复杂度当中的常数更小了。
我们需要观察这个求和。我们发现,这个求和的过程中,这个循环中间尽管\(s\)在不断变化,但是我们不需要追踪它的变化(也就是说我们不需要知道\(s\)的中间值),只需要知道循环结束之后\(s\)是多少!
现在再回来考虑我们的循环链
当我们不在意\(var\)的中间变化的时候,实际上我们完全可以用数学方法推出,该循环过程执行了\(k\)次之后,\(var\)是多少。而不需要借助于程序。比如最简单的
我们可以直接得到循环3次后,\(var=9\)。那么对于一般情况
为了方便从高中的数列递推角度考虑,我们设该循环链为\(n\)阶,并改写为
其中\(a_0\)为一个常数\(c\)。并且,用\(a_n^{(k)}\)表示循环变量\(a_n\)在循环\(k\)次之后的值。可以列出等式
全部写出来,
将上面这组算式求和得到
写到这里有的读者可能已经懵了,建议先展开算几项找找感觉再继续看。如果你写了几项,就会发现,最后的答案,一定是\(\alpha_1\),\(\alpha_2\),\(...\),\(\alpha_{n}\)的一个线性组合。
接下来就可以从每一个\(\alpha\)对答案的贡献考虑了。
对于\(\alpha_n\),只会在最开始被加1次。而\(1=C_i^0\)。
对于\(\alpha_{n-1}\),在\(a_{n-1}^{(0)}\)、\(a_{n-1}^{(1)}\)、\(\cdots\)、\(a_{n-1}^{(i-1)}\)中分别被加1次,故总共被加\(i=C_i^1\)次。
对于\(\alpha_{n-2}\),在\(a_{n-1}^{(1)}\)中被加1次,在\(a_{n-1}^{(2)}\)中被加2次,\(\cdots\),在\(a_{n-1}^{(i-1)}\)被加\(i-1\)次,故总共被加\(\frac{i(i-1)}{2}=C_i^2\)次。
数学直觉告诉我们,不妨猜想\(\alpha_{n-k}\)的贡献为\(C_i^k\alpha_{n-k}\)次。实际上事实就是这样。但是笔者还没有想出严谨的证明方式,只是直观理解。
于是,可知\(a_n^{(i)}=\sum_{j=0}^iC_i^j\alpha_{n-j}\)。
如果写成更加好看的形式,也就是设
则可知\(var\)在第\(i\)次循环中的值为
这是非常精妙的结论。于是,对于刚才的三次求和的\(s\),
可知\(s\)在\(n\)次循环后的值为
于是,编译器只要生成计算上面这个式子的代码即可。我们发现,这个式子和之前给出的三次方求和,我们人为反编译出来看到的式子是神似的(可能有些细节不同,但是是完全等价的)。可以说明编译器优化求和的本质原理就在这里。