两个多项式的卷积【NTT】

题意

2020计蒜之道决赛:

给出两个多项式 \(A(x)\)\(B(x)\)

\[A(x)=a_0+a_1x+a_2x^2+a_3x^3+\cdots +a_nx^n\\ B(x)=b_0+b_1x+b_2x^2+b_3x^3+\cdots +b_nx^n \]

\(C(x)\) 为上述两个多项式的卷积:

\[C(x)=A(x)B(x)=c_0+c_1x+c_2x^2\cdots+c_{2n}x^{2n} \]

现有 \(m\) 次操作,每次操作可能查询 \(\sum_{i=l}^{r}{c_i}\) ,也可能修改 \(A(x)\) 中的某个系数。

具体如下:

1 l r:代表查询 \(\sum_{i=l}^{r}{c_i}\)

2 p q:表示把 \(A(x)\)\(x^p\) 的系数增加 \(q\)

\(1\leq n \leq 5000,-10^5\leq a_i \leq 10^5,-10^5\leq b_i\leq 10^5,0\leq p\leq n,-10^5\leq q \leq 10^5\)

输出结果对 \(998244353\) 取模。

分析

对于某段区间的查询,可以转化为对两个前缀和的查询。但对于第二种操作,如果每次修改之后直接做 \(NTT\) ,复杂度为 \(O(mn\log n)\)

对于这两个操作而言,一种是“单次优,但数量太大”,一种是“单次劣,但是不限于操作次数”。我们可以减少 \(NTT\) 的次数,同时控制第一种算法记录的量不要太大。

维护一个大小为 \(S\) 的集合,每次来一个修改就把这次的修改信息记录到集合中,当集合大小增长到一定的阀值,就做一次 \(NTT\) ,同时把集合清空。而一次查询的答案,等于上次做完 \(NTT\) 的时候的答案,再加上集合中记录的修改操作对当前查询的影响,通过合理控制 \(S\) 的大小,可以做到既可以不让集合太大,又可以减小 \(NTT\) 的次数。

复杂度为:\(O(mS+\frac{m}{S}n\log n)\),通过均值不等式可知:\(S=\sqrt{n\log n}\) 时,复杂度最优秀。

代码

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const int mod=998244353;
const int N=5100;
const int g=3;
ll A[N<<2],B[N<<2],C[N<<2],pre[N<<2],a[N<<2];
int rev[N<<2],S[N];
int pos[N],w[N],num;
ll power(ll x,ll y)
{
    ll res=1;
    x%=mod;
    while(y)
    {
        if(y&1)
            res=res*x%mod;
        x=x*x%mod;
        y>>=1;
    }
    return res;
}
void NTT(ll *pn,int len,int f)
{
    for(int i=0;i<len;i++)
        if(i<rev[i]) swap(pn[i],pn[rev[i]]);
    for(int i=1;i<len;i<<=1)
    {
        ll wn=power(g,1LL*(mod-1)/(2LL*i));
        if(f==-1) wn=power(wn,mod-2);
        for(int j=0,d=(i<<1);j<len;j+=d)
        {
            ll w=1;
            for(int k=0;k<i;k++)
            {
                ll u=pn[j+k],v=w*pn[j+k+i]%mod;
                pn[j+k]=(u+v)%mod,pn[j+k+i]=((u-v)%mod+mod)%mod;
                w=wn*w%mod;
            }
        }
    }
    if(f==-1)
    {
        ll inv=power(1LL*len,mod-2);
        for(int i=0;i<len;i++)
            pn[i]=pn[i]*inv%mod;
    }
}
void dontt(int len)
{
    for(int i=0;i<len;i++)
        A[i]=a[i];//因为NTT会改变原有系数的位置和值
    NTT(A,len,1);
    for(int i=0;i<len;i++)
        C[i]=A[i]*B[i]%mod;
    NTT(C,len,-1);
    for(int i=1;i<len;i++)
        C[i]=(C[i-1]+C[i])%mod;
}
ll cal(int x)
{
    if(x<0) return 0;
    ll res=C[x];
    for(int i=1;i<=num;i++)
    {
        int t=x-pos[i];
        if(t>=0)//确定B需要的前缀和最大下标
            res=(res+(pre[t]*w[i]%mod)+mod)%mod;
    }
    return res;
}
int main()
{
    int n,m,op,x,y;
    scanf("%d",&n);
    for(int i=0;i<=n;i++)
    {
        scanf("%lld",&a[i]);
        a[i]=(a[i]+mod)%mod;
    }
    for(int i=0;i<=n;i++)
    {
        scanf("%lld",&B[i]);
        B[i]=(B[i]+mod)%mod;
        if(i==0)
            pre[i]=B[i];
        else
            pre[i]=(pre[i-1]+B[i])%mod;
    }
    int cnt=0,len=1;
    while(len<=2*n)//注意len不要开小了
    {
        len<<=1;
        cnt++;
    }
    for(int i=n+1;i<len;i++)
        pre[i]=(pre[i-1]+B[i])%mod;
    for(int i=0;i<len;i++)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(cnt-1));
    NTT(B,len,1);
    dontt(len);
    int mx=(int)sqrt(1.0*n*log2(1.0*n));
    num=0;
    scanf("%d",&m);
    while(m--)
    {
        scanf("%d%d%d",&op,&x,&y);
        if(op==1)
            printf("%lld\n",(cal(y)-cal(x-1)+mod)%mod);
        else
        {
            pos[++num]=x;//记录修改的位置
            a[x]=(a[x]+y+mod)%mod;
            w[num]=y;//
            if(num>=mx)
            {
                dontt(len);
                num=0;
            }
        }
    }
    return 0;
}
posted @ 2020-11-16 11:05  xzx9  阅读(388)  评论(0编辑  收藏  举报