多项式总结

Part1:FFT(fast fast tle)

LuoguP3803 【模板】多项式乘法(FFT)

前置知识:复数,单位根,多项式的系数表达法,多项式的点值表达法

  • 复数:

可以表示为\((a+bi)\),可以看做原点到\((a,b)\)一个向量,其中\(i=\sqrt{-1}\)
复数可以进行加,减,乘(向量的除法有点问题),其中

\[(a+bi)+(c+di)=((a+c)+(b+d)i) \]

\[(a+bi)-(c+di)=((a-c)+(b-d)i) \]

\[(a+bi)(c+di)=a(c+di)+bi(c+di)=ac+adi+bci+bdi^2=((ac-bd)+(ad+bc)i) \]

即:

\[(a+bi)(c+di)=((ac-bd)+(ad+bc)i) \]

同时复数的乘还有和向量一样的几何意义:模长相乘,幅角相加

  • 单位根:

在OI中,经常用到2的正整数次幂相关的数,因为这样方便处理,为方便,我们规定下文的\(n\)为2的正整数次幂。
定义:如果\(w_n^n=1\)那么\(w_n\)\(n\)次单位根
因为\(w_n^n=1\),根据复数乘的几何意义,可知模长为1,幅角为\(\frac{2\pi}{n}\),易得单位根

\[w_n=(\cos(\frac{2\pi}{n})+\sin(\frac{2\pi}{n})i) \]

然后\(w^k_n\)的幅角为\(\frac{2k\pi}{n}\),所以

\[w^k_n=(\cos(\frac{2k\pi}{n})+\sin(\frac{2k\pi}{n})i) \]

单位根这里还需要两个性质:
性质一:

\[w^{2k}_{2n}=(\cos(\frac{4k\pi}{2n})+\sin(\frac{4k\pi}{2n})i)=(\cos(\frac{2k\pi}{n})+\sin(\frac{2k\pi}{n})i)=w^k_n \]

\[w^{2k}_{2n}=w^k_n \]

性质二:

\[w^{k+\frac{n}{2}}_n=w^kw^{\frac{n}{2}}=w^k(\cos(\frac{n\pi}{n})+\sin(\frac{n\pi}{n})i)=-w^k \]

\[w^{k+\frac{n}{2}}=-w^k \]

  • 多项式的系数表达法:

就是平时的表达方法,用\(n+1\)个系数表示一个\(n\)次多项式,比如:

\[f(x)=a_0+a_1x+a_2x^2+a_3x^3+...+a_nx^n \]

该方法易读,也易求值,但很难快速求卷积。

  • 多项式的点值表达法:

就是用\(n+1\)个点来表示一个\(n\)次多项式,比如:

\[f(x)=\{(x_0,y_0),(x_1,y_1),(x_2,y_2),...,(x_n,y_n)\} \]

该方法不易理解,但很容易求卷积
如果两个多项式\(f,g\)满足\(fx_0==gx_0,fx_1==gx_1,...,fx_n==gx_n\),则新多项式

\[h(x)=\{(x_0,fx_0 gx_0),...,(x_n,fx_n gx_n)\} \]

正题:快速傅里叶变换(FFT)以及快速傅里叶逆变换(IFFT)

从上面两种多项式的表达方式中,我们可以发现如果能快速的把多项式在系数与点值中转换,就可以快速的获取两个多项式的卷积。

  • 1、系数多项式转点值多项式(快速傅里叶变换)

给出多项式

\[f(x)=a_0+a_1x+a_2x^2+...+a_nx^n \]

我们需要快速求出\(f(1),f(w_n),...,f(w^{n-1}_{n-1})\)
先将\(f\)按奇偶分类分为

\[f(x)=(a_0+a_2x^2+...+a_nx^n)+(a_1x+a_3x^3+...+a_{n-1}x^{n-1}) \]

我们设

\[f1(x)=a_0+a_2x+...+a_nx^{\frac{n}{2}} \]

\[f2(x)=a_1+a_3x+...+a_{n-1}x^{\frac{n-2}{2}} \]

那么有

\[f(x)=f1(x^2)+xf2(x^2) \]

带入\(x=w^k_n\),

\[f(w^k_n)=f1(w^{2k}_n)+w^k_nf2(w^{2k}_n) \]

带入\(x=w^{k+\frac{n}{2}}_n\)

\[f(w^{k+\frac{n}{2}}_n)=f1(w^{2k+n}_n)+w^{k+\frac{n}{2}}_nf2(w^{2k+n}_n) \]

根据性质二,有

\[f(w^{k+\frac{n}{2}}_n)=f1(w^{2k}_n)-w^k_nf2(w^{2k}_n) \]

可以发现\(f(w^k_n)\)\(f(w^{k+\frac{n}{2}}_n)\)两项只差第二项的系数,所以我们可以只用一半的时间就处理整个多项式。
可以发现,如果我们递归的处理,就可以用\(O(n\log n)\)的复杂度实现FFT。

  • 2、点值多项式转系数多项式(快速傅里叶逆变换)

我们可以把式子列成一个矩阵(没学过矩阵可以先学再看或跳过推导):
其中\((a_0,a_1,a_2,...,a_{n-1})\)为系数表达法,\((y_0,y_1,y_2,...,y_{n-1})\)为省略\(x\)的点值表达法。

\[\left|\begin{matrix}1&1&1&\cdots&1\\1&w_n&w_n^2&\cdots&w_n^{n-1}\\1&w_n^2&w_n^4&\cdots&w_n^{2(n-1)}\\\vdots&\vdots&\vdots&\ddots&\vdots\\1&w_n^{n-1}&w_n^{2(n-1)}&\cdots&w_n^{(n-1)(n-1)}\end{matrix} \right|\left|\begin{matrix}a_0\\a_1\\a_2\\\vdots\\\ a_{n-1}\end{matrix} \right|=\left|\begin{matrix}y_0\\y_1\\y_2\\\vdots\\\ y_{n-1}\end{matrix} \right|\]

如果我们能快速求出左边这个矩阵的逆矩阵,我们就能快速转换。
考虑矩阵求逆(\(O(n^3)\)完全负优化)
但我们可以发现原矩阵中所有数之间是有关联的,我们可以考虑转换。

\(V\)为原矩阵,\(G\)为逆矩阵,考虑最终矩阵\(E\)\((i,j)\)上的值:

\[E(i,j)=\sum\limits_{k=0}^{n-1}G(j,k) V(k,i)\\ =\sum\limits_{k=0}^{n-1}G(j,k)w_n^{ki} \]

因为\(V\)\(G\)互逆,所以\(E\)是单位矩阵,只有当\(i=j\)时才会有值1,否则为0。

\[\sum\limits_{k=0}^{n-1}G(j,k)w_n^{ki}=[i==j] \]

我们先证明一个引理:当\(k\)不是\(n\)的倍数时

\[\sum\limits_{j=0}^{n-1}w_n^{kj}=0 \]

由等比数列求和得

\[\sum\limits_{j=0}^{n-1}w_n^{kj}=\frac{w_n^{kn}-1}{1-w_n^k}=\frac{1-1}{w_n^k-1} \]

因为\(k\)不是\(n\)的倍数,所以\(w_n^k\not=1\),即分母不为0,所以该引理成立。

根据这个引理,可以发现矩阵\(G\)有一个比较简单的构造方式,即\(G(i,j)=w_n^{-ij}\)
这时

\[E(i,j)=\sum\limits_{k=0}w_n^{-jk}w_n^{ki} \]

\[=\sum\limits_{k=0}w_n^{i-j} \]

\(i-j\)不为\(n\)的倍数(不为0时),\(E(i,j)=1\),但当\(i=j\)时,已知\(E(i,j)=n\),跟单位矩阵有点偏差,我们在前面加一个\(\frac{1}{n}\)

好吧,这个推导其实有些牵强,只用把他当做结论记就可以了。

这样,我们就有:

\[\left|\begin{matrix}1&1&1&\cdots&1\\1&w_n^{-1}&w_n^{-2}&\cdots&w_n^{1-n}\\1&w_n^{-2}&w_n^{-4}&\cdots&w_n^{2(1-n)}\\\vdots&\vdots&\vdots&\ddots&\vdots\\1&w_n^{1-n}&w_n^{2(1-n)}&\cdots&w_n^{(1-n)(1-n)}\end{matrix} \right|\left|\begin{matrix}y_0\\y_1\\y_2\\\vdots\\\ y_{n-1}\end{matrix} \right|=\left|\begin{matrix}na_0\\na_1\\na_2\\\vdots\\\ na_{n-1}\end{matrix} \right|\]

这样我们就可以用类似系数转点值的方法转换了,只是这边的单位根要取反,其实在使用起来时就是

\[(a+bi)^{-1}=(a-bi) \]

非常简单,只用在FFT的基础上略作修改即可。

最后因为直接乘出来的答案是真实值的\(n\)倍,所以要除以\(n\)

然后我们就可以写出最基本的递归版FFT(虽然C++有自带complex类型,但用起来会比较慢,建议手写一个):

#include<bits/stdc++.h>
using namespace std;
const int N=1000010;
const double pi=acos(-1);
int n,m,lg[N<<1];
struct Complex{
    double x,y;
    Complex(double x=0,double y=0):x(x),y(y){}
};
Complex operator+(Complex a,Complex b){return Complex(a.x+b.x,a.y+b.y);}
Complex operator-(Complex a,Complex b){return Complex(a.x-b.x,a.y-b.y);}
Complex operator*(Complex a,Complex b){return Complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
Complex c[N<<2],b[N<<2],t[N<<2],s[50];
void fft(Complex *f,int l,int len,int op){
    if(!len)return;
    for(int i=l;i<l+(len<<1);++i)t[i]=f[i];
    for(int i=l;i<l+len;++i)f[i]=t[l+((i-l)<<1)],f[i+len]=t[l+((i-l)<<1|1)];
    fft(f,l,len>>1,op);
    fft(f,l+len,len>>1,op);
    Complex tmp=s[lg[len]],buf=Complex(1,0),d;
    tmp.y*=op;
    for(int i=l;i<l+len;++i){
        d=buf*f[i+len];
        t[i]=f[i]+d;
        t[i+len]=f[i]-d;
        buf=buf*tmp;
    }
    for(int i=l;i<l+(len<<1);++i)f[i]=t[i];
}
int main(){
    int n,m;
    cin>>n>>m;
    for(int i=2;i<=n+m;++i)lg[i]=lg[i>>1]+1;
    for(int i=0;i<=n;++i)scanf("%lf",&b[i].x);
    for(int i=0;i<=m;++i)scanf("%lf",&c[i].x);
    for(m+=n,n=1;n<=m;n<<=1);
    for(int i=1,j=0;i<=n;i<<=1,++j)s[j]=Complex(cos(pi/i),sin(pi/i));
    fft(b,0,n>>1,1);
    fft(c,0,n>>1,1);
    for(int i=0;i<n;++i)b[i]=b[i]*c[i];
    fft(b,0,n>>1,-1);
    for(int i=0;i<=m;++i)printf("%.0lf ",fabs(b[i].x)/n);
}

但递归版的着实很慢,这时就要用到神奇的:

二进制反转

因为开始的时候我们把所有数按奇偶性分类,所以我们不能直接枚举区间长度然后处理。但正是因为我们按奇偶分类,我们可以直接按二进制分类,然后一个数的真实位置就是该数的下标按二进制反转后的值。

求真实位置的方法其实很简单:

for(int i=0;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));

一个数的二进制反转为这个数除2的反转除2(得到后(l-1)位),如果这个数第一位为1,那么把最后一位加1。
然后就可以用非递归的写法优化递归的大常数:

#include<bits/stdc++.h>
#define eps 1e-6
using namespace std;
const int N=4e6+10;
const double pi=acos(-1);
struct Complex{
    double x,y;
    Complex(double x=0,double y=0):x(x),y(y){}
};
Complex operator+(Complex a,Complex b){return Complex(a.x+b.x,a.y+b.y);}
Complex operator-(Complex a,Complex b){return Complex(a.x-b.x,a.y-b.y);}
Complex operator*(Complex a,Complex b){return Complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
int limit,l;
int r[N];
void FFT(Complex*f,int op){
    for(int i=0;i<limit;++i){
        if(i<r[i])swap(f[i],f[r[i]]); 
    }
    for(int mid=2;mid<=limit;mid*=2){
        Complex wn=Complex(cos(2*pi/mid),op*sin(2*pi/mid));
        for(int j=0;j<limit;j+=mid){
            Complex w=Complex(1,0);
            for(int k=j;k<j+mid/2;++k,w=w*wn){
                Complex x=f[k],y=w*f[k+mid/2];
                f[k]=x+y;
                f[k+mid/2]=x-y;
            }
        }
    }
    if(op==-1){
        for(int i=0;i<limit;++i){
            f[i].x/=limit;
        }
    }
}
int n,m;
Complex a[N],b[N]; 
int main(){
    cin>>n>>m;
    for(int i=0;i<=n;++i)scanf("%lf",&a[i].x);
    for(int i=0;i<=m;++i)scanf("%lf",&b[i].x);
    for(limit=1,l=0;limit<=n+m;limit*=2,l++);
    for(int i=0;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    FFT(a,1),FFT(b,1);
    for(int i=0;i<limit;++i)a[i]=a[i]*b[i];
    FFT(a,-1);
    for(int i=0;i<=n+m;++i){
    	if(fabs(a[i].x)<eps)printf("0 ");
        else printf("%.0lf ",a[i].x);
    }
}

例题:[ZJOI2014]力

Part2:NTT(快速数论变换)

在FFT的时候,因为会用到大量sin和cos以及double的乘法,会让精度有巨大的损失,所以我们要用一些其他的方法来让精度不损失。这时就要用到快速数论变换。

前置知识:原根

  • 原根

原根:是一个数学符号。设\(m\)是正整数,\(a\)是整数,若\(a\)\(m\)的阶等于\(\varphi(m)\),则称\(a\)为模\(m\)的一个原根。

阶:使\(a^n\equiv 1(\mod p)\)成立的最小正整数\(n\)叫做\(a\)\(p\)的阶。这里\(\equiv\)指同余符号,代表\(a^n\)除以\(p\)的余数跟1除以\(p\)的余数相等。

一般情况下模数为998244353,而998244353的原根为3。因为有:

\[3^{\varphi(998233453)}\equiv 3^{998244352}\equiv 1(mod 998244353) \]

而且对于\(1\leq i< 998244352\),没有\(3^i\equiv 1(mod 998244353)\)
下文中我们默认\(p=998244353,G=3\)\(p\)为模数,\(G\)为原根)

然后我们要尝试用原根相关的东西替换掉单位根:

我们需要知道\(w_n\)怎么求:
因为\(w_n^n\equiv 1(\mod p)\)
所以\(w_n^n\equiv G^{p-1}(\mod p)\)
所以有\(w_n\equiv G^{\frac{p-1}{n}}(\mod p)\)

先证明几个性质,

性质一:\(w^{k+\frac{n}{2}}=-w^k\)
性质二:\(w^{2k}_{2n}=w^k_n\)
这两个性质的证明方法和FFT一致

所以就跟FFT完全一样了

#include<bits/stdc++.h>
using namespace std;
const int N=3e6+10,p=998244353,G=3,Gi=332748118;
int add(int x,int y){return(x+y)%p;}
int mul(int x,int y){return 1ll*x*y%p;}
int mpow(int a,int n){
    int ret=1;
    while(n){
        if(n&1)ret=mul(ret,a);
        a=mul(a,a);
        n/=2;
	}
    return ret;
}
int n,m;
int a[N],b[N];
int r[N],limit;
void ntt(int*f,int op){
    for(int i=0;i<limit;++i)if(i<r[i])swap(f[i],f[r[i]]);
    for(int len=2;len<=limit;len*=2){
        int wn=mpow(op==1?G:Gi,(p-1)/len);
        for(int j=0;j<limit;j+=len){
            int w=1;
            for(int k=j;k<j+len/2;++k,w=mul(w,wn)){
                int x=f[k],y=mul(w,f[k+len/2]);
                f[k]=add(x,y);
                f[k+len/2]=add(x,p-y);
			}
		}
	}
	if(op==-1){
	    int inv=mpow(limit,p-2);
	    for(int i=0;i<limit;++i){
	        f[i]=mul(f[i],inv);
		}
	}
}
int main(){
	cin>>n>>m;
	for(int i=0;i<=n;++i)scanf("%d",a+i);
	for(int i=0;i<=m;++i)scanf("%d",b+i);
	int l=0;limit=1;
	while(limit<=n+m)limit*=2,l++;
	for(int i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
	ntt(a,1);
	ntt(b,1);
	for(int i=0;i<limit;++i)a[i]=mul(a[i],b[i]);
	ntt(a,-1);
	for(int i=0;i<=n+m;++i)printf("%d ",a[i]);
}

例题:[AH2017/HNOI2017]礼物

Part3:多项式求逆

我们要求

\[A\times B\equiv 1(\mod x^n) \]

现在已经知道了:

\[A\times B'\equiv 1(\mod x^{\frac{n}{2}}) \]

然后可以转化:

\[A\times (B-B')\equiv0(\mod x^{\frac{n}{2}}) \]

\[B-B'\equiv 0(\mod x^{\frac{n}{2}}) \]

\[(B-B')^2\equiv0(\mod x^n) \]

\[B^2+B'^2-2BB'\equiv0(\mod x^n) \]

\[AB^2+AB'^2-2ABB'\equiv0(\mod x^n) \]

\[B+AB'^2-2B'\equiv0(\mod x^n) \]

\[B\equiv2B'-AB'^2(\mod x^n) \]

根据这个理论基础,我们可以做出多项式求逆:

#include<bits/stdc++.h>
#define res register int
using namespace std;
const int N=3e6+10,p=998244353,G=3,Gi=332748118;
inline int add(res x,res y){return(x+y)>=p?x+y-p:x+y;}
inline int mul(res x,res y){return 1ll*x*y-1ll*x*y/p*p;}
inline int mpow(res a,res n){
    res ret=1;
    while(n){
        if(n&1)ret=mul(ret,a);
        a=mul(a,a);
        n/=2;
    }
    return ret;
}    
int g[2][N];
void init(){
    for(int i=1;i<N;i*=2){
        g[0][i]=mpow(G,(p-1)/i);
        g[1][i]=mpow(Gi,(p-1)/i);
    }
}
int n,m;
int ls[5][N];
int a[N],b[N],c[N];
int r[N];
void ntt(int*f,int limit,int op){
    for(res i=0;i<limit;++i)if(i<r[i])swap(f[i],f[r[i]]);
    for(res len=2;len<=limit;len*=2){
        res wn=op==1?g[0][len]:g[1][len];
        for(res j=0;j<limit;j+=len){
            res w=1;
            for(res k=j;k<j+len/2;++k,w=mul(w,wn)){
                res x=f[k],y=mul(w,f[k+len/2]);
                f[k]=add(x,y);
                f[k+len/2]=add(x,p-y);
            }
        }
    }
    if(op==-1){
        res inv=mpow(limit,p-2);
        for(res i=0;i<limit;++i){
            f[i]=mul(f[i],inv);
        }
    }
}
void mul(int*a,int*b,int*c,int n,int m){
    int limit=1;
    while(limit<n+m-1)limit*=2;
    for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
    for(int i=0;i<limit;++i)ls[0][i]=a[i],ls[1][i]=b[i];
    ntt(ls[0],limit,1);
    ntt(ls[1],limit,1);
    for(res i=0;i<limit;++i)c[i]=mul(ls[0][i],ls[1][i]);
    ntt(c,limit,-1);
}
void inv(int*a,int*b,int n){
	b[0]=mpow(a[0],p-2);
    for(int len=1,l=0,limit;len<2*n;len*=2){
        limit=len*2,l++;
        for(int i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
        for(int i=0;i<len;++i)ls[0][i]=a[i],ls[1][i]=b[i];
        for(int i=len;i<limit;++i)ls[0][i]=ls[1][i]=0;
        ntt(ls[0],limit,1),ntt(ls[1],limit,1);
        for(int i=0;i<limit;++i){
            b[i]=add(mul(2,ls[1][i]),p-(mul(ls[0][i],mul(ls[1][i],ls[1][i]))));
		}
		ntt(b,limit,-1);
		for(int i=len;i<limit;++i)b[i]=0;
	}
}
inline int read(){
    res ret=0;char c;
    for(c=getchar();!isdigit(c);c=getchar());
    for(;isdigit(c);ret=ret*10+c-'0',c=getchar());
    return ret;
}
int main(){
	init();
    cin>>n;
    for(res i=0;i<n;++i)a[i]=read();
    inv(a,c,n);
    for(res i=0;i<n;++i)printf("%d ",c[i]);puts("");
}

Part 4:多项式ln

我们要求

\[A=ln(B) \]

可推导:

\[A'=(ln(B))' \]

\[A'=\frac{1}{B}B' \]

\[A=\int\frac{1}{B}B'dx \]

#include<bits/stdc++.h>
#define res register int
using namespace std;
const int N=3e6+10,p=998244353,G=3,Gi=332748118;
inline int add(res x,res y){return(x+y)>=p?x+y-p:x+y;}
inline int mul(res x,res y){return 1ll*x*y-1ll*x*y/p*p;}
inline int mpow(res a,res n){
    res ret=1;
    while(n){
        if(n&1)ret=mul(ret,a);
        a=mul(a,a);
        n/=2;
    }
    return ret;
}    
int g[2][N];
    void init(){
        for(int i=1;i<N;i*=2){
            g[0][i]=mpow(G,(p-1)/i);
            g[1][i]=mpow(Gi,(p-1)/i);
        }
    }
int n,m;
int ls[5][N],used;
//0,1 mul
//2,3,4 ln
int a[N],b[N],c[N];
int r[N];
void ntt(int*f,int limit,int op){
    for(res i=0;i<limit;++i)if(i<r[i])swap(f[i],f[r[i]]);
    for(res len=2;len<=limit;len*=2){
        res wn=op==1?g[0][len]:g[1][len];
        for(res j=0;j<limit;j+=len){
            res w=1;
            for(res k=j;k<j+len/2;++k,w=mul(w,wn)){
                res x=f[k],y=mul(w,f[k+len/2]);
                f[k]=add(x,y);
                f[k+len/2]=add(x,p-y);
            }
        }
    }
    if(op==-1){
        res inv=mpow(limit,p-2);
        for(res i=0;i<limit;++i){
            f[i]=mul(f[i],inv);
        }
    }
}
void mul(int*a,int*b,int*c,int n,int m){
    int limit=1;
    while(limit<n+m-1)limit*=2;
    for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
    for(int i=0;i<limit;++i)ls[0][i]=a[i],ls[1][i]=b[i];
    ntt(ls[0],limit,1);
    ntt(ls[1],limit,1);
    for(res i=0;i<limit;++i)c[i]=mul(ls[0][i],ls[1][i]);
    ntt(c,limit,-1);
}
void inv(int*a,int*b,int n){
	b[0]=mpow(a[0],p-2);
    for(int len=1,l=0,limit;len<2*n;len*=2){
        limit=len*2,l++;
        for(int i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
        for(int i=0;i<len;++i)ls[0][i]=a[i],ls[1][i]=b[i];
        for(int i=len;i<limit;++i)ls[0][i]=ls[1][i]=0;
        ntt(ls[0],limit,1),ntt(ls[1],limit,1);
        for(int i=0;i<limit;++i){
            b[i]=add(mul(2,ls[1][i]),p-(mul(ls[0][i],mul(ls[1][i],ls[1][i]))));
		}
		ntt(b,limit,-1);
		for(int i=len;i<limit;++i)b[i]=0;
	}
}
void direv(int*a,int*b,int n){
    for(int i=1;i<n;++i){
        b[i-1]=mul(a[i],i);
	}
}
void inter(int*a,int*b,int n){
	b[0]=0;
    for(int i=1;i<n;++i){
        b[i]=mul(a[i-1],mpow(i,p-2));
	}
}
void ln(int*a,int*b,int n){
	direv(a,ls[2],n);
	inv(a,ls[3],n);
	mul(ls[2],ls[3],ls[4],n,n);
	inter(ls[4],b,2*n);
	for(int i=n;i<2*n;++i)b[i]=0;
}
void sqrt(int*a,int*b,int n){
    b[0]=1;
}
inline int read(){
    res ret=0;char c;
    for(c=getchar();!isdigit(c);c=getchar());
    for(;isdigit(c);ret=ret*10+c-'0',c=getchar());
    return ret;
}
int main(){
	init();
    cin>>n;
    for(res i=0;i<n;++i)a[i]=read();
    ln(a,c,n);
    for(res i=0;i<n;++i)printf("%d ",c[i]);puts("");
}

Part ex:黑科技

突然学会了一个黑科技:

牛顿迭代,泰勒展开

\(F,F_0\)是多项式,\(G\)是某个多项式函数。

现在要求

\[G(F)\equiv 0(\mod x^n) \]

现在已经知道了

\[G(F_0)\equiv0(\mod x^\frac n2) \]

我们对\(G(F)\)\(F_0\)处泰勒展开

\[G(F)\equiv G(F_0)+\frac{G'(F_0)}{1!}(F-F_0)+\frac{G''(F_0)}{2!}(F-F_0)+\dots(\mod x^n) \]

因为\(F\)\(F_0\)的前\(\frac n2\)相同,所以\((F-F_0)\)的前\(\frac n2\)为0,所以对于\(n>1\)的情况\((F-F_0)^n\)的前n为必定为0,对答案无意义,可舍去。

所以有

\[G(F)\equiv G(F_0)+G'(F_0)(F-F_0)(\mod x^n) \]

因为\(G(F)\equiv 0(\mod x^n)\),所以有

\[0\equiv G(F_0)+G'(F_0)(F-F_0)(\mod x^n) \]

\[F=\frac{-G(F_0)+G'(F_0)F_0}{G'(F_0)}(\mod x^n) \]

\[F=F_0-\frac{G(F_0)}{G'(F_0)}(\mod x^n) \]

这里要注意当求\(G'(F)\)时,我们要把\(F\)当成一个未知数,这样\(G'(F)=G'F\)

Part 4:多项式exp

用黑科技可求解。

给出多项式\(A\)

\[G(F)=\ln(F)-A \]

\[G'(F)=\frac1F \]

\[F=F_0-\frac{\ln(F_0)-A}{\frac 1{F_0}} \]

\[F=F_0(1-\ln(F_0)+A) \]

#include<bits/stdc++.h>
#define res register int
using namespace std;
const int N=3e6+10,p=998244353,G=3,Gi=332748118;
inline int add(res x,res y){return(x+y)>=p?x+y-p:x+y;}
inline int mul(res x,res y){return 1ll*x*y-1ll*x*y/p*p;}
inline int mpow(res a,res n){
    res ret=1;
    while(n){
        if(n&1)ret=mul(ret,a);
        a=mul(a,a);
        n/=2;
    }
    return ret;
}    
int g[2][N];
void init(){
    for(int i=1;i<N;i*=2){
        g[0][i]=mpow(G,(p-1)/i);
        g[1][i]=mpow(Gi,(p-1)/i);
    }
}
int n,m;
int ls[10][N],used;
//0,1 mul
//2,3,4 ln
//5,6 exp
int a[N],b[N],c[N];
int r[N];
void ntt(int*f,int limit,int op){
    for(res i=0;i<limit;++i)if(i<r[i])swap(f[i],f[r[i]]);
    for(res len=2;len<=limit;len*=2){
        res wn=op==1?g[0][len]:g[1][len];
        for(res j=0;j<limit;j+=len){
            res w=1;
            for(res k=j;k<j+len/2;++k,w=mul(w,wn)){
                res x=f[k],y=mul(w,f[k+len/2]);
                f[k]=add(x,y);
                f[k+len/2]=add(x,p-y);
            }
        }
    }
    if(op==-1){
        res inv=mpow(limit,p-2);
        for(res i=0;i<limit;++i){
            f[i]=mul(f[i],inv);
        }
    }
}
void mul(int*a,int*b,int*c,int n,int m){
    int limit=1;
    while(limit<n+m-1)limit*=2;
    for(int i=0;i<limit;++i)ls[0][i]=ls[1][i]=0;
    for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
    for(int i=0;i<limit;++i)ls[0][i]=a[i],ls[1][i]=b[i];
    ntt(ls[0],limit,1);
    ntt(ls[1],limit,1);
    for(res i=0;i<limit;++i)c[i]=mul(ls[0][i],ls[1][i]);
    ntt(c,limit,-1);
}
void inv(int*a,int*b,int n){
	b[0]=mpow(a[0],p-2);
    for(int len=1,limit;len<2*n;len*=2){
        limit=len*2;
        for(int i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
        for(int i=0;i<len;++i)ls[0][i]=a[i],ls[1][i]=b[i];
        for(int i=len;i<limit;++i)ls[0][i]=ls[1][i]=0;
        ntt(ls[0],limit,1),ntt(ls[1],limit,1);
        for(int i=0;i<limit;++i){
            b[i]=add(mul(2,ls[1][i]),p-(mul(ls[0][i],mul(ls[1][i],ls[1][i]))));
		}
		ntt(b,limit,-1);
		for(int i=len;i<limit;++i)b[i]=0;
	}
}
void direv(int*a,int*b,int n){
    for(int i=1;i<n;++i){
        b[i-1]=mul(a[i],i);
	}
}
void inter(int*a,int*b,int n){
	b[0]=0;
    for(int i=1;i<n;++i){
        b[i]=mul(a[i-1],mpow(i,p-2));
	}
}
void ln(int*a,int*b,int n){
	direv(a,ls[2],n);
	inv(a,ls[3],n);
	mul(ls[2],ls[3],ls[4],n,n);
	inter(ls[4],b,2*n);
	for(int i=n;i<2*n;++i)b[i]=0;
}
void exp(int*a,int*b,int n){
	b[0]=1;
    for(int len=1;len<2*n;len*=2){
        int limit=len*2;
		ln(b,ls[5],len);
		for(int i=0;i<len;++i){
		    ls[5][i]=add(p-ls[5][i],a[i]);
		}
		ls[5][0]=add(ls[5][0],1);
		for(int i=0;i<len;++i)ls[6][i]=b[i];
		mul(ls[5],ls[6],b,len,len);
		for(int i=len;i<limit;++i)b[i]=0;
	}
}
inline int read(){
    res ret=0;char c;
    for(c=getchar();!isdigit(c);c=getchar());
    for(;isdigit(c);ret=ret*10+c-'0',c=getchar());
    return ret;
}
int main(){
	init();
    cin>>n;
    for(res i=0;i<n;++i)a[i]=read();
    exp(a,c,n);
    for(res i=0;i<n;++i)printf("%d ",c[i]);puts("");
}

Part 5:多项式开根

也使用上文黑科技:

给出多项式\(A\)

\[G(F)=F^2-A \]

\[G'(F)=2F \]

\[F=F_0-\frac{F_0^2-A}{2F_0} \]

\[F=\frac{F_0^2+A}{2F_0} \]

\[F=\frac{F_0}2+\frac{A}{2F_0} \]

#include<bits/stdc++.h>
#define res register int
#define ll long long
ll js;
using namespace std;
const int N=3e6+10,p=998244353,G=3,Gi=332748118;
inline int add(res x,res y){return(x+y)%p;}
inline int mul(res x,res y){return 1ll*x*y%p;}
inline int mpow(res a,res n){
    res ret=1;
    while(n){
        if(n&1)ret=mul(ret,a);
        a=mul(a,a);
        n/=2;
    }
    return ret;
}    
int g[2][N];
int inv2;
void init(){
	inv2=mpow(2,p-2);
    for(int i=1;i<N;i*=2){
        g[0][i]=mpow(G,(p-1)/i);
        g[1][i]=mpow(Gi,(p-1)/i);
    }
}
int n,m;
int ls[7][N],used;
//0,1 mul inv
//2,3,4 ln sqrt
//5,6 exp
int a[N],b[N],c[N];
int r[N];
void ntt(int*f,res limit,res op){
    for(res i=0;i<limit;++i)if(i<r[i])swap(f[i],f[r[i]]);
    for(res len=2;len<=limit;len*=2){
        res wn=op==1?g[0][len]:g[1][len];
        for(res j=0;j<limit;j+=len){
            res w=1;
            for(res k=j;k<j+len/2;++k,w=mul(w,wn)){
                res x=f[k],y=mul(w,f[k+len/2]);
                f[k]=add(x,y);
                f[k+len/2]=add(x,p-y);
            }
        }
    }
    if(op==-1){
        res inv=mpow(limit,p-2);
        for(res i=0;i<limit;++i){
            f[i]=mul(f[i],inv);
        }
    }
}
void mul(int*a,int*b,int*c,int n,int m){
    int limit=1;
    while(limit<n+m-1)limit*=2;
    for(res i=0;i<limit;++i)ls[0][i]=ls[1][i]=0;
    for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
    for(res i=0;i<limit;++i)ls[0][i]=a[i],ls[1][i]=b[i];
    ntt(ls[0],limit,1);
    ntt(ls[1],limit,1);
    for(res i=0;i<limit;++i)c[i]=mul(ls[0][i],ls[1][i]);
    ntt(c,limit,-1);
}
void inv(int*a,int*b,int n){
	b[0]=mpow(a[0],p-2);
    for(res len=2,limit;len<2*n;len*=2){
        limit=len*2;
        for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
        for(res i=0;i<len;++i)ls[0][i]=a[i],ls[1][i]=b[i];
        for(res i=len;i<limit;++i)ls[0][i]=ls[1][i]=0;
        ntt(ls[0],limit,1),ntt(ls[1],limit,1);
        for(res i=0;i<limit;++i){
            b[i]=mul(add(2,p-(mul(ls[0][i],ls[1][i]))),ls[1][i]);
		}
		ntt(b,limit,-1);
		for(res i=len;i<limit;++i)b[i]=0;
	}
}
inline void direv(int*a,int*b,int n){
    for(res i=1;i<n;++i){
        b[i-1]=mul(a[i],i);
	}
}
inline void inter(int*a,int*b,int n){
	b[0]=0;
    for(res i=1;i<n;++i){
        b[i]=mul(a[i-1],mpow(i,p-2));
	}
}
void ln(int*a,int*b,int n){
	direv(a,ls[2],n);
	inv(a,ls[3],n);
	mul(ls[2],ls[3],ls[4],n,n);
	inter(ls[4],b,2*n);
	for(res i=n;i<2*n;++i)b[i]=0;
}
void exp(int*a,int*b,int n){
	b[0]=1;
    for(res len=2;len<2*n;len*=2){
        res limit=len*2;
		ln(b,ls[5],len);
		for(res i=0;i<len;++i){
		    ls[5][i]=add(p-ls[5][i],a[i]);
		}
		ls[5][0]=add(ls[5][0],1);
		for(res i=0;i<len;++i)ls[6][i]=b[i];
		mul(ls[5],ls[6],b,len,len);
		for(res i=len;i<limit;++i)b[i]=0;
	}
}
void sqrt(int*a,int*b,int n){
	b[0]=1;
    for(res len=2;len<2*n;len*=2){
        res limit=len*2;
        inv(b,ls[2],len);
        for(res i=0;i<len;++i)ls[3][i]=a[i];
        mul(ls[2],ls[3],ls[4],len,len);
        for(res i=0;i<len;++i)b[i]=mul(add(b[i],ls[4][i]),inv2);
        for(res i=len;i<limit;++i)b[i]=0;
    }
}
inline int read(){
    res ret=0;char c;
    for(c=getchar();!isdigit(c);c=getchar());
    for(;isdigit(c);ret=ret*10+c-'0',c=getchar());
    return ret;
}
int main(){

	init();
    cin>>n;
    for(res i=0;i<n;++i)a[i]=read();
    sqrt(a,c,n);
    for(res i=0;i<n;++i)printf("%d ",c[i]);puts("");

}

Part 6:多项式快速幂

直接换底公式即可

#include<bits/stdc++.h>
#define res register int
using namespace std;
const int N=3e6+10,p=998244353,G=3,Gi=332748118;
inline int add(res x,res y){return(x+y)%p;}
inline int mul(res x,res y){return 1ll*x*y%p;}
inline int mpow(res a,res n){
    res ret=1;
    while(n){
        if(n&1)ret=mul(ret,a);
        a=mul(a,a);
        n/=2;
    }
    return ret;
}    
int g[2][N];
int inv2;
void init(){
	inv2=mpow(2,p-2);
    for(int i=1;i<N;i*=2){
        g[0][i]=mpow(G,(p-1)/i);
        g[1][i]=mpow(Gi,(p-1)/i);
    }
}
int n,m;
int ls[7][N],used;
//0,1 mul inv
//2,3,4 ln sqrt
//5,6 exp
int a[N],b[N],c[N];
int r[N];
void ntt(int*f,res limit,res op){
    for(res i=0;i<limit;++i)if(i<r[i])swap(f[i],f[r[i]]);
    for(res len=1;len<=limit;len*=2){
        res wn=op==1?g[0][len]:g[1][len];
        for(res j=0;j<limit;j+=len){
            res w=1;
            for(res k=j;k<j+len/2;++k,w=mul(w,wn)){
                res x=f[k],y=mul(w,f[k+len/2]);
                f[k]=add(x,y);
                f[k+len/2]=add(x,p-y);
            }
        }
    }
    if(op==-1){
        res inv=mpow(limit,p-2);
        for(res i=0;i<limit;++i){
            f[i]=mul(f[i],inv);
        }
    }
}
void mul(int*a,int*b,int*c,int n,int m){
    int limit=1;
    while(limit<n+m-1)limit*=2;
    for(res i=0;i<limit;++i)ls[0][i]=ls[1][i]=0;
    for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
    for(res i=0;i<limit;++i)ls[0][i]=a[i],ls[1][i]=b[i];
    ntt(ls[0],limit,1);
    ntt(ls[1],limit,1);
    for(res i=0;i<limit;++i)c[i]=mul(ls[0][i],ls[1][i]);
    ntt(c,limit,-1);
}
void inv(int*a,int*b,int n){
	b[0]=mpow(a[0],p-2);
	res limit;
    for(res len=1;len<2*n;len*=2){
        limit=len*2;
        for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
        for(res i=0;i<len;++i)ls[0][i]=a[i],ls[1][i]=b[i];
        for(res i=len;i<limit;++i)ls[0][i]=ls[1][i]=0;
        ntt(ls[0],limit,1),ntt(ls[1],limit,1);
        for(res i=0;i<limit;++i){
            b[i]=mul(add(2,p-(mul(ls[0][i],ls[1][i]))),ls[1][i]);
		}
		ntt(b,limit,-1);
		for(res i=len;i<limit;++i)b[i]=0;
	}
	for(int i=n;i<limit;++i)b[i]=0;
}
inline void direv(int*a,int*b,int n){
    for(res i=1;i<n;++i){
        b[i-1]=mul(a[i],i);
	}
	b[n-1]=0;
}
inline void inter(int*a,int*b,int n){
	b[0]=0;
    for(res i=1;i<n;++i){
        b[i]=mul(a[i-1],mpow(i,p-2));
	}
}
void ln(int*a,int*b,int n){
	direv(a,ls[2],n);
	inv(a,ls[3],n);
	mul(ls[2],ls[3],ls[4],n,n);
	inter(ls[4],b,2*n);
	for(res i=n;i<2*n;++i)b[i]=0;
	for(res i=0;i<n;++i)ls[2][i]=ls[3][i]=ls[4][i]=0;
}
void exp(int*a,int*b,int n){
	b[0]=1;
    for(res len=1;len<2*n;len*=2){
        res limit=len*2;
		ln(b,ls[5],len);
		for(res i=0;i<len;++i){
		    ls[5][i]=add(p-ls[5][i],a[i]);
		}
		ls[5][0]=add(ls[5][0],1);
		for(res i=0;i<len;++i)ls[6][i]=b[i];
		mul(ls[5],ls[6],b,len,len);
		for(res i=len;i<limit;++i)b[i]=0;
	}
}
void sqrt(int*a,int*b,int n){
	b[0]=1;
    for(res len=2;len<2*n;len*=2){
        res limit=len*2;
        inv(b,ls[2],len);
        for(res i=0;i<len;++i)ls[3][i]=a[i];
        mul(ls[2],ls[3],ls[4],len,len);
        for(res i=0;i<len;++i)b[i]=mul(add(b[i],ls[4][i]),inv2);
        for(res i=len;i<limit;++i)b[i]=0;
    }
}
inline int read(){
    res ret=0;char c;
    for(c=getchar();!isdigit(c);c=getchar());
    for(;isdigit(c);ret=add(mul(ret,10),c-'0'),c=getchar());
    return ret;
}
int main(){
	init();
	n=read(),m=read();
    for(res i=0;i<n;++i)a[i]=read();
    ln(a,b,n);
    for(int i=0;i<n;++i)b[i]=mul(b[i],m);
	exp(b,c,n); 
    for(res i=0;i<n;++i)printf("%d ",c[i]);puts("");
}
posted @ 2019-05-20 10:33  整理者  阅读(819)  评论(0编辑  收藏  举报