多项式的高级运算
基础:FFT与NTT
参考博客:多项式和生成函数
分治FFT
(其实我觉得对于我这个几乎只写NTT的人来说,叫分治NTT比较好)
简单说一下,分治FFT用到了CDQ分治的思想,但不用非得学CDQ分治,毕竟这个思想还是比较好理解的,之前也经常用到。简而言之,这里的CDQ分治思想就是:每次只考虑左半段对右半段的贡献,先递归解决左半段,然后让右半段加上左半段的贡献,再递归解决右半段。这样,一次次贡献的加和就组成了每个位置的值。
将题意稍稍转化,f[i] = g[0 ~ i]与f[0 ~ i]的卷积(多项式乘法)。然后我们可以手玩,发现 \(i\) 的贡献可以拆成多个“线段树”上节点的贡献。
注意\(g\) 的取值是从 \(0\) 开始的。
注意,分治 FFT 要求必须有一个多项式是已知的(比如模板中的 \(g\))
至于到底是项数还是度数,其实多取值一两个应该是问题不大的,就都认为是项数吧。
最关键的几条语句:
int len = R - L + 1;//项数
for (register int i = 0; i < len; ++i) A[i] = g[i];
for (register int j = 0; j <= mid - L; ++j) B[j] = f[L + j];
len += (mid - L + 1);
\(Code:\) my code
(主要篇幅是NTT,其实重点就在于以上两条语句,NTT只是工具)
注意:必须对要求的函数(这里是 \(f\))进行 CDQ 分治,即将 \(f_i\) 的贡献拆为若干段 \(f_j\) 的贡献。否则 \(B\) 那部分可能还没有求出来。
拓展:两个函数相互卷
形如 \(f = \sum f \cdot g + ...\),\(g =\sum f \cdot g + ...\),一种方法是列生成函数暴力解,局限性较大(可能不好解)。一种方法是一块做,分别 CDQ。按 \(L\) 分类,\(L=0\) 的时候算 \(0...mid\) 的 \(f,g\) 卷的贡献,当 \(L > 0\) 的时候 \(f(L...mid) \times g(0...len)\) 和 \(g(L...mid) \times f(0...len)\) 都要算,因为都还没算过贡献呢。
例题:普通的计数题
//f=g*h,g=f*g
void sol(int L, int R) {
if (L == R) ...
sol(L, mid);
for (int i = L; i <= mid; ++i) A[i - L] = g[i];
for (int i = 1; i <= len; ++i) B[i - 1] = h[i];
...(->f)
if (L == 0) {
for (int i = 0; i <= mid; ++i) A[i] = f[i];
for (int i = 1; i <= mid; ++i) B[i - 1] = g[i];
...(->g)
} else {
for (int i = L; i <= mid; ++i) A[i - L] = f[i];
for (int i = 1; i <= len; ++i) B[i - 1] = g[i];
...(->g)
for (int i = L; i <= mid; ++i) A[i - L] = g[i];
for (int i = 1; i <= len; ++i) B[i - 1] = f[i];
...(->g)
}
sol(mid + 1, R);
}
任意模数NTT(MTT)
有的时候质数不能表示为 \(a * 2^k + 1\) 的形式,无法进行NTT。如果质数比较小(比如 \(10^4\) 左右),就可以直接用 FFT 搞过去;但是如果质数到达 \(10^9\) 左右,就需要一些技巧了。
我们可以把所有系数化为 \(ax+b\) 的形式,其中 \(x\) 为我们指定的一个数(通常是 \(2^{15}\)),这样的话 \(a, b\) 就都不超过 \(2^{15}\) 了,可以暴力求解。答案为 \(aa'x ^2+ (ab'+ba')x + bb'\),根据需要进行乘法即可。
一种减小常数的方法是:我们设复多项式 \(P = (a, a'),Q=(b,b'),R=(a,-a')\),然后设 \(F=PQ=(ab-a'b',ab'+a'b),F'=RQ=(ab+a'b',ab'-a'b)\),再加加减减:\(F+F'=(2ab,2ab'),F-F'=(-2a'b',2a'b)\),这样,我们需要的东西就都有了,并且只用做 5 次FFT。
关键代码:
for (register int i = 0; i <= n; ++i) {
int a; read(a);
P[i].x = a >> 15;
P[i].y = a & mask;
R[i].x = P[i].x, R[i].y = -P[i].y;
}
for (register int i = 0; i <= m; ++i) {
int a; read(a);
Q[i].x = a >> 15;
Q[i].y = a & mask;
}
fft(P, 1), fft(Q, 1), fft(R, 1);
for (register int i = 0; i < limi; ++i) {
Q[i].x /= limi, Q[i].y /= limi;//提前除以 limi,fft就不用除了
P[i] = P[i] * Q[i], R[i] = R[i] * Q[i];
}
fft(P, -1), fft(R, -1);
for (register int i = 0; i <= n + m; ++i) {
ll a1b1 = (ll)((P[i].x + R[i].x) / 2.0 + 0.5);
ll a1b2 = (ll)((P[i].y + R[i].y) / 2.0 + 0.5);
ll a2b1 = (ll)((P[i].y - R[i].y) / 2.0 + 0.5);
ll a2b2 = (ll)((R[i].x - P[i].x) / 2.0 + 0.5);
ll ans = a1b1 % Mod * ((1ll << 30) % Mod) % Mod;
ans += (a2b1 + a1b2) % Mod * ((1ll << 15) % Mod) % Mod;
ans += a2b2 % Mod;
ans = (ans % Mod + Mod) % Mod;
printf("%lld ", ans);
}
puts("");
下降幂多项式乘法
下降幂多项式:
其中下降幂多项式的系数的 OGF(没有特殊说明均为普通的多项式)和下降幂多项式的点值的 EGF 可以较为轻松地互相转化。设 \(f_i = \sum_{j} a_j i^{\underline{j}}\),则:
于是将系数多项式卷一个 \(e^x\) 即可得到点值 EGF,将点值 EGF 卷一个 \(e^{-x}\) 即可得到系数多项式。
其中 \(e^{-x} = \sum_i \frac{(-1)^i}{i!} x^i\)。
而点值是可以 \(O(n)\) 进行乘法的,于是可以把下降幂多项式转为点值多项式相乘(记得消掉多余的 \(\frac{1}{i!}\) )后在转化回来。复杂度是 \(O(n \log n)\)。
//C = A * B
for (int i = 0; i < n + m; ++i)
E[i] = jieni[i], E_[i] = (i & 1 ? P - 1 : 1) * jieni[i] % P;
Mul(A, E, Af, n + m);
Mul(B, E, Bf, n + m);
for (int i = 0; i < n + m; ++i) Af[i] = Af[i] * Bf[i] % P * jie[i] % P;//bug
Mul(Af, E_, C, n + m);
下面开始正式的多项式工业
多项式求乘法逆元
牛顿迭代法:
\(Code:\)
int n;
ll A[N], B[N], C[N], r[N];
ll limi, l;
inline ll quickpow(ll x, ll k)...
inline void ntt(ll *a, int type) {...//此处已经让type = 1的乘inv了
void sol(int deg, ll *a, ll *b) {//b is 逆元
if (deg == 1) {
b[0] = quickpow(a[0], P - 2);
return ;
}
sol((deg + 1) >> 1, a, b);
limi = 1, l = 0;
while (limi <= (deg << 1)) limi <<= 1, ++l;
for (register int i = 1; i <= limi; ++i)
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
for (register int i = 0; i < deg; ++i) C[i] = a[i];//转移到C防变化
for (register int i = deg; i < limi; ++i) C[i] = 0;//多次清空更保险
ntt(b, 1); ntt(C, 1);
for (register int i = 0; i < limi; ++i)//B = 2B' - AB'B' = B'(2 - AB')
b[i] = ((2ll - b[i] * C[i]) % P + P) % P * b[i] % P;
ntt(b, -1);
for (register int i = deg; i <= limi; ++i)//多次清空更保险
b[i] = 0;
}
int main() {
read(n);
for (register int i = 0; i < n; ++i) read(A[i]);
sol(n, A, B);
for (register int i = 0; i < n; ++i)
printf("%lld ", B[i]);
return 0;
}
这里给一个简化的板子,供复习:
7.28 Update : 拿了个更简化的倍增版本
inline void get_inv(ll *a, ll *b, int d) {//d : 项数
b[0] = quickpow(a[0], P - 2);
L = 0;
for (register int len = 1; len < (d << 1); len <<= 1) {//一定求完长度为 len/2 的 b 数组,想求长度为 len 的 b 数组
limi = len << 1, ++L;
for (register int i = 1; i < limi; ++i)
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
for (register int i = 0; i < len; ++i) C[i] = a[i], D[i] = b[i];
ntt(C, 1), ntt(D, 1);
for (register int i = 0; i < limi; ++i)
b[i] = D[i] * (2 - C[i] * D[i] % P) % P, C[i] = D[i] = 0;
ntt(b, -1);
for (register int i = len; i < limi; ++i) b[i] = 0;
}
}
多项式开根号
题意
- 给A(x),其中a0 = 1,求B(x),使得B(x)^2 = A(x) (mod x^n)
思路简析
与多项式求逆相同,由于一平方就mod x^m -> mod x^(2m)
,我们考虑递归求解。
表达式
同样假设我们已经求出B(x)的一半b(x),那么:
A * b = 1(mod x^m)
A * B = 1(mod x^m)
∴B - b = 0(mod x^m)
两边平方:
B^2 - 2 * B * b + b^2 = 0(mod x^(2*m))
据B^2 = A(mod x^(2*m)):
A - 2 * B * b + b^2 = 0(mod x^(2*m))
于是:
B = (A/b + b) / 2
配合多项式求逆解出B(x)。
(7.28 Update:)
这里给一个更方便,无需更多的聪明与技巧的方法:牛顿迭代:
我们可以一遍 NTT 算出 \(A, B, \frac{1}{B}\) 的点值表示,然后在运算并 NTT 搞回去,但是这需要 NTT 三次,因为第一个 \(B\) 与分式是加法关系,直接系数相加即可,这样就只用 NTT 两次。
递归边界
模板题非常友善,告诉我们 a0 = 1,于是递归边界为
if (deg == 1) {b[0] = 1; return ;}
如果题目没有那么友善,那么我们或许可以多random几个数 我们需要用二次剩余之类的麻烦的东西,或者考虑换一种算法。
Code:
void get_sqrt(ll *a, ll *b, int deg) {
if (deg == 1) {b[0] = 1; return ;}
get_sqrt(a, b, (deg + 1) >> 1);
//get_len
limi = 1, len = 0;
while (limi <= (deg << 1)) limi <<= 1, len++;
for (register int i = 0; i <= limi; ++i)
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (len - 1));
//copy and multiply
for (register int i = 0; i <= limi; ++i) bn[i] = 0;//attention
get_inv(b, bn, deg);
for (register int i = 0; i < deg; ++i) C[i] = a[i];
for (register int i = deg; i <= limi; ++i) C[i] = 0;
ntt(C, 1); ntt(bn, 1);
for (register int i = 0; i <= limi; ++i)
C[i] = C[i] * bn[i] % P;
ntt(C, -1);
for (register int i = 0; i < deg; ++i) b[i] = (C[i] + b[i]) * inv2 % P;
for (register int i = deg; i <= limi; ++i) b[i] = 0;
}
(7.28 Update)
倍增版本:(附新求逆模板)(我怕两个板子之间的 \(limi\) 互相串出错,因此每个函数单独写了个 \(limi\) 和 \(L\)。NTT 的时候直接传入 \(limi\))
inline void get_inv(ll *a, ll *b, int d) {
b[0] = quickpow(a[0], P - 2);
int L = 0, limi = 1;
for (register int len = 1; len < (d << 1); len <<= 1) {
limi = (len << 1), ++L;
for (register int i = 1; i < limi; ++i)
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
for (register int i = 0; i < len; ++i) E[i] = a[i], F[i] = b[i];
ntt(E, 1, limi), ntt(F, 1, limi);
for (register int i = 0; i < limi; ++i)
b[i] = F[i] * (2 - E[i] * F[i] % P) % P, E[i] = F[i] = 0;
ntt(b, -1, limi);
for (register int i = len; i < limi; ++i) b[i] = 0;
}
}
inline void Sqrt(ll *a, ll *b, int d) {
b[0] = 1;
int limi = 1, L = 0;
for (register int len = 1; len < (d << 1); len <<= 1) {
get_inv(b, C, len);
limi = len << 1, ++L;
for (register int i = 1; i < limi; ++i)
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
for (register int i = 0; i < len; ++i) D[i] = a[i];
ntt(D, 1, limi); ntt(C, 1, limi);
for (register int i = 0; i < limi; ++i) C[i] = D[i] * C[i] % P;
ntt(C, -1, limi);
for (register int i = 0; i < len; ++i)
b[i] = INV2 * (b[i] + C[i]) % P, C[i] = D[i] = 0;
for (register int i = len; i < limi; ++i) b[i] = 0;
}
}
注意:
- 用数组前一定注意清空。我也不知道为什么,反正不清空就会出错。估计是NTT的祸吧。
多项式除法
-
\(A(x) * B(x) + C(x) = D(x)\),给出 \(D(x), A(x)\),求 \(B(x),C(x)\).(类似高精除)
-
\(m = deg(A) < 100,000, n = deg(D) <= 100,000, m <= n\)
思路简析
我们发现神奇的事情:把A,B,C,D都翻转过来,(把C加到D后面),等式仍然成立。并且还有一个好处:C转过来后D的0~n-m项都不受C的影响,而B又肯定超不过n-m+1项。因此我们可以借助反转后的A,D数组算出B数组,然后什么都好搞了。
2020.12.23 Update:
还是说点人话吧
定义翻转操作为如下操作:
其中 \(n\) 为 \(f(x)\) 的次数。显然次操作可以 \(O(n)\) 做完。
现在回到题目:
他们的项数(本题中)为:\(m\) 项,\(n-m+1\) 项,\(m-1\) 项,\(n\) 项。(此处的 \(n,m\) 为原题面的 \(n+1,m+1\))于是,现在尝试对它们进行翻转,应该有如下等式:
而我们恰好要的是 \(B^R\) 的前 \(n-m+1\) 项。于是直接在 \(\bmod x^{n-m+1}\) 的意义下多项式求逆整出 \(B^R(x)\),然后据此推出答案即可。
(注意区分好以上文字中的“项数”和“次数”)
Code:
int main() {
read(n); read(m); n++; m++;//n,m变成项数
for (register int i = 0; i < n; ++i) read(D[i]), Dbp[i] = D[i];
for (register int i = 0; i < m; ++i) read(A[i]), Abp[i] = A[i];//backup
Reverse(A, m);
Reverse(D, n);
for (register int i = n - m + 1; i < n; ++i)
A[i] = D[i] = 0;
get_inv(A, An, n - m + 1);
mul(D, An, B, n - m + 1, n - m + 1);
//D(n-m+1项) * An(n-m+1项) -> B
Reverse(B, n - m + 1);
mul(Abp, B, AB, m, n - m + 1);
for (register int i = 0; i < m - 1; ++i)
C[i] = (Dbp[i] - AB[i] + P) % P;
...
}
注意
-
在算D*inv(A)时,保险起见,只保留D和A的0~n-m项,且对A,D做备份,算C时用。
-
什么时候用n-m,什么时候用n-m+1,要分清楚!
-
此时多项式变量名逐渐增多,注意区分,不要把An写成A!
多项式求ln
在学习微积分后,我再学ln,感觉舒适了很多。
求ln很简单,两边求个导,用一下多项式求逆,再积分即可。
7.28 Update:
给一下推导:
思维难度低,代码量大。
好想好写,算是比较小清新的板子了。
(更新封装风格的代码)
inline void dao(ll *a, ll *b, int d) {
for (register int i = 0; i < d; ++i) b[i] = a[i + 1] * (i + 1) % P;
}
inline void ji(ll *a, ll *b, int d) {//利用线性推逆元,积分可以做到 O(n)
for (register int i = 1; i < d; ++i)//注意:如果想把b和a写成一个数组的话,要注意转移顺序
b[i] = a[i - 1] * quickpow(i, P - 2) % P;
}
inline void get_ln(ll *a, ll *b, int d) {
get_inv(a, C, d);
dao(a, D, d);
int limi = 1, L = 0;
while (limi < (d << 1)) limi <<= 1, ++L;
for (register int i = 1; i < limi; ++i)
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
ntt(C, 1, limi), ntt(D, 1, limi);
for (register int i = 0; i < limi; ++i) C[i] = C[i] * D[i] % P;
ntt(C, -1, limi);
for (register int i = d; i < limi; ++i) C[i] = 0;
ji(C, b, d);
}
多项式求exp
继续牛顿迭代:
然后就终极套娃即可。
\(1 - ln(B(x)) + A(x)\) 可以用系数表示法直接运算,但是记住这是系数表示法,不是点值表示,不是每一项都 + 1,而是只有常数项 +1
多项式快速幂
自然可以倍增快速幂,不过那是 \(n~logn~logk\) 的。既然多项式可以快速取 \(ln,exp\),我们可以利用这一点来做到 \(n~logn\)
套 ln 和 exp 即可。
下面给一个 快速幂 + exp + ln + 求逆求导积分 的代码,这份代码实现的是求:
的系数。注意到常数项不一定是 \(1\),可以把常数项提出来最后乘,剩下的部分常数项就是一了。
(之前代码可能会出现爆掉四倍数组的情况,这份代码应该是没问题的)
namespace jzpac {
const int G = 3;
const int Gi = (P + 1) / G;
int r[N], lst;
inline void ntt(ll *a, int limi, int type) {
if (lst != limi) {
lst = limi;
for (int i = 0; i < limi; ++i)
r[i] = (r[i >> 1] >> 1) | ((i & 1) ? (limi >> 1) : 0);
}
for (int i = 0; i < limi; ++i)
if (i < r[i]) swap(a[i], a[r[i]]);
for (int i = 1; i < limi; i <<= 1) {
ll T = quickpow(type == 1 ? G : Gi, (P - 1) / (i << 1));
for (int j = 0; j < limi; j += (i << 1)) {
ll t = 1;
for (int k = 0; k < i; ++k, t = t * T % P) {
ll nx = a[j + k], ny = a[i + j + k] * t % P;
a[j + k] = Mod(nx + ny);
a[i + j + k] = Mod(nx - ny + P);
}
}
}
if (type == -1) {
const ll inv = quickpow(limi, P - 2);
for (int i = 0; i < limi; ++i) a[i] = a[i] * inv % P;
}
}
ll inv_A[N], inv_B[N];
inline void get_inv(ll *a, ll *b, int d) {
b[0] = quickpow(a[0], P - 2);
for (int len = 1; (len >> 1) < d; len <<= 1) {
int limi = len << 1;
for (int i = 0; i < len; ++i) inv_A[i] = a[i], inv_B[i] = b[i];
ntt(inv_A, limi, 1), ntt(inv_B, limi, 1);
for (int i = 0; i < limi; ++i) inv_A[i] = inv_B[i] * (2ll - inv_A[i] * inv_B[i] % P + P) % P;
ntt(inv_A, limi, -1);
for (int i = 0; i < len; ++i) b[i] = inv_A[i];
for (int i = 0; i < limi; ++i) inv_A[i] = inv_B[i] = 0;
}
}
inline void get_dao(ll *a, int d) {
a[d] = 0;
for (int i = 0; i < d; ++i) a[i] = a[i + 1] * (i + 1) % P;
}
inline ll Inv(ll x) { return jie[x - 1] * jieni[x] % P; }
inline void get_ji(ll *a, int d) {
for (int i = d - 1; ~i; --i) a[i + 1] = a[i] * Inv(i + 1) % P;
a[0] = 0;
}
ll ln_A[N], ln_B[N];
inline void get_ln(ll *a, ll *b, int d) {//a[0] = 1
for (int i = 0; i < d; ++i) ln_A[i] = a[i];
get_inv(ln_A, ln_B, d);//bug
get_dao(ln_A, d);
int limi = 1; while (limi < d + d) limi <<= 1;
ntt(ln_A, limi, 1), ntt(ln_B, limi, 1);
for (int i = 0; i < limi; ++i) ln_A[i] = ln_A[i] * ln_B[i] % P;
ntt(ln_A, limi, -1);
get_ji(ln_A, d);
for (int i = 0; i < d; ++i) b[i] = ln_A[i];
for (int i = 0; i < limi; ++i) ln_A[i] = ln_B[i] = 0;
}
ll exp_A[N], exp_B[N];
inline void get_exp(ll *a, ll *b, int d) {//a[0] = 0
b[0] = 1;
for (int len = 1; (len >> 1) < d; len <<= 1) {
int limi = len << 1;
for (int i = 0; i < len; ++i) exp_A[i] = b[i];
get_ln(exp_A, exp_B, len);
exp_A[0] = (a[0] + 1 - exp_B[0] + P) % P;
for (int i = 1; i < len; ++i) exp_A[i] = (a[i] - exp_B[i] + P) % P;//bug
for (int i = 0; i < len; ++i) exp_B[i] = b[i];
ntt(exp_A, limi, 1), ntt(exp_B, limi, 1);
for (int i = 0; i < limi; ++i) exp_A[i] = exp_A[i] * exp_B[i] % P;
ntt(exp_A, limi, -1);
for (int i = 0; i < len; ++i) b[i] = exp_A[i];//bug!
for (int i = 0; i < limi; ++i) inv_A[i] = inv_B[i] = 0;
}
}
ll f[N], g[N];
inline void sol() {
for (int i = 0; i < m; ++i) f[i] = jieni[i];
++f[0];
for (int i = 0; i < m; ++i) f[i] = f[i] * inv2 % P;
get_ln(f, g, m);
for (int i = 0; i < m; ++i) g[i] = g[i] * n % P;
memset(f, 0, sizeof(f));
get_exp(g, f, m);
ll mi = quickpow(2, n);
for (int i = 0; i < m; ++i) f[i] = f[i] * mi % P;
}
}
暴力多项式算法
没有 FFT 的多项式工业。
通常是对某个关系式观察第 \(m\) 项得到递推式子。
求逆
直接展开第 \(m\) 项:
得到递推式。
\(O(n^2)\)
ln/exp
观察第 \(m\) 项:
\(O(n^2)\)
快速幂
观察第 \(m\) 项:
\(O(n^2)\)
当 \(B\) 的项数为 \(k\),结果模 \(x^m\) 的时候,复杂度可以做到 \(O(km)\)。