多项式总结
Part1:FFT(fast fast tle)
前置知识:复数,单位根,多项式的系数表达法,多项式的点值表达法
- 复数:
可以表示为\((a+bi)\),可以看做原点到\((a,b)\)一个向量,其中\(i=\sqrt{-1}\)。
复数可以进行加,减,乘(向量的除法有点问题),其中
即:
同时复数的乘还有和向量一样的几何意义:模长相乘,幅角相加
- 单位根:
在OI中,经常用到2的正整数次幂相关的数,因为这样方便处理,为方便,我们规定下文的\(n\)为2的正整数次幂。
定义:如果\(w_n^n=1\)那么\(w_n\)为\(n\)次单位根
因为\(w_n^n=1\),根据复数乘的几何意义,可知模长为1,幅角为\(\frac{2\pi}{n}\),易得单位根
然后\(w^k_n\)的幅角为\(\frac{2k\pi}{n}\),所以
单位根这里还需要两个性质:
性质一:
即
性质二:
即
- 多项式的系数表达法:
就是平时的表达方法,用\(n+1\)个系数表示一个\(n\)次多项式,比如:
该方法易读,也易求值,但很难快速求卷积。
- 多项式的点值表达法:
就是用\(n+1\)个点来表示一个\(n\)次多项式,比如:
该方法不易理解,但很容易求卷积。
如果两个多项式\(f,g\)满足\(fx_0==gx_0,fx_1==gx_1,...,fx_n==gx_n\),则新多项式
正题:快速傅里叶变换(FFT)以及快速傅里叶逆变换(IFFT)
从上面两种多项式的表达方式中,我们可以发现如果能快速的把多项式在系数与点值中转换,就可以快速的获取两个多项式的卷积。
- 1、系数多项式转点值多项式(快速傅里叶变换)
给出多项式
我们需要快速求出\(f(1),f(w_n),...,f(w^{n-1}_{n-1})\)
先将\(f\)按奇偶分类分为
我们设
那么有
带入\(x=w^k_n\),
带入\(x=w^{k+\frac{n}{2}}_n\)
根据性质二,有
可以发现\(f(w^k_n)\)和\(f(w^{k+\frac{n}{2}}_n)\)两项只差第二项的系数,所以我们可以只用一半的时间就处理整个多项式。
可以发现,如果我们递归的处理,就可以用\(O(n\log n)\)的复杂度实现FFT。
- 2、点值多项式转系数多项式(快速傅里叶逆变换)
我们可以把式子列成一个矩阵(没学过矩阵可以先学再看或跳过推导):
其中\((a_0,a_1,a_2,...,a_{n-1})\)为系数表达法,\((y_0,y_1,y_2,...,y_{n-1})\)为省略\(x\)的点值表达法。
如果我们能快速求出左边这个矩阵的逆矩阵,我们就能快速转换。
考虑矩阵求逆(\(O(n^3)\)完全负优化)
但我们可以发现原矩阵中所有数之间是有关联的,我们可以考虑转换。
设\(V\)为原矩阵,\(G\)为逆矩阵,考虑最终矩阵\(E\)在\((i,j)\)上的值:
因为\(V\)和\(G\)互逆,所以\(E\)是单位矩阵,只有当\(i=j\)时才会有值1,否则为0。
我们先证明一个引理:当\(k\)不是\(n\)的倍数时
由等比数列求和得
因为\(k\)不是\(n\)的倍数,所以\(w_n^k\not=1\),即分母不为0,所以该引理成立。
根据这个引理,可以发现矩阵\(G\)有一个比较简单的构造方式,即\(G(i,j)=w_n^{-ij}\)
这时
当\(i-j\)不为\(n\)的倍数(不为0时),\(E(i,j)=1\),但当\(i=j\)时,已知\(E(i,j)=n\),跟单位矩阵有点偏差,我们在前面加一个\(\frac{1}{n}\)。
好吧,这个推导其实有些牵强,只用把他当做结论记就可以了。
这样,我们就有:
这样我们就可以用类似系数转点值的方法转换了,只是这边的单位根要取反,其实在使用起来时就是
非常简单,只用在FFT的基础上略作修改即可。
最后因为直接乘出来的答案是真实值的\(n\)倍,所以要除以\(n\)。
然后我们就可以写出最基本的递归版FFT(虽然C++有自带complex类型,但用起来会比较慢,建议手写一个):
#include<bits/stdc++.h>
using namespace std;
const int N=1000010;
const double pi=acos(-1);
int n,m,lg[N<<1];
struct Complex{
double x,y;
Complex(double x=0,double y=0):x(x),y(y){}
};
Complex operator+(Complex a,Complex b){return Complex(a.x+b.x,a.y+b.y);}
Complex operator-(Complex a,Complex b){return Complex(a.x-b.x,a.y-b.y);}
Complex operator*(Complex a,Complex b){return Complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
Complex c[N<<2],b[N<<2],t[N<<2],s[50];
void fft(Complex *f,int l,int len,int op){
if(!len)return;
for(int i=l;i<l+(len<<1);++i)t[i]=f[i];
for(int i=l;i<l+len;++i)f[i]=t[l+((i-l)<<1)],f[i+len]=t[l+((i-l)<<1|1)];
fft(f,l,len>>1,op);
fft(f,l+len,len>>1,op);
Complex tmp=s[lg[len]],buf=Complex(1,0),d;
tmp.y*=op;
for(int i=l;i<l+len;++i){
d=buf*f[i+len];
t[i]=f[i]+d;
t[i+len]=f[i]-d;
buf=buf*tmp;
}
for(int i=l;i<l+(len<<1);++i)f[i]=t[i];
}
int main(){
int n,m;
cin>>n>>m;
for(int i=2;i<=n+m;++i)lg[i]=lg[i>>1]+1;
for(int i=0;i<=n;++i)scanf("%lf",&b[i].x);
for(int i=0;i<=m;++i)scanf("%lf",&c[i].x);
for(m+=n,n=1;n<=m;n<<=1);
for(int i=1,j=0;i<=n;i<<=1,++j)s[j]=Complex(cos(pi/i),sin(pi/i));
fft(b,0,n>>1,1);
fft(c,0,n>>1,1);
for(int i=0;i<n;++i)b[i]=b[i]*c[i];
fft(b,0,n>>1,-1);
for(int i=0;i<=m;++i)printf("%.0lf ",fabs(b[i].x)/n);
}
但递归版的着实很慢,这时就要用到神奇的:
二进制反转
因为开始的时候我们把所有数按奇偶性分类,所以我们不能直接枚举区间长度然后处理。但正是因为我们按奇偶分类,我们可以直接按二进制分类,然后一个数的真实位置就是该数的下标按二进制反转后的值。
求真实位置的方法其实很简单:
for(int i=0;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
一个数的二进制反转为这个数除2的反转除2(得到后(l-1)位),如果这个数第一位为1,那么把最后一位加1。
然后就可以用非递归的写法优化递归的大常数:
#include<bits/stdc++.h>
#define eps 1e-6
using namespace std;
const int N=4e6+10;
const double pi=acos(-1);
struct Complex{
double x,y;
Complex(double x=0,double y=0):x(x),y(y){}
};
Complex operator+(Complex a,Complex b){return Complex(a.x+b.x,a.y+b.y);}
Complex operator-(Complex a,Complex b){return Complex(a.x-b.x,a.y-b.y);}
Complex operator*(Complex a,Complex b){return Complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
int limit,l;
int r[N];
void FFT(Complex*f,int op){
for(int i=0;i<limit;++i){
if(i<r[i])swap(f[i],f[r[i]]);
}
for(int mid=2;mid<=limit;mid*=2){
Complex wn=Complex(cos(2*pi/mid),op*sin(2*pi/mid));
for(int j=0;j<limit;j+=mid){
Complex w=Complex(1,0);
for(int k=j;k<j+mid/2;++k,w=w*wn){
Complex x=f[k],y=w*f[k+mid/2];
f[k]=x+y;
f[k+mid/2]=x-y;
}
}
}
if(op==-1){
for(int i=0;i<limit;++i){
f[i].x/=limit;
}
}
}
int n,m;
Complex a[N],b[N];
int main(){
cin>>n>>m;
for(int i=0;i<=n;++i)scanf("%lf",&a[i].x);
for(int i=0;i<=m;++i)scanf("%lf",&b[i].x);
for(limit=1,l=0;limit<=n+m;limit*=2,l++);
for(int i=0;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
FFT(a,1),FFT(b,1);
for(int i=0;i<limit;++i)a[i]=a[i]*b[i];
FFT(a,-1);
for(int i=0;i<=n+m;++i){
if(fabs(a[i].x)<eps)printf("0 ");
else printf("%.0lf ",a[i].x);
}
}
例题:[ZJOI2014]力
Part2:NTT(快速数论变换)
在FFT的时候,因为会用到大量sin和cos以及double的乘法,会让精度有巨大的损失,所以我们要用一些其他的方法来让精度不损失。这时就要用到快速数论变换。
前置知识:原根
- 原根
原根:是一个数学符号。设\(m\)是正整数,\(a\)是整数,若\(a\)模\(m\)的阶等于\(\varphi(m)\),则称\(a\)为模\(m\)的一个原根。
阶:使\(a^n\equiv 1(\mod p)\)成立的最小正整数\(n\)叫做\(a\)模\(p\)的阶。这里\(\equiv\)指同余符号,代表\(a^n\)除以\(p\)的余数跟1除以\(p\)的余数相等。
一般情况下模数为998244353,而998244353的原根为3。因为有:
而且对于\(1\leq i< 998244352\),没有\(3^i\equiv 1(mod 998244353)\)
下文中我们默认\(p=998244353,G=3\)(\(p\)为模数,\(G\)为原根)
然后我们要尝试用原根相关的东西替换掉单位根:
我们需要知道\(w_n\)怎么求:
因为\(w_n^n\equiv 1(\mod p)\)
所以\(w_n^n\equiv G^{p-1}(\mod p)\),
所以有\(w_n\equiv G^{\frac{p-1}{n}}(\mod p)\)
先证明几个性质,
性质一:\(w^{k+\frac{n}{2}}=-w^k\)
性质二:\(w^{2k}_{2n}=w^k_n\)
这两个性质的证明方法和FFT一致
所以就跟FFT完全一样了
#include<bits/stdc++.h>
using namespace std;
const int N=3e6+10,p=998244353,G=3,Gi=332748118;
int add(int x,int y){return(x+y)%p;}
int mul(int x,int y){return 1ll*x*y%p;}
int mpow(int a,int n){
int ret=1;
while(n){
if(n&1)ret=mul(ret,a);
a=mul(a,a);
n/=2;
}
return ret;
}
int n,m;
int a[N],b[N];
int r[N],limit;
void ntt(int*f,int op){
for(int i=0;i<limit;++i)if(i<r[i])swap(f[i],f[r[i]]);
for(int len=2;len<=limit;len*=2){
int wn=mpow(op==1?G:Gi,(p-1)/len);
for(int j=0;j<limit;j+=len){
int w=1;
for(int k=j;k<j+len/2;++k,w=mul(w,wn)){
int x=f[k],y=mul(w,f[k+len/2]);
f[k]=add(x,y);
f[k+len/2]=add(x,p-y);
}
}
}
if(op==-1){
int inv=mpow(limit,p-2);
for(int i=0;i<limit;++i){
f[i]=mul(f[i],inv);
}
}
}
int main(){
cin>>n>>m;
for(int i=0;i<=n;++i)scanf("%d",a+i);
for(int i=0;i<=m;++i)scanf("%d",b+i);
int l=0;limit=1;
while(limit<=n+m)limit*=2,l++;
for(int i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
ntt(a,1);
ntt(b,1);
for(int i=0;i<limit;++i)a[i]=mul(a[i],b[i]);
ntt(a,-1);
for(int i=0;i<=n+m;++i)printf("%d ",a[i]);
}
Part3:多项式求逆
我们要求
现在已经知道了:
然后可以转化:
根据这个理论基础,我们可以做出多项式求逆:
#include<bits/stdc++.h>
#define res register int
using namespace std;
const int N=3e6+10,p=998244353,G=3,Gi=332748118;
inline int add(res x,res y){return(x+y)>=p?x+y-p:x+y;}
inline int mul(res x,res y){return 1ll*x*y-1ll*x*y/p*p;}
inline int mpow(res a,res n){
res ret=1;
while(n){
if(n&1)ret=mul(ret,a);
a=mul(a,a);
n/=2;
}
return ret;
}
int g[2][N];
void init(){
for(int i=1;i<N;i*=2){
g[0][i]=mpow(G,(p-1)/i);
g[1][i]=mpow(Gi,(p-1)/i);
}
}
int n,m;
int ls[5][N];
int a[N],b[N],c[N];
int r[N];
void ntt(int*f,int limit,int op){
for(res i=0;i<limit;++i)if(i<r[i])swap(f[i],f[r[i]]);
for(res len=2;len<=limit;len*=2){
res wn=op==1?g[0][len]:g[1][len];
for(res j=0;j<limit;j+=len){
res w=1;
for(res k=j;k<j+len/2;++k,w=mul(w,wn)){
res x=f[k],y=mul(w,f[k+len/2]);
f[k]=add(x,y);
f[k+len/2]=add(x,p-y);
}
}
}
if(op==-1){
res inv=mpow(limit,p-2);
for(res i=0;i<limit;++i){
f[i]=mul(f[i],inv);
}
}
}
void mul(int*a,int*b,int*c,int n,int m){
int limit=1;
while(limit<n+m-1)limit*=2;
for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(int i=0;i<limit;++i)ls[0][i]=a[i],ls[1][i]=b[i];
ntt(ls[0],limit,1);
ntt(ls[1],limit,1);
for(res i=0;i<limit;++i)c[i]=mul(ls[0][i],ls[1][i]);
ntt(c,limit,-1);
}
void inv(int*a,int*b,int n){
b[0]=mpow(a[0],p-2);
for(int len=1,l=0,limit;len<2*n;len*=2){
limit=len*2,l++;
for(int i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(int i=0;i<len;++i)ls[0][i]=a[i],ls[1][i]=b[i];
for(int i=len;i<limit;++i)ls[0][i]=ls[1][i]=0;
ntt(ls[0],limit,1),ntt(ls[1],limit,1);
for(int i=0;i<limit;++i){
b[i]=add(mul(2,ls[1][i]),p-(mul(ls[0][i],mul(ls[1][i],ls[1][i]))));
}
ntt(b,limit,-1);
for(int i=len;i<limit;++i)b[i]=0;
}
}
inline int read(){
res ret=0;char c;
for(c=getchar();!isdigit(c);c=getchar());
for(;isdigit(c);ret=ret*10+c-'0',c=getchar());
return ret;
}
int main(){
init();
cin>>n;
for(res i=0;i<n;++i)a[i]=read();
inv(a,c,n);
for(res i=0;i<n;++i)printf("%d ",c[i]);puts("");
}
Part 4:多项式ln
我们要求
可推导:
#include<bits/stdc++.h>
#define res register int
using namespace std;
const int N=3e6+10,p=998244353,G=3,Gi=332748118;
inline int add(res x,res y){return(x+y)>=p?x+y-p:x+y;}
inline int mul(res x,res y){return 1ll*x*y-1ll*x*y/p*p;}
inline int mpow(res a,res n){
res ret=1;
while(n){
if(n&1)ret=mul(ret,a);
a=mul(a,a);
n/=2;
}
return ret;
}
int g[2][N];
void init(){
for(int i=1;i<N;i*=2){
g[0][i]=mpow(G,(p-1)/i);
g[1][i]=mpow(Gi,(p-1)/i);
}
}
int n,m;
int ls[5][N],used;
//0,1 mul
//2,3,4 ln
int a[N],b[N],c[N];
int r[N];
void ntt(int*f,int limit,int op){
for(res i=0;i<limit;++i)if(i<r[i])swap(f[i],f[r[i]]);
for(res len=2;len<=limit;len*=2){
res wn=op==1?g[0][len]:g[1][len];
for(res j=0;j<limit;j+=len){
res w=1;
for(res k=j;k<j+len/2;++k,w=mul(w,wn)){
res x=f[k],y=mul(w,f[k+len/2]);
f[k]=add(x,y);
f[k+len/2]=add(x,p-y);
}
}
}
if(op==-1){
res inv=mpow(limit,p-2);
for(res i=0;i<limit;++i){
f[i]=mul(f[i],inv);
}
}
}
void mul(int*a,int*b,int*c,int n,int m){
int limit=1;
while(limit<n+m-1)limit*=2;
for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(int i=0;i<limit;++i)ls[0][i]=a[i],ls[1][i]=b[i];
ntt(ls[0],limit,1);
ntt(ls[1],limit,1);
for(res i=0;i<limit;++i)c[i]=mul(ls[0][i],ls[1][i]);
ntt(c,limit,-1);
}
void inv(int*a,int*b,int n){
b[0]=mpow(a[0],p-2);
for(int len=1,l=0,limit;len<2*n;len*=2){
limit=len*2,l++;
for(int i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(int i=0;i<len;++i)ls[0][i]=a[i],ls[1][i]=b[i];
for(int i=len;i<limit;++i)ls[0][i]=ls[1][i]=0;
ntt(ls[0],limit,1),ntt(ls[1],limit,1);
for(int i=0;i<limit;++i){
b[i]=add(mul(2,ls[1][i]),p-(mul(ls[0][i],mul(ls[1][i],ls[1][i]))));
}
ntt(b,limit,-1);
for(int i=len;i<limit;++i)b[i]=0;
}
}
void direv(int*a,int*b,int n){
for(int i=1;i<n;++i){
b[i-1]=mul(a[i],i);
}
}
void inter(int*a,int*b,int n){
b[0]=0;
for(int i=1;i<n;++i){
b[i]=mul(a[i-1],mpow(i,p-2));
}
}
void ln(int*a,int*b,int n){
direv(a,ls[2],n);
inv(a,ls[3],n);
mul(ls[2],ls[3],ls[4],n,n);
inter(ls[4],b,2*n);
for(int i=n;i<2*n;++i)b[i]=0;
}
void sqrt(int*a,int*b,int n){
b[0]=1;
}
inline int read(){
res ret=0;char c;
for(c=getchar();!isdigit(c);c=getchar());
for(;isdigit(c);ret=ret*10+c-'0',c=getchar());
return ret;
}
int main(){
init();
cin>>n;
for(res i=0;i<n;++i)a[i]=read();
ln(a,c,n);
for(res i=0;i<n;++i)printf("%d ",c[i]);puts("");
}
Part ex:黑科技
突然学会了一个黑科技:
牛顿迭代,泰勒展开
设\(F,F_0\)是多项式,\(G\)是某个多项式函数。
现在要求
现在已经知道了
我们对\(G(F)\)在\(F_0\)处泰勒展开
因为\(F\)和\(F_0\)的前\(\frac n2\)相同,所以\((F-F_0)\)的前\(\frac n2\)为0,所以对于\(n>1\)的情况\((F-F_0)^n\)的前n为必定为0,对答案无意义,可舍去。
所以有
因为\(G(F)\equiv 0(\mod x^n)\),所以有
这里要注意当求\(G'(F)\)时,我们要把\(F\)当成一个未知数,这样\(G'(F)=G'F\)
Part 4:多项式exp
用黑科技可求解。
给出多项式\(A\)
#include<bits/stdc++.h>
#define res register int
using namespace std;
const int N=3e6+10,p=998244353,G=3,Gi=332748118;
inline int add(res x,res y){return(x+y)>=p?x+y-p:x+y;}
inline int mul(res x,res y){return 1ll*x*y-1ll*x*y/p*p;}
inline int mpow(res a,res n){
res ret=1;
while(n){
if(n&1)ret=mul(ret,a);
a=mul(a,a);
n/=2;
}
return ret;
}
int g[2][N];
void init(){
for(int i=1;i<N;i*=2){
g[0][i]=mpow(G,(p-1)/i);
g[1][i]=mpow(Gi,(p-1)/i);
}
}
int n,m;
int ls[10][N],used;
//0,1 mul
//2,3,4 ln
//5,6 exp
int a[N],b[N],c[N];
int r[N];
void ntt(int*f,int limit,int op){
for(res i=0;i<limit;++i)if(i<r[i])swap(f[i],f[r[i]]);
for(res len=2;len<=limit;len*=2){
res wn=op==1?g[0][len]:g[1][len];
for(res j=0;j<limit;j+=len){
res w=1;
for(res k=j;k<j+len/2;++k,w=mul(w,wn)){
res x=f[k],y=mul(w,f[k+len/2]);
f[k]=add(x,y);
f[k+len/2]=add(x,p-y);
}
}
}
if(op==-1){
res inv=mpow(limit,p-2);
for(res i=0;i<limit;++i){
f[i]=mul(f[i],inv);
}
}
}
void mul(int*a,int*b,int*c,int n,int m){
int limit=1;
while(limit<n+m-1)limit*=2;
for(int i=0;i<limit;++i)ls[0][i]=ls[1][i]=0;
for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(int i=0;i<limit;++i)ls[0][i]=a[i],ls[1][i]=b[i];
ntt(ls[0],limit,1);
ntt(ls[1],limit,1);
for(res i=0;i<limit;++i)c[i]=mul(ls[0][i],ls[1][i]);
ntt(c,limit,-1);
}
void inv(int*a,int*b,int n){
b[0]=mpow(a[0],p-2);
for(int len=1,limit;len<2*n;len*=2){
limit=len*2;
for(int i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(int i=0;i<len;++i)ls[0][i]=a[i],ls[1][i]=b[i];
for(int i=len;i<limit;++i)ls[0][i]=ls[1][i]=0;
ntt(ls[0],limit,1),ntt(ls[1],limit,1);
for(int i=0;i<limit;++i){
b[i]=add(mul(2,ls[1][i]),p-(mul(ls[0][i],mul(ls[1][i],ls[1][i]))));
}
ntt(b,limit,-1);
for(int i=len;i<limit;++i)b[i]=0;
}
}
void direv(int*a,int*b,int n){
for(int i=1;i<n;++i){
b[i-1]=mul(a[i],i);
}
}
void inter(int*a,int*b,int n){
b[0]=0;
for(int i=1;i<n;++i){
b[i]=mul(a[i-1],mpow(i,p-2));
}
}
void ln(int*a,int*b,int n){
direv(a,ls[2],n);
inv(a,ls[3],n);
mul(ls[2],ls[3],ls[4],n,n);
inter(ls[4],b,2*n);
for(int i=n;i<2*n;++i)b[i]=0;
}
void exp(int*a,int*b,int n){
b[0]=1;
for(int len=1;len<2*n;len*=2){
int limit=len*2;
ln(b,ls[5],len);
for(int i=0;i<len;++i){
ls[5][i]=add(p-ls[5][i],a[i]);
}
ls[5][0]=add(ls[5][0],1);
for(int i=0;i<len;++i)ls[6][i]=b[i];
mul(ls[5],ls[6],b,len,len);
for(int i=len;i<limit;++i)b[i]=0;
}
}
inline int read(){
res ret=0;char c;
for(c=getchar();!isdigit(c);c=getchar());
for(;isdigit(c);ret=ret*10+c-'0',c=getchar());
return ret;
}
int main(){
init();
cin>>n;
for(res i=0;i<n;++i)a[i]=read();
exp(a,c,n);
for(res i=0;i<n;++i)printf("%d ",c[i]);puts("");
}
Part 5:多项式开根
也使用上文黑科技:
给出多项式\(A\)
#include<bits/stdc++.h>
#define res register int
#define ll long long
ll js;
using namespace std;
const int N=3e6+10,p=998244353,G=3,Gi=332748118;
inline int add(res x,res y){return(x+y)%p;}
inline int mul(res x,res y){return 1ll*x*y%p;}
inline int mpow(res a,res n){
res ret=1;
while(n){
if(n&1)ret=mul(ret,a);
a=mul(a,a);
n/=2;
}
return ret;
}
int g[2][N];
int inv2;
void init(){
inv2=mpow(2,p-2);
for(int i=1;i<N;i*=2){
g[0][i]=mpow(G,(p-1)/i);
g[1][i]=mpow(Gi,(p-1)/i);
}
}
int n,m;
int ls[7][N],used;
//0,1 mul inv
//2,3,4 ln sqrt
//5,6 exp
int a[N],b[N],c[N];
int r[N];
void ntt(int*f,res limit,res op){
for(res i=0;i<limit;++i)if(i<r[i])swap(f[i],f[r[i]]);
for(res len=2;len<=limit;len*=2){
res wn=op==1?g[0][len]:g[1][len];
for(res j=0;j<limit;j+=len){
res w=1;
for(res k=j;k<j+len/2;++k,w=mul(w,wn)){
res x=f[k],y=mul(w,f[k+len/2]);
f[k]=add(x,y);
f[k+len/2]=add(x,p-y);
}
}
}
if(op==-1){
res inv=mpow(limit,p-2);
for(res i=0;i<limit;++i){
f[i]=mul(f[i],inv);
}
}
}
void mul(int*a,int*b,int*c,int n,int m){
int limit=1;
while(limit<n+m-1)limit*=2;
for(res i=0;i<limit;++i)ls[0][i]=ls[1][i]=0;
for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(res i=0;i<limit;++i)ls[0][i]=a[i],ls[1][i]=b[i];
ntt(ls[0],limit,1);
ntt(ls[1],limit,1);
for(res i=0;i<limit;++i)c[i]=mul(ls[0][i],ls[1][i]);
ntt(c,limit,-1);
}
void inv(int*a,int*b,int n){
b[0]=mpow(a[0],p-2);
for(res len=2,limit;len<2*n;len*=2){
limit=len*2;
for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(res i=0;i<len;++i)ls[0][i]=a[i],ls[1][i]=b[i];
for(res i=len;i<limit;++i)ls[0][i]=ls[1][i]=0;
ntt(ls[0],limit,1),ntt(ls[1],limit,1);
for(res i=0;i<limit;++i){
b[i]=mul(add(2,p-(mul(ls[0][i],ls[1][i]))),ls[1][i]);
}
ntt(b,limit,-1);
for(res i=len;i<limit;++i)b[i]=0;
}
}
inline void direv(int*a,int*b,int n){
for(res i=1;i<n;++i){
b[i-1]=mul(a[i],i);
}
}
inline void inter(int*a,int*b,int n){
b[0]=0;
for(res i=1;i<n;++i){
b[i]=mul(a[i-1],mpow(i,p-2));
}
}
void ln(int*a,int*b,int n){
direv(a,ls[2],n);
inv(a,ls[3],n);
mul(ls[2],ls[3],ls[4],n,n);
inter(ls[4],b,2*n);
for(res i=n;i<2*n;++i)b[i]=0;
}
void exp(int*a,int*b,int n){
b[0]=1;
for(res len=2;len<2*n;len*=2){
res limit=len*2;
ln(b,ls[5],len);
for(res i=0;i<len;++i){
ls[5][i]=add(p-ls[5][i],a[i]);
}
ls[5][0]=add(ls[5][0],1);
for(res i=0;i<len;++i)ls[6][i]=b[i];
mul(ls[5],ls[6],b,len,len);
for(res i=len;i<limit;++i)b[i]=0;
}
}
void sqrt(int*a,int*b,int n){
b[0]=1;
for(res len=2;len<2*n;len*=2){
res limit=len*2;
inv(b,ls[2],len);
for(res i=0;i<len;++i)ls[3][i]=a[i];
mul(ls[2],ls[3],ls[4],len,len);
for(res i=0;i<len;++i)b[i]=mul(add(b[i],ls[4][i]),inv2);
for(res i=len;i<limit;++i)b[i]=0;
}
}
inline int read(){
res ret=0;char c;
for(c=getchar();!isdigit(c);c=getchar());
for(;isdigit(c);ret=ret*10+c-'0',c=getchar());
return ret;
}
int main(){
init();
cin>>n;
for(res i=0;i<n;++i)a[i]=read();
sqrt(a,c,n);
for(res i=0;i<n;++i)printf("%d ",c[i]);puts("");
}
Part 6:多项式快速幂
直接换底公式即可
#include<bits/stdc++.h>
#define res register int
using namespace std;
const int N=3e6+10,p=998244353,G=3,Gi=332748118;
inline int add(res x,res y){return(x+y)%p;}
inline int mul(res x,res y){return 1ll*x*y%p;}
inline int mpow(res a,res n){
res ret=1;
while(n){
if(n&1)ret=mul(ret,a);
a=mul(a,a);
n/=2;
}
return ret;
}
int g[2][N];
int inv2;
void init(){
inv2=mpow(2,p-2);
for(int i=1;i<N;i*=2){
g[0][i]=mpow(G,(p-1)/i);
g[1][i]=mpow(Gi,(p-1)/i);
}
}
int n,m;
int ls[7][N],used;
//0,1 mul inv
//2,3,4 ln sqrt
//5,6 exp
int a[N],b[N],c[N];
int r[N];
void ntt(int*f,res limit,res op){
for(res i=0;i<limit;++i)if(i<r[i])swap(f[i],f[r[i]]);
for(res len=1;len<=limit;len*=2){
res wn=op==1?g[0][len]:g[1][len];
for(res j=0;j<limit;j+=len){
res w=1;
for(res k=j;k<j+len/2;++k,w=mul(w,wn)){
res x=f[k],y=mul(w,f[k+len/2]);
f[k]=add(x,y);
f[k+len/2]=add(x,p-y);
}
}
}
if(op==-1){
res inv=mpow(limit,p-2);
for(res i=0;i<limit;++i){
f[i]=mul(f[i],inv);
}
}
}
void mul(int*a,int*b,int*c,int n,int m){
int limit=1;
while(limit<n+m-1)limit*=2;
for(res i=0;i<limit;++i)ls[0][i]=ls[1][i]=0;
for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(res i=0;i<limit;++i)ls[0][i]=a[i],ls[1][i]=b[i];
ntt(ls[0],limit,1);
ntt(ls[1],limit,1);
for(res i=0;i<limit;++i)c[i]=mul(ls[0][i],ls[1][i]);
ntt(c,limit,-1);
}
void inv(int*a,int*b,int n){
b[0]=mpow(a[0],p-2);
res limit;
for(res len=1;len<2*n;len*=2){
limit=len*2;
for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(res i=0;i<len;++i)ls[0][i]=a[i],ls[1][i]=b[i];
for(res i=len;i<limit;++i)ls[0][i]=ls[1][i]=0;
ntt(ls[0],limit,1),ntt(ls[1],limit,1);
for(res i=0;i<limit;++i){
b[i]=mul(add(2,p-(mul(ls[0][i],ls[1][i]))),ls[1][i]);
}
ntt(b,limit,-1);
for(res i=len;i<limit;++i)b[i]=0;
}
for(int i=n;i<limit;++i)b[i]=0;
}
inline void direv(int*a,int*b,int n){
for(res i=1;i<n;++i){
b[i-1]=mul(a[i],i);
}
b[n-1]=0;
}
inline void inter(int*a,int*b,int n){
b[0]=0;
for(res i=1;i<n;++i){
b[i]=mul(a[i-1],mpow(i,p-2));
}
}
void ln(int*a,int*b,int n){
direv(a,ls[2],n);
inv(a,ls[3],n);
mul(ls[2],ls[3],ls[4],n,n);
inter(ls[4],b,2*n);
for(res i=n;i<2*n;++i)b[i]=0;
for(res i=0;i<n;++i)ls[2][i]=ls[3][i]=ls[4][i]=0;
}
void exp(int*a,int*b,int n){
b[0]=1;
for(res len=1;len<2*n;len*=2){
res limit=len*2;
ln(b,ls[5],len);
for(res i=0;i<len;++i){
ls[5][i]=add(p-ls[5][i],a[i]);
}
ls[5][0]=add(ls[5][0],1);
for(res i=0;i<len;++i)ls[6][i]=b[i];
mul(ls[5],ls[6],b,len,len);
for(res i=len;i<limit;++i)b[i]=0;
}
}
void sqrt(int*a,int*b,int n){
b[0]=1;
for(res len=2;len<2*n;len*=2){
res limit=len*2;
inv(b,ls[2],len);
for(res i=0;i<len;++i)ls[3][i]=a[i];
mul(ls[2],ls[3],ls[4],len,len);
for(res i=0;i<len;++i)b[i]=mul(add(b[i],ls[4][i]),inv2);
for(res i=len;i<limit;++i)b[i]=0;
}
}
inline int read(){
res ret=0;char c;
for(c=getchar();!isdigit(c);c=getchar());
for(;isdigit(c);ret=add(mul(ret,10),c-'0'),c=getchar());
return ret;
}
int main(){
init();
n=read(),m=read();
for(res i=0;i<n;++i)a[i]=read();
ln(a,b,n);
for(int i=0;i<n;++i)b[i]=mul(b[i],m);
exp(b,c,n);
for(res i=0;i<n;++i)printf("%d ",c[i]);puts("");
}