P4721【模板】分治 FFT

瞎扯

虽然说是FFT但是还是写了一发NTT(笑)
然后忘了IDFT之后要除个n懵逼了好久
以及递归的时候忘了边界无限RE

思路

朴素算法

分治FFT
考虑到题目要求求这样的一个式子

\[F_x=\Sigma_{i=1}^{x}F_{x-i}G_{i} \]

我们可以按定义暴力,然后再松式卡常(不是)
我们可以发现它长得像一个卷积一样,但是因为后面的f值会依赖与前面的f值,所以没法一遍FFT直接求出结果,而对每个f都跑一遍FFT太慢了,我们使用分治优化这个过程就很优秀了,复杂度是\(O(n\log^2 n)\)

分治优化

我们能够想到cdq分治的思想,在统计一个区间时,确保对这个区间有影响的操作产生的贡献已经全被统计,就是先统计[l,mid]区间对[mid+1,r]区间的贡献
然后发现对于每个\(f_x\),它对后面的\(f_i\)产生的贡献是\(\Sigma_{j=l}^{mid} f_{i}g_{i-j}\)

然后分治就好

代码

#include <cstdio>
#include <algorithm>
#include <cstring>
#define int long long
using namespace std;

const int MOD = 998244353,G=3,invG=332748118;
int a[200000],b[200000],f[200000],g[200000],n;
int pow(int a,int b){
    int ans=1;
    while(b){
        if(b&1)
            ans=(1LL*ans*a)%MOD;
        a=(1LL*a*a)%MOD;
        b>>=1;
    }
    return ans;
}
void FFT(int *a,int opt,int n){
    int lim=0;
    while((1<<lim)<n)
        lim++;
    for(int i=0;i<n;i++){
        int t=0;
        for(int j=0;j<lim;j++)
            if((i>>j)&1)
                t|=(1<<(lim-j-1));
        if(i<t)
            swap(a[i],a[t]);        
    }
    for(int i=2;i<=n;i<<=1){
        int len=i/2;
        int tmp=pow((opt)?G:invG,(MOD-1)/i);
        for(int j=0;j<n;j+=i){
            int arr=1;
            for(int k=j;k<len+j;k++){
                int t=arr*a[k+len];
                a[k+len]=((a[k]-t)%MOD+MOD)%MOD;
                a[k]=(a[k]+t)%MOD;
                arr=(arr*tmp)%MOD;         
            }
        }
    }
    if(opt==0){
        int invt=pow(n,MOD-2);
        for(int i=0;i<n;i++)
            a[i]=a[i]*invt%MOD;
    }
}
void solve(int l,int r){
    if(r-l==1)
        return;
    int t=pow(r-l,MOD-2);
    int mid=(l+r)>>1;
    solve(l,mid);
    memset(a+(r-l)/2,0,sizeof(int)*(r-l)/2);
    memcpy(a,f+l,sizeof(int)*(r-l)/2);
    memcpy(b,g,sizeof(int)*(r-l));
    FFT(a,1,r-l);
    FFT(b,1,r-l);
    for(int i=0;i<r-l;i++)
        a[i]=(a[i]*b[i])%MOD;
    FFT(a,0,r-l);
    for(int i=(r-l)/2;i<r-l;i++)
        f[l+i]=(f[l+i]+a[i])%MOD;
    solve(mid,r);
}
signed main(){
    int mid;
    scanf("%lld",&n);
    mid=n;
    for(int i=1;i<=n-1;i++)
        scanf("%lld",&g[i]);
    int t=1;
    while(t<n)
        t<<=1;
    n=t;
    f[0]=1;
    solve(0,n);
    for(int i=0;i<mid;i++)
        printf("%lld ",f[i]);
    return 0;
}
posted @ 2019-02-25 11:27  dreagonm  阅读(173)  评论(0编辑  收藏  举报