LOJ#6289. 花朵 树链剖分+分治NTT

本来以为这道题会非常难调,但是没想到调了不到 5 分钟就 A 了.  

由于基于多项式的运算都可以方便地进行封装,所以细节就不是很多(或者说几乎没有细节)   

题意:给定一棵树,每个点有点权,求对于所有大小为 $m$ 的独立集的点权之积的和.     

数据范围:$n,m \leqslant 8 \times 10^4$.  

先考虑一个十分显然的 $O(n^2)$ 暴力:

令 $f[x][i],g[x][i]$ 分别表示点 $x$ 选/不选的情况下独立集大小为 $i$ 的点积 之和.  

考虑将 $x$ 与 $x$ 的一个儿子 $y$ 合并:$f[x][i+j]=f[x][i] \times f[y][j]$,$g$ 同理.  

然后 $x$ 的初始值是:$f[x][1]=w[x],g[x][0]=1$.    

树形DP 卡一下上界复杂度是 $O(n^2)$ 的.  

不难发现,上述 $f[x][i+j] = f[x][i] \times f[y][j]$ 是一个卷积的形式.  

如果是菊花图或者链的话可以直接用 NTT/分治NTT 来做.   

正解的话考虑进行轻重路径剖分:   

对于一条重链来说,先求出该重链中每个点轻儿子为根的多项式 $f,g$,然后对于重链中每个点都将其轻儿子与该点合并.   

最后对于一条重链进行分治,求出该重链链顶为根的多项式.   

分析一下时间复杂度: 

考虑一条重链链顶为根的子树会被卷多少次:其祖先中每一条重链都会将其贡献一次.  

那么树链剖分中一个点有 $O(\log n)$ 个祖先,而每次卷积的时候对链分治的复杂度是 $O(n \log^2 n)$.  

总复杂度就是 $O(n \log^3 n)$,但是由于树链剖分的常数比较小,跑的并不慢.   

code:  

#include <queue>
#include <cstdio>   
#include <vector>
#include <cstring> 
#include <algorithm>  
#define N 1000009 
#define ll long long 
#define mod 998244353 
#define pb push_back
#define setIO(s) freopen(s".in","r",stdin)  
using namespace std;  
int m; 
int A[N<<2],B[N<<2];      
int tim,edges,n; 
int size[N],son[N],top[N],hd[N],to[N<<1],nex[N<<1],fa[N],dep[N]; 
int dfn[N],bu[N],si[N],val[N];   
void add(int u,int v) { 
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;  
}
int ADD(int x,int y) { 
    return (ll)(x+y)%mod; 
}  
int DEC(int x,int y) { 
    return (ll)(x-y+mod)%mod; 
}  
int MUL(int x,int y) { 
    return (ll)x*y%mod; 
}
int qpow(int x,int y) { 
    int tmp=1; 
    for(;y;y>>=1,x=(ll)x*x%mod) {   
        if(y&1) tmp=(ll)tmp*x%mod; 
    }  
    return tmp; 
}
int get_inv(int x) { 
    return qpow(x,mod-2); 
}
void NTT(int *a,int len,int op) { 
    for(int i=0,k=0;i<len;++i) { 
        if(i>k) { 
            swap(a[i],a[k]); 
        }  
        for(int j=len>>1;(k^=j)<j;j>>=1); 
    }  
    for(int l=1;l<len;l<<=1) { 
        int wn=qpow(3,(mod-1)/(l<<1));  
        if(op==-1) wn=get_inv(wn);  
        for(int i=0;i<len;i+=l<<1) { 
            int w=1;  
            for(int j=0;j<l;++j) { 
                int x=a[i+j],y=(ll)w*a[i+j+l]%mod;  
                a[i+j]=(ll)(x+y)%mod;  
                a[i+j+l]=(ll)(x-y+mod)%mod;  
                w=(ll)w*wn%mod; 
            }
        }
    }
    if(op==-1) { 
        int iv=get_inv(len); 
        for(int i=0;i<len;++i) { 
            a[i]=(ll)a[i]*iv%mod;   
        }
    }
}
struct poly { 
    int len;
    vector<int>a;  
    poly() { len=0,a.clear(); } 
    void push(int x) { 
        a.pb(x),++len;
    }
    void resize(int x) {
        a.resize(x),len=x;    
    }                       
    poly operator*(const poly &b) const { 
        int lim;
        for(lim=1;lim<len+b.len-1;lim<<=1); 
        for(int i=0;i<lim;++i) A[i]=B[i]=0;
        for(int i=0;i<len;++i) A[i]=a[i];
        for(int i=0;i<b.len;++i) B[i]=b.a[i];
        NTT(A,lim,1),NTT(B,lim,1);
        for(int i=0;i<lim;++i) {    
            A[i]=(ll)A[i]*B[i]%mod;
        }
        NTT(A,lim,-1);
        poly c;
        for(int i=0;i<len+b.len-1;++i) { 
            c.push(A[i]); 
        }
        if(c.len>m+1) c.resize(m+1);
        return c;   
    }
    poly operator+(const poly &b) const {
        poly c; 
        c.resize(max(len,b.len));  
        for(int i=0;i<c.len;++i) c.a[i]=0; 
        for(int i=0;i<c.len;++i) {    
            if(i<len) c.a[i]=ADD(c.a[i],a[i]); 
            if(i<b.len) c.a[i]=ADD(c.a[i],b.a[i]);  
        }
        return c;   
    }
    poly operator-(const poly &b) const {    
        poly c;  
        c.resize(max(len,b.len));    
        for(int i=0;i<c.len;++i) c.a[i]=0;
        for(int i=0;i<c.len;++i) { 
            if(i<len) c.a[i]=ADD(c.a[i],a[i]); 
            if(i<b.len) c.a[i]=DEC(c.a[i],b.a[i]);  
        }  
        return c;  
    }
}f0[N],f1[N],g[2][N];      
struct data {
    poly f00,f01,f10,f11;           
    data operator+(const data &b) const { 
        data c;   
        c.f00=(f01*b.f00)+(f00*(b.f00+b.f10));   
        c.f11=(f11*b.f01)+(f10*(b.f11+b.f01));    
        c.f01=(f01*b.f01)+(f00*(b.f01+b.f11));      
        c.f10=(f11*b.f00)+(f10*(b.f10+b.f00));    
        return c;  
    }
}tmp;  
void dfs1(int x,int ff) {  
    fa[x]=ff,dep[x]=dep[ff]+1,size[x]=1;  
    for(int i=hd[x];i;i=nex[i]) { 
        int y=to[i];  
        if(y==ff) continue;  
        dfs1(y,x);
        size[x]+=size[y];
        if(size[y]>size[son[x]]) son[x]=y;
    }
}
void dfs2(int x,int tp) { 
    top[x]=tp;  
    dfn[x]=++tim;
    bu[tim]=x;
    ++si[tp];  
    if(son[x]) {  
        dfs2(son[x],tp); 
    }
    for(int i=hd[x];i;i=nex[i]) {    
        if(to[i]!=fa[x]&&to[i]!=son[x]) { 
            dfs2(to[i],to[i]);  
        }
    }
}
poly calc(int l,int r,int d) {     
    if(l==r) {   
        return g[d][l];  
    }
    int mid=(l+r)>>1;  
    return calc(l,mid,d)*calc(mid+1,r,d);  
}
data solve(int l,int r) {   
    if(l==r) {      
        int u=bu[l];   
        data e;   
        e.f00=f0[u];  
        e.f11=f1[u];  
        return e;  
    }
    int mid=(l+r)>>1;       
    return solve(l,mid)+solve(mid+1,r);  
}
int main() { 
    // setIO("input");  
    int x,y,z; 
    scanf("%d%d",&n,&m);    
    for(int i=1;i<=n;++i) scanf("%d",&val[i]);
    for(int i=1;i<n;++i) {
        scanf("%d%d",&x,&y); 
        add(x,y),add(y,x); 
    }
    dfs1(1,0),dfs2(1,1);       
    for(int i=1;i<=n;++i) {
        f0[i].push(1);  
        f1[i].push(0);  
        f1[i].push(val[i]);    
    }        
    for(int i=n;i>=1;--i) {
        int p=bu[i]; 
        if(top[p]==p) {
            for(int j=dfn[p];j<=dfn[p]+si[p]-1;++j) { 
                x=bu[j];         
                int p0=0,p1=0;      
                for(int e=hd[x];e;e=nex[e]) {
                    y=to[e];  
                    if(y==son[x]||y==fa[x]) continue;            
                    g[0][++p0]=f0[y]+f1[y];   
                    g[1][++p1]=f0[y];  
                }     
                if(p0) f0[x]=calc(1,p0,0);  
                if(p1) f1[x]=f1[x]*calc(1,p1,1); 
            } 
            tmp=solve(dfn[p],dfn[p]+si[p]-1);       
            f0[p]=tmp.f01+tmp.f00;  
            f1[p]=tmp.f10+tmp.f11;             
        }
    }   
    f0[1].resize(m+1); 
    f1[1].resize(m+1);  
    printf("%d\n",(ll)(f0[1].a[m]+f1[1].a[m])%mod);  
    return 0; 
}

  

posted @ 2020-07-25 08:30  EM-LGH  阅读(342)  评论(0编辑  收藏  举报