[CC-BLREDSET]Black and Red vertices of Tree

[CC-BLREDSET]Black and Red vertices of Tree

题目大意:

有一棵\(n(\sum n\le10^6)\)个结点的树,每个结点有一种颜色(红色、黑色、白色)。删去一个由红色点构成的连通块,使得存在一个黑点和一个白点,满足这两个点不连通。问有多少种删法。

思路:

设满足删掉这个点后,使得存在一个黑点和一个白点,满足这两个点不连通的红点为关键点。那么我们可以用两个\(\mathcal O(n)\)的树形DP求出所有的关键点。剩下的问题就变成了求有多少种全红连通块使得该连通块中至少有一个关键点,这显然又可以用一个\(\mathcal O(n)\)树形DP求出。

源代码:

#include<cstdio>
#include<cctype>
#include<vector>
inline int getint() {
	register char ch;
	while(!isdigit(ch=getchar()));
	register int x=ch^'0';
	while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0');
	return x;
}
const int N=1e5+1,mod=1e9+7;
bool mark[N];
int col[N],cnt1[N],cnt2[N],f[N][2];
std::vector<int> e[N];
inline void add_edge(const int &u,const int &v) {
	e[u].push_back(v);
	e[v].push_back(u);
}
void dfs(const int &x,const int &par) {
	cnt1[x]=cnt2[x]=0;
	if(col[x]==1) cnt1[x]=1;
	if(col[x]==2) cnt2[x]=1;
	for(unsigned i=0;i<e[x].size();i++) {
		const int &y=e[x][i];
		if(y==par) continue;
		dfs(y,x);
		cnt1[x]+=cnt1[y];
		cnt2[x]+=cnt2[y];
	}
}
void move(const int &x,const int &par) {
	bool g1=false,g2=false;
	if(x!=1) {
		g1=cnt1[par]-cnt1[x];
		g2=cnt2[par]-cnt2[x];
		cnt1[x]+=cnt1[par]-cnt1[x];
		cnt2[x]+=cnt2[par]-cnt2[x];
	}
	mark[x]=false;
	for(unsigned i=0;i<e[x].size();i++) {
		const int &y=e[x][i];
		if(y==par) continue;
		mark[x]|=cnt1[y]&&g2;
		mark[x]|=cnt2[y]&&g1;
		g1|=cnt1[y];
		g2|=cnt2[y];
		move(y,x);
	}
}
void dp(const int &x) {
	col[x]=-1;
	f[x][mark[x]]=1;
	f[x][!mark[x]]=0;
	for(unsigned i=0;i<e[x].size();i++) {
		const int &y=e[x][i];
		if(col[y]) continue;
		dp(y);
		f[x][1]=(1ll*f[x][1]*(f[y][0]+f[y][1]+1)%mod+1ll*f[x][0]*f[y][1]%mod)%mod;
		f[x][0]=1ll*f[x][0]*(f[y][0]+1)%mod;
	}
}
int main() {
	for(register int T=getint();T;T--) {
		const int n=getint();
		for(register int i=1;i<n;i++) {
			add_edge(getint(),getint());
		}
		for(register int i=1;i<=n;i++) {
			col[i]=getint();
		}
		dfs(1,0);
		move(1,0);
		for(register int i=1;i<=n;i++) {
			if(!col[i]) dp(i);
		}
		for(register int i=1;i<=n;i++) {
			e[i].clear();
		}
		int ans=0;
		for(register int i=1;i<=n;i++) {
			if(col[i]==-1) (ans+=f[i][1])%=mod;
		}
		printf("%d\n",ans);
	}
	return 0;
}
posted @ 2018-11-03 13:23  skylee03  阅读(250)  评论(0编辑  收藏  举报