一遍过,纪念一下
题目描述
给定一颗N个节点组成的树,3种颜色,其中K个节点已染色,要求任意两相邻节点颜色不同,求合法染色方案数。\(1\)\(\le\)\(n\)\(\le\)\(100000\) \(0\)\(\le\)\(k\)\(\le\)\(n\)
思路
dp[i][j]表示i这棵子树,染j这个颜色
转移注意相邻点颜色不同
code
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=100100,mod=1e9+7;
int n,k,fir[N],cnt,col[N],flag;
struct node{
int nxt,to;
}e[N<<1];
void add(int u,int v){
e[++cnt].nxt=fir[u];fir[u]=cnt;e[cnt].to=v;
}
ll dp[N][4];
void dfs(int u,int fa){
if(col[u]){
for(int i=1;i<=3;++i)dp[u][i]=0;
dp[u][col[u]]=1;
}
else for(int i=1;i<=3;++i)dp[u][i]=1;
for(int v,i=fir[u];i;i=e[i].nxt){
v=e[i].to;
if(v==fa)continue;
dfs(v,u);
ll sum=0;
if(!col[u]){
if(col[v]){
for(int j=1;j<=3;++j)
if(j!=col[v])dp[u][j]=dp[u][j]*dp[v][col[v]]%mod;
dp[u][col[v]]=0;
}
else{
for(int j=1;j<=3;++j){
sum=0;
for(int k=1;k<=3;++k)
if(j!=k)sum=(sum+dp[v][k])%mod;
dp[u][j]=dp[u][j]*sum%mod;
}
}
}
else{
if(col[v]){
if(col[u]==col[v]){
flag=1;
}
dp[u][col[u]]=dp[u][col[u]]*dp[v][col[v]]%mod;
}
else{
for(int j=1;j<=3;++j)
if(j!=col[u])sum=(sum+dp[v][j])%mod;
dp[u][col[u]]=dp[u][col[u]]*sum%mod;
}
}
}
}
int main(){
// freopen("t.in","r",stdin);
scanf("%d%d",&n,&k);
int u,v;
for(int i=1;i<n;++i){
scanf("%d%d",&u,&v);
add(u,v),add(v,u);
}
for(int i=1;i<=k;++i){
scanf("%d%d",&u,&v);
col[u]=v;
}
dfs(1,0);
if(flag){
puts("0");
}
else{
ll ans=0;
for(int i=1;i<=3;++i)ans=(ans+dp[1][i])%mod;
printf("%lld\n",ans);
}
return 0;
}