Luogu4721 【模板】分治 FFT

https://www.luogu.com.cn/problem/P4721

分治\(FFT\)

\[f_{i}=\sum_{j=1}^{i-1} f_{i-j} g_j \]

等式左右两边均存在函数\(f\),无法直接\(NTT\)

考虑到对于每个\(f_i\),都是有满足\(i<j\)\(j\)转移而来,因此考虑分治

对于一段区间\([l,r]\)假设\([l,mid]\)已经被处理完,我们将\([l,mid]\)\([mid+1,r]\)产生的贡献直接增加到\([mid+1,r]\)

然后我们发现\(mid+1\)上的值也被处理完了,然后\([mid+1,mid+1]\rightarrow [mid+1,mid+2] \rightarrow [mid+1,mid+4]\cdots\),最终成功覆盖了\([l,r]\)

注意不要忘记清空进行\(NTT\)\(a,b\)数组的多余部分,否则必然凉凉

时间复杂度:\(O(n\log^2n)\)

\(C++ Code:\)

#include<iostream>
#include<cstdio>
#include<algorithm>
#define N 200005
#define ll long long
#define p 998244353
using namespace std;
int r[N],lg2[N];
ll inv3,g[N],f[N],a[N],b[N];
int n,s,l;
ll ksm(ll x,ll y)
{
    ll ans=1;
    while (y)
    {
        if (y & 1)
            ans=ans*x%p;
        x=x*x%p;
        y >>=1;
    }
    return ans;
}
#define inv(x) (ksm(x,p-2))
void NTT(ll *a,int t)
{
    for (int i=0;i<s;i++)
        if (i<r[i])
            swap(a[i],a[r[i]]);
    for (int mid=1,w=2;mid<s;mid <<=1,w++)
    {
        ll gn=ksm((t==0)?3:inv3,(p-1)/(mid << 1));
        for (int j=0;j<s;j+=(mid << 1))
        {
            ll g=1;
            for (int k=0;k<mid;k++,g=g*gn%p)
            {
                ll x=a[j+k],y=g*a[j+k+mid]%p;
                a[j+k]=(x+y)%p;
                a[j+k+mid]=(x-y)%p;
            }
        }
    }
}
void FZ_NTT(int L,int R)
{
    if (L==R)
        return;
    int mid=(L+R) >> 1;
    ll t=inv(R-L+1);
    FZ_NTT(L,mid);
    s=1 << lg2[R-L+1];
    for (int i=L;i<=mid;i++)
        a[i-L]=f[i];
    for (int i=mid-L+1;i<s;i++)
        a[i]=0;
    for (int i=L;i<=R;i++)
        b[i-L]=g[i-L];
    for (int i=R-L+1;i<s;i++)
        b[i]=0;
    for (int i=0;i<s;i++)
        r[i]=(r[i >> 1] >> 1) | ((i & 1) << (lg2[R-L+1]-1));
    NTT(a,0);
    NTT(b,0);
    for (int i=0;i<s;i++)
        a[i]=a[i]*b[i]%p;
    NTT(a,1);
    for (int i=0;i<s;i++)
        a[i]=a[i]*t%p;
    for (int i=mid+1;i<=R;i++)
        f[i]=(f[i]+a[i-L])%p;
    FZ_NTT(mid+1,R);
}
int main()
{
    inv3=ksm(3,p-2);
    scanf("%d",&n);
    n--;
    for (int i=1;i<=n;i++)
        scanf("%lld",&g[i]);
    f[0]=1;
    lg2[0]=0;
    l=0;
    for (int i=1;i<=n;i++)
    {
        if ((1 << l)<i)
            l++;
        lg2[i]=l;
    }
    for (int i=n+1;i<=(1 << lg2[n]);i++)
        lg2[i]=lg2[n];
    FZ_NTT(0,(1 << lg2[n])-1);
    for (int i=0;i<=n;i++)
    {
        f[i]=(f[i]%p+p)%p;
        printf("%lld ",f[i]);
    }
    putchar('\n');
    return 0;
}
posted @ 2020-07-30 20:17  GK0328  阅读(74)  评论(0编辑  收藏  举报