51nod1514 美妙的序列 分治NTT

显然,不合法的情况要存在序列被分成值域为 $[1,i]$ 与 $[i+1,r]$ 两部分.  

不妨采用容斥的方法来减去所有不合法的情况.    

令 $f[i]$ 表示 $1$ ~ $i$ 构成的合法序列数目.  

那么不合法的情况一定可以表示为 $f[j] \times (i-j)!$ 即前 $j$ 个数组成的连通块合法,然后第一个不合法位点为 $(j,j+1)$  

由于每一次第一个不合法位点不同,所以不会减多.   

$f[n]=n!-\sum_{j=1}^{i-1} f[j] \times (n-j)!$ 这个式子用分治 NTT 加速就好了.    

code:   

#include <cstdio>  
#include <cstring>
#include <algorithm>   
#define N 100007
#define ll long long 
#define mod 998244353
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std; 
int n;  
int fac[N],f[N],g[N],A[N<<2],B[N<<2];   
void init() {  
    fac[0]=1;  
    for(int i=1;i<N;++i) { 
        fac[i]=(ll)fac[i-1]*i%mod;  
    }
}
int qpow(int x,int y) {
    int tmp=1;
    for(;y;y>>=1,x=(ll)x*x%mod) {
        if(y&1) tmp=(ll)tmp*x%mod;
    }
    return tmp;
} 
int get_inv(int x) {
    return qpow(x,mod-2);
} 
void NTT(int *a,int len,int op) {
    for(int i=0,k=0;i<len;++i) {
        if(i>k) swap(a[i],a[k]);
        for(int j=len>>1;(k^=j)<j;j>>=1); 
    }     
    for(int l=1;l<len;l<<=1) {
        int wn=qpow(3,(mod-1)/(l<<1)); 
        if(op==-1) {
            wn=get_inv(wn);
        } 
        for(int i=0;i<len;i+=l<<1) {
            int w=1,x,y; 
            for(int j=0;j<l;++j) {
                x=a[i+j],y=(ll)a[i+j+l]*w%mod; 
                a[i+j]=(ll)(x+y)%mod; 
                a[i+j+l]=(ll)(x-y+mod)%mod; 
                w=(ll)w*wn%mod; 
            }
        }
    }   
    if(op==-1) {  
        int in=get_inv(len);
        for(int i=0;i<len;++i) {
            a[i]=(ll)a[i]*in%mod; 
        }
    }
} 
void solve(int l,int r) {
    if(l==r) {
        return;
    }  
    int mid=(l+r)>>1,lim,s1=0,s2=0;
    solve(l,mid); 
    for(int i=l;i<=mid;++i) A[s1++]=f[i];  
    for(int i=0;i<=r-l;++i) B[s2++]=g[i]; 
    for(lim=1;lim<(s1+s1);lim<<=1); 
    for(int i=s1;i<lim;++i) A[i]=0;
    for(int i=s2;i<lim;++i) B[i]=0;  
    NTT(A,lim,1),NTT(B,lim,1);
    for(int i=0;i<lim;++i) A[i]=(ll)A[i]*B[i]%mod; 
    NTT(A,lim,-1);
    for(int i=mid+1;i<=r;++i) {
        f[i]=(ll)(f[i]-A[i-l]+mod)%mod; 
    }  
    for(int i=0;i<lim;++i) A[i]=B[i]=0; 
    solve(mid+1,r);
}
char *p1,*p2,buf[100000];  
#define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)    
int rd() { 
    int x=0;char c;   
    do { c=nc();}while(c<48);  
    while(c>47) { 
        x=(((x<<2)+x)<<1)+(c^48); 
        c=nc();  
    }   
    return x;  
}
int main() { 
    // setIO("input");     
    init(),n=100000;  
    for(int i=1;i<=n;++i) {
        f[i]=g[i]=fac[i];   
    }           
    solve(1,n);        
    int T=rd(),x,y;
    while(T--) {    
        x=rd();
        printf("%d\n",f[x]);  
    }
    return 0;  
}

  

posted @ 2020-07-21 16:19  EM-LGH  阅读(149)  评论(0编辑  收藏  举报