多叉分治半在线卷积

唔,把以前一直口胡着的多叉分治半在线卷积实现了一下啊!

这个算法的思想是,我们分治计算贡献时,设目前区间长度为 \(n\),分成 \(B\) 个部分,并且用 cdq 的方式计算贡献,其中 \(B\) 是关于 \(n\) 的一个函数。

考虑平常 \(B=2\) 时最优,是因为计算两两支之间的贡献均为 \(O(n/B\log (n/B))\) 的,也即要进行 \(B^2\) 轮长度为 \(n/B\) 的 FFT,多叉没有优化效果。

此时,我们注意到转移时的乘法中本质不同的 poly 只有 \(O(B)\) 个,我们可以提前 DFT 算出各个多项式的点值,每次把点值对应做贡献,最后分别一轮 IDFT 插值计算即可。这样只用做大约 \(3B-3\) 轮长度为 \(2n/B\) 的 FFT(使用循环卷积!),以及计算点值间乘法贡献的 \(\Theta(B^2)\) 对。

从而单层的复杂度为 \(\Theta(nB+n\log(n/B))\),取 \(B=\log n\),即得单层复杂度为 \(\Theta(n\log n)\) 的。

可以证明递归层数是 \(\Theta(\log n/\log\log n)\) 级别的,总复杂度即为 \(\Theta(n\log^2n/\log\log n)\)

按照 EI 论文,目前理论界存在一个神秘的 \(O(n\log^{1+\epsilon}n)\) 的半在线卷积算法(\(O(n\log ne^{2\sqrt{\log2\log\log n}})\)),感觉估计在 OI 圈内没啥用,除了指出在线算法往往和迭代法有相近的复杂度以及更小的常数外。

总不会这个东西是对那个 \(O(B^2)\) 对点值转移的过程继续调用当前算法吧?反正我是不会的。。。

2023.4.13 upd:这个东西的复杂度是 \(O(n\log^{\log_23}n)\) 的。另一个更优的做法可以参见 EI 博客。(似乎就是把平衡的部分又细化了一下?)

放一下不等关系的代码。

封装版:

const ullt Mod=998244353,g=3;
typedef ConstMod::mod_ullt<Mod>modint;
typedef std::vector<modint>modvec;
typedef NTT_POLY::poly_NTT<Mod,g>poly;
typedef NTT_POLY::poly_eval<Mod,g>eval;
typedef NTT_POLY::poly_inter<Mod,g>inter;
typedef NTT_POLY::poly_cpd<Mod,g>cpd;
typedef NTT_POLY::poly_nums<Mod,g>nums;
typedef FWT_MODINT::FWT_Mod<Mod>FWT;
modint P[200005],Q[200005],Dp[100005];
bol A[100005];
voi dfs(uint l,uint r){
    if(r-l<=128){
        for(uint i=l;i<r;i++){
            Dp[i]+=Q[i+1];if(A[i])for(uint j=1;i+j<r;j++)Dp[i+j]-=Dp[i]*Q[j];
        }
        return;
    }
    uint B=1;
    if(r-l<=256)B=32;
    else if(r-l<=512)B=64;
    else if(r-l<=1024)B=128;
    else while(B*8<r-l)B<<=1;
    uint Cnt=(r-l+B-1)/B;
    poly::DIFDIT s(B<<1);
    modvec A[15],C[15];
    for(uint i=0;i<Cnt;i++){
        uint L=l+i*B,R=std::min(r,l+(i+1)*B);
        modvec User(B<<1);
        for(uint j=0;j<i;j++)for(uint k=0;k<(B<<1);k++)User[k]+=A[j][k]*C[i-j-1][k];
        s.dit(User);
        for(uint j=0;j<R-L;j++)Dp[L+j]-=User[j+B];
        dfs(L,R);
        if(i!=Cnt-1){
            A[i].resize(B<<1),C[i].resize(B<<1);
            for(uint j=0;j<(B<<1);j++)A[i][j]=Q[i*B+j];
            for(uint j=L;j<R;j++)if(::A[j])C[i][j-L]=Dp[j];
            s.dif(A[i]),s.dif(C[i]);
        }
    }
}
chr C[100005];
int main()
{
#ifdef MYEE
    freopen("QAQ.in","r",stdin);
    // freopen("QAQ.out","w",stdout);
#endif
    P[0]=1;for(uint i=1;i<=100001;i++)P[i]=P[i-1]*i;
    Q[100001]=P[100001].inv();for(uint i=100001;i;i--)Q[i-1]=Q[i]*i;
    uint n=0;scanf("%s",C);while(C[n])n++;
    uint cnt=0;
    for(uint i=0;i<n;i++)Dp[i]=0,A[i]=C[i]=='>',cnt+=A[i];
    dfs(0,n+1);
    (cnt&1?-Dp[n]*P[n+1]:Dp[n]*P[n+1]).println();
    return 0;
}

散装版:

const uint Mod=998244353,g=3;
inline uint chg(uint v){return v<Mod?v:v-Mod;}
class DIFDIT
{
    private:
        uint n;uint*G;
    public:
        DIFDIT():n(0),G(NULL){}
        DIFDIT(uint len)
        {
            n=1;while(n<len)n<<=1;
            G=new uint[n],G[0]=1;
            uint w=power<ullt>(g,(Mod-1)/n,Mod),*End=G+n;
            for(uint*_G=G+1;_G<End;_G++)*_G=(ullt)_G[-1]*w%Mod;
        }
        ~DIFDIT(){if(G!=NULL)delete[]G,G=NULL;}
        inline uint size(){return n;}
        voi dif(std::vector<uint>&x)
        {
            if(x.size()<n)x.resize(n);
            for(uint i=n>>1;i;i>>=1)for(uint j=0;j<n;j+=i<<1)
            {
                uint*w=G;
                for(uint k=0;k<i;k++,w+=n/(2*i))
                {
                    uint u=x[j|k],t=x[i|j|k];
                    x[j|k]=chg(u+t),x[i|j|k]=(ullt)(u+Mod-t)*(*w)%Mod;
                }
            }
        }
        voi dit(std::vector<uint>&x)
        {
            if(x.size()<n)x.resize(n);
            for(uint i=1;i<n;i<<=1)for(uint j=0;j<n;j+=i<<1)
            {
                uint*w=G;
                for(uint k=0;k<i;k++,w+=n/(2*i))
                {
                    uint t=(ullt)x[i|j|k]*(*w)%Mod;
                    x[i|j|k]=chg(x[j|k]+Mod-t),x[j|k]=chg(x[j|k]+t);
                }
            }
            for(uint i=1;i*2<n;i++)std::swap(x[i],x[n-i]);
            uint t=power<ullt>(n,Mod-2,Mod);for(uint i=0;i<n;i++)x[i]=(ullt)x[i]*t%Mod;
        }
};
uint Q[200005],Dp[100005];
bol A[100005];
voi dfs(uint l,uint r){
    if(r-l<=64){
        for(uint i=l;i<r;i++){
            Dp[i]=chg(Dp[i]+Q[i+1]);
            if(A[i])for(uint j=1;i+j<r;j++)
                Dp[i+j]=chg(Dp[i+j]+Mod-(ullt)Dp[i]*Q[j]%Mod);
        }
        return;
    }
    uint B=1;
    if(r-l<=128)B=16;
    else if(r-l<=256)B=32;
    else if(r-l<=512)B=64;
    else if(r-l<=1024)B=128;
    else while(B*8<r-l)B<<=1;
    uint Cnt=(r-l+B-1)/B;
    DIFDIT s(B<<1);
    std::vector<uint>A[7],C[7];
    for(uint i=0;i<Cnt;i++){
        uint L=l+i*B,R=std::min(r,l+(i+1)*B);
        std::vector<uint>User(B<<1);
        for(uint j=0;j<i;j++)for(uint k=0;k<(B<<1);k++)
            User[k]=(User[k]+(ullt)A[j][k]*C[i-j-1][k])%Mod;
        s.dit(User);
        for(uint j=0;j<R-L;j++)Dp[L+j]=chg(Dp[L+j]+Mod-User[j+B]);
        dfs(L,R);
        if(i!=Cnt-1){
            A[i].resize(B<<1),C[i].resize(B<<1);
            for(uint j=0;j<(B<<1);j++)A[i][j]=Q[i*B+j];
            for(uint j=0;j<B;j++)if(::A[j+L])C[i][j]=Dp[j+L];
            s.dif(A[i]),s.dif(C[i]);
        }
    }
}
chr C[100005];
int main()
{
#ifdef MYEE
    freopen("QAQ.in","r",stdin);
    // freopen("QAQ.out","w",stdout);
#endif
    uint n=0,v=1,cnt=0;scanf("%s",C);while(C[n])n++;
    for(uint i=2;i<=n+1;i++)v=(ullt)v*i%Mod;
    Q[n+1]=power<ullt>(v,Mod-2,Mod);
    for(uint i=n+1;i;i--)Q[i-1]=(ullt)Q[i]*i%Mod;
    for(uint i=0;i<n;i++)cnt+=A[i]=C[i]=='>';
    dfs(0,n+1);
    printf("%llu\n",cnt&1?chg(Mod-(ullt)Dp[n]*v%Mod):(ullt)Dp[n]*v%Mod);
    return 0;
}

目前 loj 最优解(301ms):链接

Update:被 Alpha 重测无了。/kel

posted @ 2023-02-24 22:32  myee  阅读(160)  评论(0编辑  收藏  举报