#3541. 花朵(flowers)

题目描述
小 F 的生日还有一个多月,大 F 早早地准备起了礼物。

> “你想要什么礼物呀?嗯...要不要好吃的?”

> “才不要呢,我想要好看的花,永远不会凋谢的花。”

小 F 和大 F 一起生活的国家—— Fairy 国,可以抽象成一棵 $N$ 个节点的树,每个节点就是一个城市,编号为 $1\ldots N$。

大 F 要游历各个城市,为心爱的小 F 寻找好看的花。

Fairy 国的每个城市都有一座山,山上有恰好一朵永远不会凋谢的花,编号为 $i$ 的城市的花的美丽值为 $B_i$。大 F 要在 $N$ 个城市中选出恰好 $M$ 个,并摘来这 $M$ 个城市中的 $M$ 朵花送给小F。可是呢,如果树上的一条边连接的两个城市的花都被摘去,这条边就会塌陷,Fairy 国就会陷入分裂,大 F 作为一个善良的人,不希望这样的情况发生。所以,**一种摘法合法,当且仅当对于每条边,这条边相连的两个节点的花不被同时摘去**。

大 F 希望小 F 快乐,小 F 的快乐程度将是摘来的 $M$ 朵花的美丽程度的积。大 F 今天闲着没事,想要求出对于所有合法的摘法,小 F 的快乐程度之和对
$998244353$ 取模的结果。

数据范围
保证 $1 \le M \le N \le 8 \times 10^4$,$0 \le B_i < 998244353$。

题解
考虑暴力dp,设 $f_{i,0/1,j}$ 表示 $i$ 子树内取了 $j$ 个点, $i$ 取/不取的价值之和

如果是一条链的话,可以分治+ $Ntt$实现

设 $F_{l,r,0/1,0/1}(x)$ 表示 $[l,r]$ 区间, $l$ 取/不取, $r$ 取/不取,摘了 $i$ 朵的答案在 $x^i$ 的系数上

考虑一棵树,将其树链剖分,一条重链上的每个点要先将其轻儿子的信息也用分治+ $Ntt$ 的方法存到每个点上,再将这条重链用上述方法操作即可

效率不会证,貌似 $O(nlog^2n\ /\ nlog^3n)$

code

#include <bits/stdc++.h>
#define E vector<int>
#define mid ((l+r)>>1)
using namespace std;
const int N=2e5+5,P=998244353;
int n,m,w[N],sz[N],son[N],fa[N],t,tt,hd[N],V[N*2];
int nx[N*2],T[N],b[N],B,re[N*8],S[2]={3,(P+1)/3},p,G[N*8],H[N*8];
E f[2][N],g[2][N],I,M,W;struct O{E a[2][2];}tmp,rs;
int K(int x,int y){
    int z=1;
    for (;y;y>>=1,x=1ll*x*x%P)
        if (y&1) z=1ll*z*x%P;
    return z;
}
void add(int u,int v){
    nx[++tt]=hd[u];V[hd[u]=tt]=v;
}
void pre(int l){
    for (t=1,p=0;t<l;t<<=1,p++);
    for (int i=0;i<t;i++)
        re[i]=(re[i>>1]>>1)|((i&1)<<(p-1));
}
void Ntt(int *s,bool o){
    for (int i=0;i<t;i++)
        if (i<re[i]) swap(s[i],s[re[i]]);
    for (int wn,i=1;i<t;i<<=1){
        wn=K(S[o],(P-1)/(i<<1));
        for (int x,y,j=0;j<t;j+=(i<<1))
            for (int w=1,k=0;k<i;k++,w=1ll*w*wn%P)
                x=s[j+k],y=1ll*w*s[i+j+k]%P,
                s[j+k]=(x+y)%P,s[i+j+k]=(x-y+P)%P;
    }
    if (o)
        for (int i=0,v=K(t,P-2);i<t;i++)
            s[i]=1ll*v*s[i]%P;
}
E by(E a,E b){
    int la=a.size(),lb=b.size();pre(la+lb);
    for (int i=0;i<la;i++) G[i]=a[i];
    for (int i=0;i<lb;i++) H[i]=b[i];
    Ntt(G,0);Ntt(H,0);
    for (int i=0;i<t;i++) G[i]=1ll*G[i]*H[i]%P;
    Ntt(G,1);W.clear();
    for (int i=0;i<la+lb-1;i++) W.push_back(G[i]);
    for (int i=0;i<t;i++) G[i]=H[i]=0;
    return W;
}
E ad(E a,E b){
    int la=a.size(),lb=b.size(),lc=max(la,lb);
    for (int i=0;i<lc;i++)
        G[i]=((i<la?a[i]:0)+(i<lb?b[i]:0))%P;
    W.clear();
    for (int i=0;i<lc;i++) W.push_back(G[i]),G[i]=0;
    return W;
}
E div(int l,int r,int o){
    if (l==r) return f[o][b[l]];
    return by(div(l,mid,o),div(mid+1,r,o));
}
O solve(int l,int r){
    if (l==r){
        for (int i=0;i<2;i++)
            tmp.a[i][i]=g[i][T[l]],
            tmp.a[i][!i]={0};
        return tmp;
    }
    O L=solve(l,mid),R=solve(mid+1,r);
    for (int i=0;i<2;i++)
        for (int j=0;j<2;j++)
            tmp.a[i][j]=ad(by(L.a[i][0],ad(R.a[0][j],R.a[1][j])),by(L.a[i][1],R.a[0][j]));
    return tmp;
}
void work(int x){
    tt=0;int u=x;
    for (;x;x=son[x]){
        T[++tt]=x;B=0;M[1]=w[x];
        for (int i=hd[x];i;i=nx[i])
            if (V[i]!=fa[x] && V[i]!=son[x])
                b[++B]=V[i];
        g[1][x]=by(B?div(1,B,0):I,M);
        g[0][x]=B?div(1,B,1):I;
    }
    rs=solve(1,tt);
    f[0][u]=ad(rs.a[0][0],rs.a[0][1]);
    f[1][u]=ad(f[0][u],ad(rs.a[1][0],rs.a[1][1]));
}
void dfs(int u,int fr){
    sz[u]=1;fa[u]=fr;
    for (int i=hd[u];i;i=nx[i])
        if (V[i]!=fr){
            dfs(V[i],u),sz[u]+=sz[V[i]];
            if (sz[V[i]]>sz[son[u]]) son[u]=V[i];
        }
    for (int i=hd[u];i;i=nx[i])
        if (V[i]!=fr && V[i]!=son[u]) work(V[i]);
}
int main(){
    scanf("%d%d",&n,&m);
    for (int i=1;i<=n;i++)
        scanf("%d",&w[i]);
    for (int x,y,i=1;i<n;i++)
        scanf("%d%d",&x,&y),
        add(x,y),add(y,x);
    I.push_back(1);M.push_back(0);
    M.push_back(0);dfs(1,0);work(1);
    if (m<(int)f[1][1].size()) printf("%d\n",f[1][1][m]);
    else puts("0");return 0;
}

 

posted @ 2019-07-22 10:45  xjqxjq  阅读(442)  评论(0编辑  收藏  举报