Codeforces 1118 F2. Tree Cutting (Hard Version) 优先队列+树形dp

题目要求将树分为k个部分,并且每种颜色恰好在同一个部分内,问有多少种方案。

 

第一步显然我们需要知道哪些点一定是要在一个部分内的,也就是说要求每一个最小的将所有颜色i的点连通的子树。

这一步我们可以将所有有颜色的点丢入优先队列,然后另深度最深的点优先出队。

如果此时这个点的颜色有不只一个点在队列中,那么我们必须要考虑将它的父亲染色,这样才能与其他的该颜色的点连通。

此时有3种情况:

1.如果它的父亲已经被染色且颜色与该点不同,那么此时显然无解;

2.如果它的父亲与它颜色相同,那么此时不做任何操作。

3.如果它的父亲无色,那么将其染色并入队。

经过这样的一番操作后我们已经将必须染色的点染色,那么现在方案数就来自与现在仍然无色的点。

 

第二步,方案数可以用树形dp来求得。

我们将每个点分为两种状态,记dp[now][0]为点now已经确定颜色的方案数,dp[now][1]为未确定颜色的方案数。

接下来分类讨论如何求这两个状态的dp值:

1.如果这个点原本就有颜色

  那么此时显然dp[now][1]=0,dp[now][0]=所有子节点i的(dp[i][0]+dp[i][1])的乘积,因为如果子节点已经染色,那显然状态可以继承,如果未染色,那么显然此时必须被点now染色。

2.如果这个点未被染色

  此时的dp[now][1]就等于情况1的dp[now][0],而dp[now][0]则要在所有子节点中选择一个子节点,令点now被这个子节点i染色,那首先前提显然是i节点已经确定颜色,所以此时枚举每个子节点,

对dp[i][0]*dp[now][1]/(dp[i][0]+dp[i][1])求和。

以下为代码:

#include<bits/stdc++.h>
using namespace std;
const long long mod=998244353;
int i,i0,n,m,k,col[300005],dep[300005],fa[300005],cnt[300005];
vector<int>mp[300005];
void dfs(int now,int d)
{
    dep[now]=d;
    for(int i:mp[now])if(!dep[i])dfs(i,d+1),fa[i]=now;
    return;
}
struct node
{
    int x,d;
    bool operator<(node a)const{return d<a.d;}
};
priority_queue<node>q;
long long dp[300005][2];
void extgcd(long long a,long long b,long long& d,long long& x,long long& y)
{
    if(!b){d=a;x=1;y=0;}
    else{extgcd(b,a%b,d,y,x);y-=x*(a/b);}
}
long long inv(long long a,long long n)
{
    long long d,x,y;
    extgcd(a,n,d,x,y);
    return d==1?(x+n)%n:-1;
}
void dfs0(int now)
{
    dp[now][0]=dp[now][1]=1;
    for(auto i:mp[now])
    {
        if(i==fa[now])continue;
        dfs0(i);
        dp[now][1]*=(dp[i][0]+dp[i][1]);
        dp[now][1]%=mod;
    }
    if(col[now])
    {
        dp[now][0]=dp[now][1];
        dp[now][1]=0;
    }
    if(!col[now])
    {
        dp[now][0]=0;
        for(auto i:mp[now])
        {
            if(i==fa[now])continue;
            dp[now][0]+=dp[now][1]*inv(dp[i][0]+dp[i][1],mod)%mod*dp[i][0]%mod;
            dp[now][0]%=mod;
        }
    }
    return;
}
int main()
{
    scanf("%d %d",&n,&k);
    for(i=1;i<=n;i++)scanf("%d",&col[i]),cnt[col[i]]++;
    for(i=1;i<n;i++)
    {
        int x,y;
        scanf("%d %d",&x,&y);
        mp[x].push_back(y);
        mp[y].push_back(x);
    }
    dfs(1,1);
    for(i=1;i<=n;i++)if(col[i])q.push({i,dep[i]});
    while(!q.empty())
    {
        node tmp=q.top();
        q.pop();
        if(col[fa[tmp.x]]==col[tmp.x])cnt[col[tmp.x]]--;
        else
        {
            if(cnt[col[tmp.x]]!=1)
            {
                if(!col[fa[tmp.x]])
                {
                    col[fa[tmp.x]]=col[tmp.x];
                    q.push({fa[tmp.x],dep[fa[tmp.x]]});
                }
                else
                {
                    printf("0\n");
                    return 0;
                }
            }
        }
    }
    dfs0(1);
    printf("%lld\n",dp[1][0]);
    return 0;
}

 

posted @ 2019-02-27 19:34  BiteTheDDDDt  阅读(412)  评论(0编辑  收藏  举报