NTT(快速数论变换)
假设质数p满足\(p=r\cdot 2^l +1\),g是p的原根
使用\(g_n=g^{\frac{p-1}{n}}代替\)FFT\(中的\omega_n\)
同理\(g_n有以下性质\)
- \(g_{2n}^{2k}\equiv g_n^k (mod \: p), (2n\leq 2^l)\)
- \(g_{2n}^n \equiv -1 (mod \: p),(2n\leq 2^l)\)
因为\((g^{\frac{p-1}{2n}\cdot n})^2=(g^{\frac{p-1}{n}})^2=g^{p-1}\equiv 1\)
所以\(g^{\frac{p-1}{n}}\equiv \pm 1\),因为\(g^0\equiv 1\),所以\(g^{\frac{p-1}{n}}=g_{2n}^n\equiv -1\)
NTT
把FFT中的\(\omega_n\)换成\(g_n\),关于DFT,IDFT的推导过程依然成立(除了从\(\mathbb{C}\)中的运算变成了\(\mathbb{Z}\)中的运算)
- NTT的优点:快、精确
- NTT的限制:模数需要是满足\(p=r\cdot 2^l+1\)的质数p
常见模数
- \(65537=2^{16}+1,g=3\)
- \(998244353=119\cdot 2^{23}+1,g=3\)
- \(1004535809=479\cdot 2^{21}+1,g=3\)
- \(4179340454199820289=29\cdot 2^{57}+1,g=3\)
点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#define ll long long
using namespace std;
const int maxn=2e5+10101;
const int MOD=998244353;
const int inf=2147483647;
const double pi=acos(-1);
int read(){
int x=0,f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
return x*f;
}
typedef complex<double> cd;
int n,m,rev[maxn],len;
void get(int bit){
for(int i=0;i<(1<<bit);i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
}
ll power(ll x,ll y){
ll ans=1;
while(y){
if(y&1)ans=ans*x%MOD;
y>>=1;x=x*x%MOD;
}
return ans;
}
void dft(ll *u,int val){
for(int i=0;i<n;i++)if(i<rev[i])swap(u[i],u[rev[i]]);
for(int i=1;i<n;i<<=1){
ll wn=power(3,(MOD-1)/(i<<1));
if(val==-1)wn=power(wn,MOD-2);
for(int j=0;j<n;j+=(i<<1)){
ll w=1;
for(int k=0;k<i;k++,w=w*wn%MOD){
ll x=u[j+k],y=w*u[j+k+i]%MOD;
u[j+k]=(x+y)%MOD;
u[j+k+i]=(x-y)%MOD;
}
}
}
return ;
}
ll c[maxn],d[maxn];
ll a[maxn],b[maxn];
int main(){
char ch[maxn];cin>>ch;n=strlen(ch);
for(int i=0;i<n;i++)a[i]=(ch[n-i-1]-'0');
cin>>ch;m=strlen(ch);
for(int i=0;i<m;i++)b[i]=(ch[m-i-1]-'0');
m+=n;for(n=1;n<=m;n<<=1)len++;
get(len);dft(a,1);dft(b,1);
for(int i=0;i<(1<<len);i++)(a[i]*=b[i])%=MOD;
dft(a,-1);
for(int i=0;i<m;i++)c[i]=(a[i]*power(n,MOD-2)%MOD+MOD)%MOD;
int jin=0,tot=0;
for(int i=0;i<m;i++){
d[i]=(c[i]+jin)%10;
if(c[i]+jin>=10)jin=(c[i]+jin)/10;
else jin=0;
}
bool fa=false;
for(int i=m-1;i>=0;i--){
if(d[i]!=0)fa=true;
if(fa)printf("%lld",(d[i]%MOD+MOD)%MOD);
}
return 0;
}
多项式求逆
求A模\(x^n\)的逆元,假设先求出了模\(x^{\lceil \frac{n}{2} \rceil }\)的逆元
设A模\(x^n\)的逆元为\(B\),模\(x^{\lceil \frac{n}{2} \rceil }\)的逆元为\(B'\)
则
\(A*B'\equiv 1 (mod \:x^{\lceil \frac{n}{2} \rceil })\)
\(A*B \equiv 1 (mod \: x^n)\)
所以\(B'-B \equiv 0 (mod \: x^{\lceil \frac{n}{2} \rceil } )\)
开平方得\(B'^2-2B'B+B^2 \equiv 0(mod\: x^n)\)
左右同乘A得\(AB'^2-2B'+B \equiv 0(mod\: x^n)\)
因此得到\(B\equiv 2B'-AB'^2(mod \: x^n)\)的递推式,由下向上,从\(x^1\)开始推至\(x^{2^z}(n\leq 2^z)\)即可
初值\(B=A(0)^{-1}\)
利用NTT可将时间复杂度优化至\(O(nlogn)\)
另外注意:\(F(x)\)存在逆元当且仅当\([x^0]F(x)\not =0\)
点击查看代码
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<complex>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<queue>
#define ll long long
using namespace std;
const int maxn=1000000+10101;
const int MOD=998244353;
inline ll read(){
ll x=0,f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
for(;isdigit(ch);ch=getchar())x=(x<<3)+(x<<1)+ch-'0';
return x*f;
}
ll power(ll x,ll y){
ll ans=1;
while(y){
if(y&1)ans=(ans*x)%MOD;
x=(x*x)%MOD;
y>>=1;
}
return ans%MOD;
}
typedef vector<ll> Poly;
int rev[maxn];
void get(int bit){
for(int i=0;i<(1<<bit);i++)rev[i]=(rev[i>>1]>>1)|((1&i)<<(bit-1));
return ;
}
void ntt(ll *a,int n,int f){
get(log2(n));
for(int i=0;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1){
ll wn=power(3,(MOD-1)/(i<<1))%MOD;
if(f==-1)wn=power(wn,MOD-2);
for(int j=0;j<n;j+=i<<1){
ll w=1,x,y;
for(int k=0;k<i;k++,w=wn*w%MOD){
x=a[k+j];y=a[k+j+i]*w%MOD;
a[j+k]=(x+y)%MOD;a[j+k+i]=(x-y)%MOD;
}
}
}
if(f==1)return ;
int nv=power(n,MOD-2);
for(int i=0;i<n;i++)a[i]=a[i]*nv%MOD;
return ;
}
ll F1[maxn],F2[maxn];
Poly mul(Poly A,Poly B,ll lens){ //求多项式A*B
int n=A.size(),m=B.size();
int bit=ceil(log2(lens));lens=(1<<bit);
for(int i=0;i<lens;i++)F1[i]=F2[i]=0;
for(int i=0;i<n;i++)F1[i]=A[i];
for(int i=0;i<m;i++)F2[i]=B[i];
ntt(F1,lens,1);ntt(F2,lens,1);
for(int i=0;i<lens;i++)F1[i]=((2-F1[i]*F2[i]%MOD)%MOD*F1[i])%MOD;
ntt(F1,lens,-1);
Poly ans;
for(int i=0;i<lens;i++)ans.push_back(F1[i]);
return ans;
}
ll n,a[maxn],g[maxn];
Poly getinv(ll limit){
ll D=ceil(log2(limit)),len=1<<D;
Poly b(1);
b[0]=power(a[0],MOD-2);
for(int s=2;s<=len;s<<=1) {
Poly f(s);
for(int i=0;i<s;++i) f[i]=a[i];
b=mul(b,f,s<<1);
for(int i=s;i<(s<<1);i++)b[i]=0;
}
return b;
}
int main(){
n=read();
for(int i=0;i<n;i++)a[i]=read();
Poly b=getinv(n);
for(int i=0;i<n;i++)printf("%lld ",(b[i]%MOD+MOD)%MOD);
return 0;
}
分治NTT
求多个多项式相乘,采用分治思想
对于一个区间分成左右两半,分别求出左右区间的多项式乘积再相乘就是当前区间的多项式乘积
应用
1.挑选队友
对于第i群可以写成表达式\(F_i(x)=0x^0+s_ix^1+{s_i \choose 2}x^2+···+{s_i \choose s_i}x^{s_i}\)
令\(F(x)=\prod_{i=1}^mF_i(x)\),用分支ntt即可求
则\(ans=[x^k]F(x)\)
点击查看代码
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<queue>
#define ll long long
using namespace std;
const int maxn=400000+10101;
const int MOD=998244353;
inline ll read(){
ll x=0,f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
for(;isdigit(ch);ch=getchar())x=(x<<3)+(x<<1)+ch-'0';
return x*f;
}
typedef vector<ll> Poly;
ll power(ll x,ll y){
ll ans=1;
while(y){
if(y&1)ans=ans*x%MOD;
y>>=1;x=x*x%MOD;
}
return ans;
}
ll inv[maxn],pre[maxn];
ll C(int n,int m){
if(m==0)return 0;
if(n==m)return 1;
ll ans=pre[n]*inv[m]%MOD*inv[n-m]%MOD;
return (ans%MOD);
}
int rev[maxn];
void get(int bit){
for(int i=0;i<(1<<bit);i++)rev[i]=(rev[i>>1]>>1)|((1&i)<<(bit-1));
return ;
}
void ntt(ll *a,int n,int f){
get(log2(n));for(int i=0;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1){
ll wn=power(3,(MOD-1)/(i<<1))%MOD;
if(f==-1)wn=power(wn,MOD-2);
for(int j=0;j<n;j+=i<<1){
ll w=1,x,y;
for(int k=0;k<i;k++,w=wn*w%MOD){
x=a[k+j];y=a[k+j+i]*w%MOD;
a[j+k]=(x+y)%MOD;a[j+k+i]=(x-y)%MOD;
}
}
}
if(f==1)return ;
int nv=power(n,MOD-2);
for(int i=0;i<n;i++)a[i]=a[i]*nv%MOD;
return ;
}
ll F1[maxn],F2[maxn];
Poly mul(Poly A,Poly B){
int n=A.size(),m=B.size(),lens=n+m-1;
int bit=ceil(log2(lens));lens=(1<<bit);
for(int i=0;i<lens;i++)F1[i]=F2[i]=0;
for(int i=0;i<n;i++)F1[i]=A[i];
for(int i=0;i<m;i++)F2[i]=B[i];
ntt(F1,lens,1);ntt(F2,lens,1);
for(int i=0;i<lens;i++)F1[i]=F1[i]*F2[i]%MOD;
ntt(F1,lens,-1);
Poly ans;
for(int i=0;i<n+m-1;i++)ans.push_back(F1[i]);
return ans;
}
int n,m,k,s[maxn];
Poly a[maxn];
void init(){
n=read();m=read();k=read();pre[1]=1ll;
for(ll i=2;i<=n;i++)pre[i]=pre[i-1]*i%MOD;
inv[n]=power(pre[n],MOD-2);
for(ll i=n-1;i>=1;i--)inv[i]=inv[i+1]*(i+1)%MOD;
for(int i=1;i<=m;i++)s[i]=read();
for(int i=1;i<=m;i++){
for(int j=0;j<=s[i];j++)a[i].push_back(C(s[i],j));
}
}
Poly solve(int l,int r){
if(l==r)return a[l];
int mid=(l+r)>>1;
return mul(solve(l,mid),solve(mid+1,r));
}
int main(){
init();printf("%lld",(solve(1,m)[k]%MOD+MOD)%MOD);
return 0;
}
2.tokitsukaze and Another Protoss and Zerg
同上一题,对于第i轮
设\(F_i(x)=(2^{b_i}-1)x^0+a_ix+{a_i \choose 2}x^2+···+{a_i \choose a_i}x^{a_i}\)
用分治ntt即可求出
点击查看代码
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<queue>
#define ll long long
using namespace std;
const int maxn=400000+10101;
const int MOD=998244353;
inline ll read(){
ll x=0,f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
for(;isdigit(ch);ch=getchar())x=(x<<3)+(x<<1)+ch-'0';
return x*f;
}
typedef vector<ll> Poly;
ll power(ll x,ll y){
ll ans=1;
while(y){
if(y&1)ans=ans*x%MOD;
y>>=1;x=x*x%MOD;
}
return ans;
}
ll inv[maxn],pre[maxn];
ll C(int n,int m){
if(m==0)return 0;
if(n==m)return 1;
ll ans=pre[n]*inv[m]%MOD*inv[n-m]%MOD;
return (ans%MOD);
}
int rev[maxn];
void get(int bit){
for(int i=0;i<(1<<bit);i++)rev[i]=(rev[i>>1]>>1)|((1&i)<<(bit-1));
return ;
}
void ntt(ll *a,int n,int f){
get(log2(n));for(int i=0;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1){
ll wn=power(3,(MOD-1)/(i<<1))%MOD;
if(f==-1)wn=power(wn,MOD-2);
for(int j=0;j<n;j+=i<<1){
ll w=1,x,y;
for(int k=0;k<i;k++,w=wn*w%MOD){
x=a[k+j];y=a[k+j+i]*w%MOD;
a[j+k]=(x+y)%MOD;a[j+k+i]=(x-y)%MOD;
}
}
}
if(f==1)return ;
int nv=power(n,MOD-2);
for(int i=0;i<n;i++)a[i]=a[i]*nv%MOD;
return ;
}
ll F1[maxn],F2[maxn];
Poly mul(Poly A,Poly B){
int n=A.size(),m=B.size(),lens=n+m-1;
int bit=ceil(log2(lens));lens=(1<<bit);
for(int i=0;i<lens;i++)F1[i]=F2[i]=0;
for(int i=0;i<n;i++)F1[i]=A[i];
for(int i=0;i<m;i++)F2[i]=B[i];
ntt(F1,lens,1);ntt(F2,lens,1);
for(int i=0;i<lens;i++)F1[i]=F1[i]*F2[i]%MOD;
ntt(F1,lens,-1);
Poly ans;
for(int i=0;i<n+m-1;i++)ans.push_back(F1[i]);
return ans;
}
int n,sa[maxn],sb[maxn],sum;
Poly a[maxn];
void init(){
n=read();pre[1]=1ll;
for(ll i=2;i<=200000;i++)pre[i]=pre[i-1]*i%MOD;
inv[200000]=power(pre[200000],MOD-2);
for(ll i=199999;i>=1;i--)inv[i]=inv[i+1]*(i+1)%MOD;
for(int i=1;i<=n;i++)sa[i]=read(),sum+=sa[i];
for(int i=1;i<=n;i++)sb[i]=read();
for(int i=1;i<=n;i++){
a[i].push_back(power(2,sb[i])-1);
for(int j=1;j<=sa[i];j++)a[i].push_back(C(sa[i],j));
}
}
Poly solve(int l,int r){
if(l==r)return a[l];
int mid=(l+r)>>1;
return mul(solve(l,mid),solve(mid+1,r));
}
int main(){
init();
Poly ans=solve(1,n);
for(int i=0;i<=sum;i++)printf("%lld ",(ans[i]%MOD+MOD)%MOD);
return 0;
}
4.卷积