多项式学习笔记(三): 多项式全家桶
1.多项式求逆
给你 \(A(x)\) 求 \(A(x)B(x) \equiv 1 \pmod {x^n}\) 。 (模 \(x^n\) 是为了把高次项舍掉)
假设我们已经得到了满足 \(C(x)A(x) \equiv 1 \pmod {x^{n\over 2}}\) 的一个多项式 \(C\) 。
那么由题意可得 \(A(x)B(x)\equiv 1 \pmod {x^{n\over 2}}\) 。
两式联立可得:
\(B(x) \equiv C(x) \pmod {x^{n\over 2}}\)
\(B(x) - C(x) \equiv 0 \pmod {x^{n\over 2}}\)
两边同时平方可得:
\(B^2(x) + C^2(x) - 2B(x)C(x) \equiv 0 \pmod {x^{n}}\)
在同时乘上一个 \(A(x)\) 得:
\(A(x)B^2(x) + A(x)C^2(x)-2A(x)B(x)C(x)\equiv 0 \pmod {x^{n}}\)
然后由题意可得 \(A(x)B(x)\equiv 1 \pmod {x^n}\) ,代入化简可得:
\(B(x) + A(x)C^2(x)-2C(x) \equiv 0 \pmod {x^n}\)
\(B(x) = 2C(x) - A(x)C^2(x)\)
然后,我们每次都可以把项数减半递归求解, 如果项数为 \(1\) 的话结果显然是零次项的逆元。
复杂度 \(T(n) = T({n\over 2}) + nlogn = nlogn\)
Code
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
#define int long long
const int N = 1e6+10;
const int p = 998244353;
int n,a[N],b[N],rev[N],c[N];
inline int read()
{
int s = 0,w = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
return s * w;
}
int ksm(int a,int b)
{
int res = 1;
for(; b; b >>= 1)
{
if(b & 1) res = res * a % p;
a = a * a % p;
}
return res;
}
void NTT(int *a,int len,int opt)
{
for(int i = 0; i < len; i++)
{
if(i < rev[i]) swap(a[i],a[rev[i]]);
}
for(int h = 1; h < len; h <<= 1)
{
int wn = ksm(3,(p-1)/(h<<1));
if(opt == -1) wn = ksm(wn,p-2);
for(int j = 0; j < len; j += (h<<1))
{
int w = 1;
for(int k = 0; k < h; k++)
{
int u = a[j + k];
int v = w * a[j + h + k] % p;
a[j + k] = (u + v) % p;
a[j + h + k] = (u - v + p) % p;
w = w * wn % p;
}
}
}
if(opt == -1)
{
int inv = ksm(len,p-2);
for(int i = 0; i < len; i++) a[i] = a[i] * inv % p;
}
}
void Inv(int n,int *a,int *b)//求 A(x)B(x) = 1 mod x^n
{
if(n == 1)//项数为1的情况
{
b[0] = ksm(a[0],p-2);
return;
}
Inv((n+1)>>1,a,b);//递归求 C(x)
int lim = 1, tim = 0;
while(lim < (n<<1)) lim <<= 1, tim++;
for(int i = 0; i < lim; i++)//预处理NTT的反转数组
{
rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
}
//注意,不能用 a 来做多项式乘法,因为如果拿 a 做了多项式乘法,那么 a 的值在递归过程中,就会发生改变。
for(int i = 0; i < n; i++) c[i] = a[i];//把 a 赋给 c,用 c 来做多项式乘法
for(int i = n; i < lim; i++) c[i] = 0;//多余的高次项舍去
//此时的 B 数组存的是 B(x)A(x) = 1 mod x^{n/2},C数组存的是 A(x)
NTT(c,lim,1); NTT(b,lim,1);//求 B 和 C 的点值表示法
for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;//计算 B的点值
NTT(b,lim,-1);//把B转化为系数表示法
for(int i = n; i < lim; i++) b[i] = 0;//高次项舍去
}
signed main()
{
n = read();
for(int i = 0; i < n; i++) a[i] = read();
Inv(n,a,b);
for(int i = 0; i < n; i++) printf("%lld ",b[i]);
printf("\n");
return 0;
}
2.多项式开根
求 \(B^2(x) \equiv A(x) \pmod {x^n}\)
假设,我们得到了满足 \(C^2(x) \equiv A(x) \pmod {x^{n\over 2}}\) 的一个多项式 \(C(x)\)。
又因为 \(B^2(x) \equiv A(x) \pmod {x^{n\over 2}}\) 。
两式联立可得:
\(B^2(x) \equiv C^2(x) \pmod {x^{n\over 2}}\)
\(B^2(x)-C^2(x) \equiv 0 \pmod {x^{n\over 2}}\)
两边同时平方可得:
\(B^4(x) + C^4(x) - 2B^2(x)C^2(x) \equiv 0 \pmod {x^n}\)
两边同时加上 \(4B^2(x)C^2(x)\) 可得:
\(B^4(x) + C^4(x) + 2B^2(x)C^2(x) \equiv 4B^2(x)C^2(x) \pmod {x^n}\)
\((B^2(x) + C^2(x))^2 \equiv 4B^2(x)C^2(x) \pmod {x^n}\)
把右边的 \(4C^2(x)\) 除过去可得:
\({(B^2(x) + C^2(x))^2 \over 4C^2(x)} \equiv B^2(x) \pmod {x^n}\)
\(B(x) \equiv {B^2(x) + C^2(x)\over 2C(x)} \pmod {x^n}\)
又因为 \(B^2(x) \equiv A(x) \pmod {x^n}\) ,代入可得:
\(B(x) \equiv {A(x) + C^2(x)\over 2C(x)} \pmod {x^n}\)
还是像求逆一样每次项数减半,递归求解,当项数为 \(1\) 的时候答案为 \(\sqrt {常数项}\) 。
多项式求逆加NTT即可。
复杂度 \(O(nlogn)\) 。
Code(常数爆炸):
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
#define int long long
const int N = 1e6+10;
const int p = 998244353;
int n,a[N],b[N],c[N],d[N],rev[N];
inline int read()
{
int s = 0,w = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
return s * w;
}
int ksm(int a,int b)
{
int res = 1;
for(; b; b >>= 1)
{
if(b & 1) res = res * a % p;
a = a * a % p;
}
return res;
}
void NTT(int *a,int len,int opt)//NTT 板子
{
for(int i = 0; i < len; i++)
{
if(i < rev[i]) swap(a[i],a[rev[i]]);
}
for(int h = 1; h < len; h <<= 1)
{
int wn = ksm(3,(p-1)/(h<<1));
if(opt == -1) wn = ksm(wn,p-2);
for(int j = 0; j < len; j += (h<<1))
{
int w = 1;
for(int k = 0; k < h; k++)
{
int u = a[j + k];
int v = w * a[j + h + k] % p;
a[j + k] = (u + v) % p;
a[j + h + k] = (u - v + p) % p;
w = w * wn % p;
}
}
}
if(opt == -1)
{
int inv = ksm(len,p-2);
for(int i = 0; i < len; i++) a[i] = a[i] * inv % p;
}
}
void Inv(int n,int *a,int *b)//多项式求逆板子
{
if(n == 1)
{
b[0] = ksm(a[0],p-2);
return;
}
Inv((n+1)>>1,a,b);
int lim = 1, tim = 0;
while(lim < (n<<1)) lim <<= 1, tim++;
for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
for(int i = 0; i < n; i++) c[i] = a[i];
for(int i = n; i < lim; i++) c[i] = 0;
NTT(c,lim,1); NTT(b,lim,1);
for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;
NTT(b,lim,-1);
for(int i = n; i < lim; i++) b[i] = 0;//记得清空
}
void sqrt(int n,int *a,int *b)
{
if(n == 1)//项数为 1的情况
{
b[0] = (int) sqrt(a[0]);
return;
}
sqrt((n+1)>>1,a,b);
Inv(n,b,d);//这里求 mod x^n 下的逆元,而不是 mod x^lim 下的逆元
int lim = 1, tim = 0;
while(lim < (n<<1)) lim <<= 1, tim++;
for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
for(int i = 0; i < n; i++) c[i] = a[i];//用c数组代替a来做多项式乘法
for(int i = n; i < lim; i++) c[i] = 0;
//这里 b 数组存的是 C^2(x) = A(x) mod x^{n/2}
// c数组 存的是 A(x), d数组存的是 C(x) 的乘法逆
NTT(b,lim,1); NTT(c,lim,1); NTT(d,lim,1);
int inv2 = ksm(2,p-2);
for(int i = 0; i < lim; i++) b[i] = (b[i] * b[i] % p + c[i] % p) * d[i] % p * inv2 % p;//根据柿子算出 B(x) 的点值
NTT(b,lim,-1);//转换为系数表示法
for(int i = n; i < lim; i++) b[i] = 0;
for(int i = 0; i < lim; i++) d[i] = 0;//多次调用要清空
}
signed main()
{
n = read();
for(int i = 0; i < n; i++) a[i] = read();
sqrt(n,a,b);
for(int i = 0; i < n; i++) printf("%lld ",b[i]);
return 0;
}
3.多项式求导
若 \(A(x) = \displaystyle\sum_{i=0}^{n} a_ix^i\) , 则 \(A^\prime(x) = \displaystyle\sum_{i=0}^{n} ia_{i}x^{i-1}\)
void qiudao(int len,int *a,int *b)
{
for(int i = 1; i < len; i++) b[i-1] = i * a[i] % p;
b[len-1] = 0;
}
5.多项式积分
若 \(A(x) = \displaystyle\sum_{i=0}^{n}a_ix^i\) ,则 \(\int A(x) = \displaystyle\sum_{i=1}^{n} {a_i\over i+1} x^{i+1}\)
void jifen(int len,int *a,int *b)
{
for(int i = 1; i < len; i++) b[i] = a[i-1] * ksm(i,p-2) % p;
b[0] = 0;
}
6.多项式 ln
求 \(B(x) \equiv lnA(x) \pmod {x^n}\)
设 \(F(x) = lnA(x)\) ,则 对等式两边同时求导可得:
\(B^\prime(x) \equiv F^\prime(x) \pmod {x^n}\)
根据复合函数求导公式 \(f^\prime(g(x)) = f^\prime(g(x)) g^\prime(x)\) 可得:
\(B^\prime(x) \equiv {A^\prime (x)\over A(x)} \pmod {x^n}\)
先求出 \(A(x)\) 的导函数和乘法逆,在相乘得到 \(B^\prime(x)\) ,最后在积分回去即可。
多项式求逆,多项式求导,多项式积分,多项式乘法。
复杂度 \(O(nlogn)\)
code
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
#define int long long
const int p = 998244353;
const int N = 1e6+10;
int n,a[N],b[N],c[N],rev[N],A[N],B[N];
inline int read()
{
int s = 0,w = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
return s * w;
}
int ksm(int a,int b)
{
int res = 1;
for(; b; b >>= 1)
{
if(b & 1) res = res * a % p;
a = a * a % p;
}
return res;
}
void NTT(int *a,int len,int opt)
{
for(int i = 0; i < len; i++)
{
if(i < rev[i]) swap(a[i],a[rev[i]]);
}
for(int h = 1; h < len; h <<= 1)
{
int wn = ksm(3,(p-1)/(h<<1));
if(opt == -1) wn = ksm(wn,p-2);
for(int j = 0; j < len; j += (h<<1))
{
int w = 1;
for(int k = 0; k < h; k++)
{
int u = a[j + k];
int v = w * a[j + h + k] % p;
a[j + k] = (u + v) % p;
a[j + h + k] = (u - v + p) % p;
w = w * wn % p;
}
}
}
if(opt == -1)
{
int inv = ksm(len,p-2);
for(int i = 0; i < len; i++) a[i] = a[i] * inv % p;
}
}
void Inv(int n,int *a,int *b)
{
if(n == 1)
{
b[0] = ksm(a[0],p-2);
return;
}
Inv((n+1)>>1,a,b);
int lim = 1, tim = 0;
while(lim < (n<<1)) lim <<= 1, tim++;
for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
for(int i = 0; i < n; i++) c[i] = a[i];
for(int i = n; i < lim; i++) c[i] = 0;
NTT(b,lim,1); NTT(c,lim,1);
for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;
NTT(b,lim,-1);
for(int i = n; i < lim; i++) b[i] = 0;
}
void qiudao(int len,int *a,int *b)
{
for(int i = 1; i < len; i++) b[i-1] = i * a[i] % p;
b[len-1] = 0;
}
void jifen(int len,int *a,int *b)
{
for(int i = 1; i < len; i++) b[i] = a[i-1] * ksm(i,p-2) % p;
b[0] = 0;
}
void Ln(int n,int *a,int *b)
{
Inv(n,a,A); qiudao(n,a,B);//A 存的是 a的乘法逆,B存的是 a的导函数
int lim = 1, tim = 0;
while(lim < (n<<1)) lim <<= 1, tim++;
for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
NTT(A,lim,1); NTT(B,lim,1);
for(int i = 0; i < lim; i++) B[i] = B[i] * A[i] % p;
NTT(B,lim,-1); jifen(lim,B,b);//B存的是 b 的导函数
for(int i = n; i < lim; i++) b[i] = 0;
}
signed main()
{
n = read();
for(int i = 0; i < n; i++) a[i] = read();
Ln(n,a,b);
for(int i = 0; i < n; i++) printf("%lld ",b[i]);
return 0;
}
7.多项式除法
给你一个 \(n\) 次多项式 \(A(x)\) 和一个 \(m\) 次的多项式 \(B(x)\),求多项式 \(C(x)\) 和 \(D(x)\) 满足:
- \(C(x)\) 的次数为 \(n-m\), \(D(x)\) 的次数小于 \(m\)
- \(A(x) = C(x) * B(x) + D(x)\)
设 \(f(x)\) 是一个 \(n\) 次多项式,则定义 \(inv(f(x)) = x^nf({1\over x})\)
\(inv(f(x)) = x^n f({1\over x}) = x^n(a_0+a_1x^{-1}+...a_nx^{-n}) = a_{n} + a_{n-1}x^1 + a_{n-2}x^2+....a_{1}x^{n-1} + a_0x^{n}\)
所以 \(inv(f(x))\) 其实就是把 \(f(x)\) 的系数反转过来得到的结果。
\(\because A(x) = C(x) * B(x) + D(x)\)
所以有 \(inv(A(x)) = inv(C(x) * B(x) + D(x))\) 。
展开可得:
\(x^nA({1\over x}) = x^{n} (C({1\over x}) * B({1\over x}) + D({1\over x}))\)
\(x^nA({1\over x}) = x^mB({1\over x}) x^{n-m} C({1\over x}) + x^{n-m+1} x^{m-1} D({1\over x})\)
在转化为 \(inv(f(X))\) 可得:
\(inv(A(x)) = inv(B(x))inv(C(x)) + x^{n-m+1}inv(D(x))\)
两边同时模上 \(x^{n-m+1}\) 可得:
\(invA(x) \equiv inv(B(x))inv(C(x)) \pmod {x^{n-m+1}}\)
\(inv(C(x)) \equiv {inv(A(x))\over invB(x)} \pmod {x^{n-m+1}}\)
多项式乘法和多项式求逆可以求出来 \(inv(C(x))\), 在把系数反转得到 \(C(x)\).
最后把 \(C(x)\) 代入原式可得到 \(D(x)\).
复杂度 \(O(nlogn)\)
一定要注意清空数组(我这个沙比就因为这个卡在了50分好多次 )
Code:
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
#define int long long
const int p = 998244353;
const int N = 1e6+10;
int n,m,rev[N],a[N],b[N],c[N],d[N],A[N],B[N],invB[N];
inline int read()
{
int s = 0,w = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
return s * w;
}
int ksm(int a,int b)
{
int res = 1;
for(; b; b >>= 1)
{
if(b & 1) res = res * a % p;
a = a * a % p;
}
return res;
}
void NTT(int *a,int len,int opt)
{
for(int i = 0; i < len; i++)
{
if(i < rev[i]) swap(a[i],a[rev[i]]);
}
for(int h = 1; h < len; h <<= 1)
{
int wn = ksm(3,(p-1)/(h<<1));
if(opt == -1) wn = ksm(wn,p-2);
for(int j = 0; j < len; j += (h<<1))
{
int w = 1;
for(int k = 0; k < h; k++)
{
int u = a[j + k];
int v = w * a[j + h + k] % p;
a[j + k] = (u + v) % p;
a[j + h + k] = (u - v + p) % p;
w = w * wn % p;
}
}
}
if(opt == -1)
{
int inv = ksm(len,p-2);
for(int i = 0; i < len; i++) a[i] = (a[i] * inv % p + p) % p;
}
}
void Inv(int n,int *a,int *b)
{
if(n == 1)
{
b[0] = ksm(a[0],p-2);
return;
}
Inv((n+1)>>1,a,b);
int lim = 1, tim = 0;
while(lim < (n<<1)) lim <<= 1, tim++;
for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
for(int i = 0; i < n; i++) c[i] = a[i];
for(int i = n; i < lim; i++) c[i] = 0;
NTT(c,lim,1); NTT(b,lim,1);
for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;
NTT(b,lim,-1);
for(int i = n; i < lim; i++) b[i] = 0;
}
void mul(int n,int m,int *a,int *b)
{
int lim = 1, tim = 0;
while(lim < (n<<1)) lim <<= 1, tim++;
for(int i = 0; i <lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
NTT(a,lim,1); NTT(b,lim,1);
for(int i = 0; i < lim; i++) a[i] = a[i] * b[i] % p;
NTT(a,lim,-1);
for(int i = n; i < lim; i++) a[i] = 0;
}
void Chu(int n,int m,int *a,int *b)
{
for(int i = 0; i < n; i++) A[i] = a[n-i-1];//A 数组存的是 inv(A(x))
for(int i = 0; i < m; i++) B[i] = b[m-i-1];//B 数组存的是 inv(B(x))
Inv(n-m+1,B,invB);
for(int i = n-m+1; i < (n<<2); i++) A[i] = invB[i] = 0;
mul(n-m+1,n-m+1,A,invB);
for(int i = 0; i < n-m+1; i++) c[i] = (A[n-m-i] % p + p) % p;
for(int i = 0; i < n-m+1; i++) printf("%lld ",c[i]);
printf("\n");
for(int i = n-m+1; i < (n<<2); i++) c[i] = 0;
mul(n,n,c,b);
for(int i = 0; i < m-1; i++) d[i] = ((a[i] - c[i]) % p + p) % p;
for(int i = 0; i < m-1; i++) printf("%lld ",d[i]);
}
signed main()
{
n = read() + 1; m = read() + 1;
for(int i = 0; i < n; i++) a[i] = read();
for(int i = 0; i < m; i++) b[i] = read();
Chu(n,m,a,b);
return 0;
}
8.多项式 exp
求 \(B(x) \equiv e^{A(x)} \pmod {x^n}\)
设 \(C(x) \equiv e^{A(x)} \pmod {x^{n\over 2}}\) ,则 \(B(x) = C(x) (1-lnC(x) + A(x))\)
多项式求逆即可。
注意:每次求 \(exp\) 的时候,一定要把求 \(ln\) 所有用到的数组都清空掉。
code:
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
#define int long long
const int p = 998244353;
const int N = 1e6+10;
int n,a[N],b[N],c[N],invB[N],invA[N],A[N],B[N],rev[N];
inline int read()
{
int s = 0,w = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
return s * w;
}
int ksm(int a,int b)
{
int res = 1;
for(; b; b >>= 1)
{
if(b & 1) res = res * a % p;
a = a * a % p;
}
return res;
}
void NTT(int *a,int len,int opt)
{
for(int i = 0; i < len; i++)
{
if(i < rev[i]) swap(a[i],a[rev[i]]);
}
for(int h = 1; h < len; h <<= 1)
{
int wn = ksm(3,(p-1)/(h<<1));
if(opt == -1) wn = ksm(wn,p-2);
for(int j = 0; j < len; j += (h<<1))
{
int w = 1;
for(int k = 0; k < h; k++)
{
int u = a[j + k];
int v = w * a[j + h + k] % p;
a[j + k] = (u + v) % p;
a[j + h + k] = (u - v + p) % p;
w = w * wn % p;
}
}
}
if(opt == -1)
{
int inv = ksm(len,p-2);
for(int i = 0; i < len; i++) a[i] = a[i] * inv % p;
}
}
void Inv(int n,int *a,int *b)
{
if(n == 1)
{
b[0] = ksm(a[0],p-2);
return;
}
Inv((n+1)>>1,a,b);
int lim = 1, tim = 0;
while(lim < (n<<1)) lim <<= 1, tim++;
for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
for(int i = 0; i < n; i++) c[i] = a[i];
for(int i = n; i < lim; i++) c[i] = 0;
NTT(c,lim,1); NTT(b,lim,1);
for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;
NTT(b,lim,-1);
for(int i = n; i < lim; i++) b[i] = 0;
}
void qiudao(int n,int *a,int *b)
{
for(int i = 1; i < n; i++) b[i-1] = i * a[i] % p;
b[n-1] = 0;
}
void jifen(int n,int *a,int *b)
{
for(int i = 1; i < n; i++) b[i] = a[i-1] * ksm(i,p-2) % p;
b[0] = 0;
}
void Ln(int n,int *a,int *b)
{
Inv(n,a,invA); qiudao(n,a,A);
int lim = 1, tim = 0;
while(lim < (n<<1)) lim <<= 1, tim++;
for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
NTT(invA,lim,1); NTT(A,lim,1);
for(int i = 0; i < lim; i++) B[i] = invA[i] * A[i] % p;
NTT(B,lim,-1); jifen(lim,B,b);
for(int i = n; i < lim; i++) b[i] = 0;
}
void Exp(int n,int *a,int *b)
{
if(n == 1)
{
b[0] = 1;
return;
}
Exp((n+1)>>1,a,b);
int lim = 1, tim = 0;
while(lim < (n<<1)) lim <<= 1, tim++;
for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
for(int i = 0; i < lim; i++) B[i] = A[i] = invA[i]= invB[i] = 0;
Ln(n,b,invB);
for(int i = 0; i < n; i++) c[i] = a[i];
for(int i = n; i < lim; i++) c[i] = 0;
NTT(c,lim,1); NTT(invB,lim,1); NTT(b,lim,1);
for(int i = 0; i < lim; i++) b[i] = b[i] * (1LL - invB[i] + c[i] + p) % p;
NTT(b,lim,-1);
for(int i = n; i < lim; i++) b[i] = 0;
}
signed main()
{
n = read();
for(int i = 0; i < n; i++) a[i] = read();
Exp(n,a,b);
for(int i = 0; i < n; i++) printf("%lld ",b[i]);
printf("\n");
return 0;
}
9.多项式快速幂
求 \(B(x) \equiv A^k(x) \pmod {x^n}\)
做法1: 倍增多项式乘法
和普通的快速幂一样,只不过在相乘的时候是把两个多项式乘起来。
常数比较大。
做法2: 多项式求 ln 多项式exp
等式两边同时取对数可得:
\(ln B(x) \equiv klnA(x) \pmod {x^n}\)
在同时取指数可得:
\(B(x) \equiv e^{klnA(x)} \pmod {x^n}\)
多项式求逆求出 \(lnA(x)\) ,把每一项乘个 \(k\), 最后在 exp回去即可。
复杂度 \(O(nlogn)\)
注意:每次求 exp 的时候,一定要把 求 \(ln\) 用到的数组清空。
Code
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
#define int long long
const int p = 998244353;
const int N = 1e6+10;
int n,k,a[N],b[N],c[N],rev[N],A[N],B[N],invA[N],invB[N],F[N];
char s[N];
inline int read()
{
int s = 0,w = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
return s * w;
}
int ksm(int a,int b)
{
int res = 1;
for(; b; b >>= 1)
{
if(b & 1) res = res * a % p;
a = a * a % p;
}
return res;
}
void NTT(int *a,int len,int opt)
{
for(int i = 0; i < len; i++)
{
if(i < rev[i]) swap(a[i],a[rev[i]]);
}
for(int h = 1; h < len; h <<= 1)
{
int wn = ksm(3,(p-1)/(h<<1));
if(opt == -1) wn = ksm(wn,p-2);
for(int j = 0; j < len; j += (h<<1))
{
int w = 1;
for(int k = 0; k < h; k++)
{
int u = a[j + k];
int v = w * a[j + h + k] % p;
a[j + k] = (u + v) % p;
a[j + h + k] = (u - v + p) % p;
w = w * wn % p;
}
}
}
if(opt == -1)
{
int inv = ksm(len,p-2);
for(int i = 0; i < len; i++) a[i] = a[i] * inv % p;
}
}
void Inv(int n,int *a,int *b)
{
if(n == 1)
{
b[0] = ksm(a[0],p-2);
return;
}
Inv((n+1)>>1,a,b);
int lim = 1, tim = 0;
while(lim < (n<<1)) lim <<= 1, tim++;
for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
for(int i = 0; i < n; i++) c[i] = a[i];
for(int i = n; i < lim; i++) c[i] = 0;
NTT(c,lim,1); NTT(b,lim,1);
for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;
NTT(b,lim,-1);
for(int i = n; i < lim; i++) b[i] = 0;
}
void qiudao(int n,int *a,int *b)
{
for(int i = 1; i < n; i++) b[i-1] = i * a[i] % p;
b[n-1] = 0;
}
void jifen(int n,int *a,int *b)
{
for(int i = 1; i < n; i++) b[i] = a[i-1] * ksm(i,p-2) % p;
b[0] = 0;
}
void Ln(int n,int *a,int *b)
{
Inv(n,a,invA); qiudao(n,a,A);
int lim = 1, tim = 0;
while(lim < (n<<1)) lim <<= 1, tim++;
for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
NTT(invA,lim,1); NTT(A,lim,1);
for(int i = 0; i < lim; i++) B[i] = A[i] * invA[i] % p;
NTT(B,lim,-1); jifen(n,B,b);
for(int i = n; i < lim; i++) b[i] = 0;
}
void Exp(int n,int *a,int *b)
{
if(n == 1)
{
b[0] = 1;
return;
}
Exp((n+1)>>1,a,b);
int lim = 1, tim = 0;
while(lim < (n<<1)) lim <<= 1, tim++;
for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
for(int i = 0; i < lim; i++) B[i] = A[i] = invA[i] = invB[i] = 0;//这里一定要清空ln所有用到的所有的数组,我在exp板子的时候只清空了两个数组既然还过了,就nm离谱
Ln(n,b,invB);
for(int i = 0; i < n; i++) c[i] = a[i];
for(int i = n; i < lim; i++) c[i] = 0;
NTT(b,lim,1); NTT(invB,lim,1); NTT(c,lim,1);
for(int i = 0; i < lim; i++) b[i] = b[i] * (1LL - invB[i] + c[i] + p) % p;
NTT(b,lim,-1);
for(int i = n; i < lim; i++) b[i] = 0;
}
void kuaisumi(int n,int k,int *a)
{
Ln(n,a,F);//F 存的是 ln(A(x))
for(int i = 0; i < n; i++) F[i] = F[i] * k % p;
Exp(n,F,b);//exp回去
}
signed main()
{
n = read(); scanf("%s",s+1);
for(int i = 1; i <= (int) strlen(s+1); i++) k = (k * 10 + s[i] - '0') % p;
for(int i = 0; i < n; i++) a[i] = read();
kuaisumi(n,k,a);
for(int i = 0; i < n; i++) printf("%lld ",b[i]);
printf("\n");
return 0;
}