Codeforces 1118 F2. Tree Cutting (Hard Version) 优先队列+树形dp
题目要求将树分为k个部分,并且每种颜色恰好在同一个部分内,问有多少种方案。
第一步显然我们需要知道哪些点一定是要在一个部分内的,也就是说要求每一个最小的将所有颜色i的点连通的子树。
这一步我们可以将所有有颜色的点丢入优先队列,然后另深度最深的点优先出队。
如果此时这个点的颜色有不只一个点在队列中,那么我们必须要考虑将它的父亲染色,这样才能与其他的该颜色的点连通。
此时有3种情况:
1.如果它的父亲已经被染色且颜色与该点不同,那么此时显然无解;
2.如果它的父亲与它颜色相同,那么此时不做任何操作。
3.如果它的父亲无色,那么将其染色并入队。
经过这样的一番操作后我们已经将必须染色的点染色,那么现在方案数就来自与现在仍然无色的点。
第二步,方案数可以用树形dp来求得。
我们将每个点分为两种状态,记dp[now][0]为点now已经确定颜色的方案数,dp[now][1]为未确定颜色的方案数。
接下来分类讨论如何求这两个状态的dp值:
1.如果这个点原本就有颜色
那么此时显然dp[now][1]=0,dp[now][0]=所有子节点i的(dp[i][0]+dp[i][1])的乘积,因为如果子节点已经染色,那显然状态可以继承,如果未染色,那么显然此时必须被点now染色。
2.如果这个点未被染色
此时的dp[now][1]就等于情况1的dp[now][0],而dp[now][0]则要在所有子节点中选择一个子节点,令点now被这个子节点i染色,那首先前提显然是i节点已经确定颜色,所以此时枚举每个子节点,
对dp[i][0]*dp[now][1]/(dp[i][0]+dp[i][1])求和。
以下为代码:
#include<bits/stdc++.h> using namespace std; const long long mod=998244353; int i,i0,n,m,k,col[300005],dep[300005],fa[300005],cnt[300005]; vector<int>mp[300005]; void dfs(int now,int d) { dep[now]=d; for(int i:mp[now])if(!dep[i])dfs(i,d+1),fa[i]=now; return; } struct node { int x,d; bool operator<(node a)const{return d<a.d;} }; priority_queue<node>q; long long dp[300005][2]; void extgcd(long long a,long long b,long long& d,long long& x,long long& y) { if(!b){d=a;x=1;y=0;} else{extgcd(b,a%b,d,y,x);y-=x*(a/b);} } long long inv(long long a,long long n) { long long d,x,y; extgcd(a,n,d,x,y); return d==1?(x+n)%n:-1; } void dfs0(int now) { dp[now][0]=dp[now][1]=1; for(auto i:mp[now]) { if(i==fa[now])continue; dfs0(i); dp[now][1]*=(dp[i][0]+dp[i][1]); dp[now][1]%=mod; } if(col[now]) { dp[now][0]=dp[now][1]; dp[now][1]=0; } if(!col[now]) { dp[now][0]=0; for(auto i:mp[now]) { if(i==fa[now])continue; dp[now][0]+=dp[now][1]*inv(dp[i][0]+dp[i][1],mod)%mod*dp[i][0]%mod; dp[now][0]%=mod; } } return; } int main() { scanf("%d %d",&n,&k); for(i=1;i<=n;i++)scanf("%d",&col[i]),cnt[col[i]]++; for(i=1;i<n;i++) { int x,y; scanf("%d %d",&x,&y); mp[x].push_back(y); mp[y].push_back(x); } dfs(1,1); for(i=1;i<=n;i++)if(col[i])q.push({i,dep[i]}); while(!q.empty()) { node tmp=q.top(); q.pop(); if(col[fa[tmp.x]]==col[tmp.x])cnt[col[tmp.x]]--; else { if(cnt[col[tmp.x]]!=1) { if(!col[fa[tmp.x]]) { col[fa[tmp.x]]=col[tmp.x]; q.push({fa[tmp.x],dep[fa[tmp.x]]}); } else { printf("0\n"); return 0; } } } } dfs0(1); printf("%lld\n",dp[1][0]); return 0; }