NTT(快速数论变换)

假设质数p满足\(p=r\cdot 2^l +1\),g是p的原根
使用\(g_n=g^{\frac{p-1}{n}}代替\)FFT\(中的\omega_n\)
同理\(g_n有以下性质\)

  • \(g_{2n}^{2k}\equiv g_n^k (mod \: p), (2n\leq 2^l)\)
  • \(g_{2n}^n \equiv -1 (mod \: p),(2n\leq 2^l)\)
    因为\((g^{\frac{p-1}{2n}\cdot n})^2=(g^{\frac{p-1}{n}})^2=g^{p-1}\equiv 1\)
    所以\(g^{\frac{p-1}{n}}\equiv \pm 1\),因为\(g^0\equiv 1\),所以\(g^{\frac{p-1}{n}}=g_{2n}^n\equiv -1\)

\[\sum_{k=0}^{n-1} g_n^{ik}g_n^{-kj}\equiv\left\{ \begin{aligned} n & , &if \quad j=i \\ 0 &, &otherwise \end{aligned} \quad (mod \: p)其中0\leq i,j <n \right. \]

NTT

把FFT中的\(\omega_n\)换成\(g_n\),关于DFT,IDFT的推导过程依然成立(除了从\(\mathbb{C}\)中的运算变成了\(\mathbb{Z}\)中的运算)

  • NTT的优点:快、精确
  • NTT的限制:模数需要是满足\(p=r\cdot 2^l+1\)的质数p

常见模数

  • \(65537=2^{16}+1,g=3\)
  • \(998244353=119\cdot 2^{23}+1,g=3\)
  • \(1004535809=479\cdot 2^{21}+1,g=3\)
  • \(4179340454199820289=29\cdot 2^{57}+1,g=3\)

模版A * B Problem Plu

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#define ll long long 
using namespace std;
const int maxn=2e5+10101;
const int MOD=998244353;
const int inf=2147483647;
const double pi=acos(-1);
int read(){
    int x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}
typedef complex<double> cd;
int n,m,rev[maxn],len;
void get(int bit){
    for(int i=0;i<(1<<bit);i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
}
ll power(ll x,ll y){
    ll ans=1;
    while(y){
        if(y&1)ans=ans*x%MOD;
        y>>=1;x=x*x%MOD;
    }
    return ans;
}
void dft(ll *u,int val){
    for(int i=0;i<n;i++)if(i<rev[i])swap(u[i],u[rev[i]]);
    for(int i=1;i<n;i<<=1){
        ll wn=power(3,(MOD-1)/(i<<1));
        if(val==-1)wn=power(wn,MOD-2);
        for(int j=0;j<n;j+=(i<<1)){
            ll w=1;
            for(int k=0;k<i;k++,w=w*wn%MOD){
                ll x=u[j+k],y=w*u[j+k+i]%MOD;
                u[j+k]=(x+y)%MOD;
                u[j+k+i]=(x-y)%MOD;
            }
        }
    }
    return ;
}
ll c[maxn],d[maxn];
ll a[maxn],b[maxn];
int main(){
    char ch[maxn];cin>>ch;n=strlen(ch);
    for(int i=0;i<n;i++)a[i]=(ch[n-i-1]-'0');
    cin>>ch;m=strlen(ch);
    for(int i=0;i<m;i++)b[i]=(ch[m-i-1]-'0');
    m+=n;for(n=1;n<=m;n<<=1)len++;
    get(len);dft(a,1);dft(b,1); 
    for(int i=0;i<(1<<len);i++)(a[i]*=b[i])%=MOD;
    dft(a,-1);
    for(int i=0;i<m;i++)c[i]=(a[i]*power(n,MOD-2)%MOD+MOD)%MOD;
    int jin=0,tot=0;
    for(int i=0;i<m;i++){
        d[i]=(c[i]+jin)%10;
        if(c[i]+jin>=10)jin=(c[i]+jin)/10;
        else jin=0;
    }
    bool fa=false;
    for(int i=m-1;i>=0;i--){
        if(d[i]!=0)fa=true;
        if(fa)printf("%lld",(d[i]%MOD+MOD)%MOD);
    }
    return 0;
}

多项式求逆

求A模\(x^n\)的逆元,假设先求出了模\(x^{\lceil \frac{n}{2} \rceil }\)的逆元
设A模\(x^n\)的逆元为\(B\),模\(x^{\lceil \frac{n}{2} \rceil }\)的逆元为\(B'\)

\(A*B'\equiv 1 (mod \:x^{\lceil \frac{n}{2} \rceil })\)
\(A*B \equiv 1 (mod \: x^n)\)
所以\(B'-B \equiv 0 (mod \: x^{\lceil \frac{n}{2} \rceil } )\)
开平方得\(B'^2-2B'B+B^2 \equiv 0(mod\: x^n)\)
左右同乘A得\(AB'^2-2B'+B \equiv 0(mod\: x^n)\)
因此得到\(B\equiv 2B'-AB'^2(mod \: x^n)\)的递推式,由下向上,从\(x^1\)开始推至\(x^{2^z}(n\leq 2^z)\)即可
初值\(B=A(0)^{-1}\)
利用NTT可将时间复杂度优化至\(O(nlogn)\)
另外注意:\(F(x)\)存在逆元当且仅当\([x^0]F(x)\not =0\)

点击查看代码
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<complex>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<queue>
#define ll long long
using namespace std;
const int maxn=1000000+10101;
const int MOD=998244353;
inline ll read(){
    ll x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=(x<<3)+(x<<1)+ch-'0';
    return x*f;
}
ll power(ll x,ll y){
    ll ans=1;
    while(y){
        if(y&1)ans=(ans*x)%MOD;
        x=(x*x)%MOD;
        y>>=1;
    }
    return ans%MOD;
}


typedef vector<ll> Poly;
int rev[maxn];
void get(int bit){
    for(int i=0;i<(1<<bit);i++)rev[i]=(rev[i>>1]>>1)|((1&i)<<(bit-1)); 
    return ;
}
void ntt(ll *a,int n,int f){
    get(log2(n));
	for(int i=0;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
    for(int i=1;i<n;i<<=1){
        ll wn=power(3,(MOD-1)/(i<<1))%MOD;
        if(f==-1)wn=power(wn,MOD-2);
        for(int j=0;j<n;j+=i<<1){
            ll w=1,x,y;
            for(int k=0;k<i;k++,w=wn*w%MOD){
                x=a[k+j];y=a[k+j+i]*w%MOD;
                a[j+k]=(x+y)%MOD;a[j+k+i]=(x-y)%MOD; 
            }
            
        }
    }
    if(f==1)return ;
    int nv=power(n,MOD-2);  
    for(int i=0;i<n;i++)a[i]=a[i]*nv%MOD;
    return ;
}
ll F1[maxn],F2[maxn];
Poly mul(Poly A,Poly B,ll lens){			//求多项式A*B 
    int n=A.size(),m=B.size();
    int bit=ceil(log2(lens));lens=(1<<bit);
    for(int i=0;i<lens;i++)F1[i]=F2[i]=0;
    for(int i=0;i<n;i++)F1[i]=A[i];
    for(int i=0;i<m;i++)F2[i]=B[i];
    ntt(F1,lens,1);ntt(F2,lens,1);
    for(int i=0;i<lens;i++)F1[i]=((2-F1[i]*F2[i]%MOD)%MOD*F1[i])%MOD;
    ntt(F1,lens,-1);
    Poly ans;
    for(int i=0;i<lens;i++)ans.push_back(F1[i]);
    return ans;
}
ll n,a[maxn],g[maxn];
Poly getinv(ll limit){
    ll D=ceil(log2(limit)),len=1<<D; 
	Poly b(1);
	b[0]=power(a[0],MOD-2);
    for(int s=2;s<=len;s<<=1) {
    	Poly f(s); 
        for(int i=0;i<s;++i) f[i]=a[i];
        b=mul(b,f,s<<1);
        for(int i=s;i<(s<<1);i++)b[i]=0;
    } 
	return b;
}
int main(){
    n=read();
    for(int i=0;i<n;i++)a[i]=read();
    Poly b=getinv(n);
    for(int i=0;i<n;i++)printf("%lld ",(b[i]%MOD+MOD)%MOD);
    return 0;
}

分治NTT

求多个多项式相乘,采用分治思想
对于一个区间分成左右两半,分别求出左右区间的多项式乘积再相乘就是当前区间的多项式乘积

应用

1.挑选队友
对于第i群可以写成表达式\(F_i(x)=0x^0+s_ix^1+{s_i \choose 2}x^2+···+{s_i \choose s_i}x^{s_i}\)
\(F(x)=\prod_{i=1}^mF_i(x)\),用分支ntt即可求
\(ans=[x^k]F(x)\)

点击查看代码
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<queue>
#define ll long long 
using namespace std;
const int maxn=400000+10101;
const int MOD=998244353;
inline ll read(){
    ll x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=(x<<3)+(x<<1)+ch-'0';
    return x*f;
}
typedef vector<ll> Poly;
ll power(ll x,ll y){
    ll ans=1;
    while(y){
        if(y&1)ans=ans*x%MOD;
        y>>=1;x=x*x%MOD;
    }
    return ans;
}
ll inv[maxn],pre[maxn];
ll C(int n,int m){
    if(m==0)return 0;
    if(n==m)return 1;
    ll ans=pre[n]*inv[m]%MOD*inv[n-m]%MOD;
    return (ans%MOD);
}
int rev[maxn];
void get(int bit){
    for(int i=0;i<(1<<bit);i++)rev[i]=(rev[i>>1]>>1)|((1&i)<<(bit-1)); 
    return ;
}
void ntt(ll *a,int n,int f){
    get(log2(n));for(int i=0;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
    for(int i=1;i<n;i<<=1){
        ll wn=power(3,(MOD-1)/(i<<1))%MOD;
        if(f==-1)wn=power(wn,MOD-2);
        for(int j=0;j<n;j+=i<<1){
            ll w=1,x,y;
            for(int k=0;k<i;k++,w=wn*w%MOD){
                x=a[k+j];y=a[k+j+i]*w%MOD;
                a[j+k]=(x+y)%MOD;a[j+k+i]=(x-y)%MOD; 
            }
            
        }
    }
    if(f==1)return ;
    int nv=power(n,MOD-2);  
    for(int i=0;i<n;i++)a[i]=a[i]*nv%MOD;
    return ;
}
ll F1[maxn],F2[maxn];
Poly mul(Poly A,Poly B){
    int n=A.size(),m=B.size(),lens=n+m-1;
    int bit=ceil(log2(lens));lens=(1<<bit);
    for(int i=0;i<lens;i++)F1[i]=F2[i]=0;
    for(int i=0;i<n;i++)F1[i]=A[i];
    for(int i=0;i<m;i++)F2[i]=B[i];
    ntt(F1,lens,1);ntt(F2,lens,1);
    for(int i=0;i<lens;i++)F1[i]=F1[i]*F2[i]%MOD;
    ntt(F1,lens,-1);
    Poly ans;
    for(int i=0;i<n+m-1;i++)ans.push_back(F1[i]);
    return ans;
}
int n,m,k,s[maxn];
Poly a[maxn];
void init(){
    n=read();m=read();k=read();pre[1]=1ll;
    for(ll i=2;i<=n;i++)pre[i]=pre[i-1]*i%MOD;
    inv[n]=power(pre[n],MOD-2);
    for(ll i=n-1;i>=1;i--)inv[i]=inv[i+1]*(i+1)%MOD;
    for(int i=1;i<=m;i++)s[i]=read();
    for(int i=1;i<=m;i++){
        for(int j=0;j<=s[i];j++)a[i].push_back(C(s[i],j));
    }
}
Poly solve(int l,int r){
    if(l==r)return a[l];
    int mid=(l+r)>>1;
    return mul(solve(l,mid),solve(mid+1,r));
}
int main(){
    init();printf("%lld",(solve(1,m)[k]%MOD+MOD)%MOD);
    return 0;
}

2.tokitsukaze and Another Protoss and Zerg
同上一题,对于第i轮
\(F_i(x)=(2^{b_i}-1)x^0+a_ix+{a_i \choose 2}x^2+···+{a_i \choose a_i}x^{a_i}\)
用分治ntt即可求出

点击查看代码
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<queue>
#define ll long long 
using namespace std;
const int maxn=400000+10101;
const int MOD=998244353;
inline ll read(){
    ll x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=(x<<3)+(x<<1)+ch-'0';
    return x*f;
}
typedef vector<ll> Poly;
ll power(ll x,ll y){
    ll ans=1;
    while(y){
        if(y&1)ans=ans*x%MOD;
        y>>=1;x=x*x%MOD;
    }
    return ans;
}
ll inv[maxn],pre[maxn];
ll C(int n,int m){
    if(m==0)return 0;
    if(n==m)return 1;
    ll ans=pre[n]*inv[m]%MOD*inv[n-m]%MOD;
    return (ans%MOD);
}
int rev[maxn];
void get(int bit){
    for(int i=0;i<(1<<bit);i++)rev[i]=(rev[i>>1]>>1)|((1&i)<<(bit-1)); 
    return ;
}
void ntt(ll *a,int n,int f){
    get(log2(n));for(int i=0;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
    for(int i=1;i<n;i<<=1){
        ll wn=power(3,(MOD-1)/(i<<1))%MOD;
        if(f==-1)wn=power(wn,MOD-2);
        for(int j=0;j<n;j+=i<<1){
            ll w=1,x,y;
            for(int k=0;k<i;k++,w=wn*w%MOD){
                x=a[k+j];y=a[k+j+i]*w%MOD;
                a[j+k]=(x+y)%MOD;a[j+k+i]=(x-y)%MOD; 
            }
            
        }
    }
    if(f==1)return ;
    int nv=power(n,MOD-2);  
    for(int i=0;i<n;i++)a[i]=a[i]*nv%MOD;
    return ;
}
ll F1[maxn],F2[maxn];
Poly mul(Poly A,Poly B){
    int n=A.size(),m=B.size(),lens=n+m-1;
    int bit=ceil(log2(lens));lens=(1<<bit);
    for(int i=0;i<lens;i++)F1[i]=F2[i]=0;
    for(int i=0;i<n;i++)F1[i]=A[i];
    for(int i=0;i<m;i++)F2[i]=B[i];
    ntt(F1,lens,1);ntt(F2,lens,1);
    for(int i=0;i<lens;i++)F1[i]=F1[i]*F2[i]%MOD;
    ntt(F1,lens,-1);
    Poly ans;
    for(int i=0;i<n+m-1;i++)ans.push_back(F1[i]);
    return ans;
}
int n,sa[maxn],sb[maxn],sum;
Poly a[maxn];
void init(){
    n=read();pre[1]=1ll;
    for(ll i=2;i<=200000;i++)pre[i]=pre[i-1]*i%MOD;
    inv[200000]=power(pre[200000],MOD-2);
    for(ll i=199999;i>=1;i--)inv[i]=inv[i+1]*(i+1)%MOD;
    for(int i=1;i<=n;i++)sa[i]=read(),sum+=sa[i];
    for(int i=1;i<=n;i++)sb[i]=read();
    for(int i=1;i<=n;i++){
        a[i].push_back(power(2,sb[i])-1);
        for(int j=1;j<=sa[i];j++)a[i].push_back(C(sa[i],j));
    }
}
Poly solve(int l,int r){
    if(l==r)return a[l];
    int mid=(l+r)>>1;
    return mul(solve(l,mid),solve(mid+1,r));
}
int main(){
    init();
    Poly ans=solve(1,n);
    for(int i=0;i<=sum;i++)printf("%lld ",(ans[i]%MOD+MOD)%MOD);
    return 0;
}

3.The Child and Binary Tree

4.卷积

posted @ 2022-01-04 20:07  I_N_V  阅读(782)  评论(0编辑  收藏  举报