【bzoj2314】士兵的放置 树形dp

题目描述

八中有N个房间和N-1双向通道,任意两个房间均可到达.现在出了一件极BT的事,就是八中开始闹鬼了。老大决定加强安保,现在如果在某个房间中放一个士兵,则这个房间以及所有与这个房间相连的房间都会被控制.现在

老大想知道至少要多少士兵可以控制所有房间.以及有多少种不同的方案数. 

 

 

输入

 

第一行一个数字N,代表有N个房间,房间编号从1开始到N.N<=500000,下面将有N-1行,每行两个数,代表这两个房间相连. 

 

 

输出

 

第一行输出至少有多少个士兵才可以控制所有房间第二行输出有多少种方案数,方案数会比较大,输出除以1032992941的余数吧. 

 

 

样例输入

6
1 2
1 3
1 5
1 4
5 6

样例输出

2
2


题解

树形dp

经典的最大支配集问题,不过要统计方案数。

发现自己以前求最大支配集的方法太sb了。

设f1表示被自己支配,f2表示被儿子支配,f3表示被父亲支配。

那么显然f1[x]=∑min(f1[son],f2[son],f3[son]),f3[x]=∑f2[son],判断一下能够从哪里转移即可轻松解决。

关键在于f2,要求儿子选择f1与f2中较小的那个,并且还应满足至少有一个儿子选择f1。

那么考虑,枚举到某一个儿子时,之前的儿子只有两种选择:存在选f1的、不存在选f1的。

对于存在,该儿子可能选f1或f2;对于不存在,该儿子只能选f1。同样判断一下即可。这里f3直接表示不存在选f1的,无需再开变量记录。

最后的答案为min(f1[1],f2[1]),同样需要判断一下。

代码略丑= =

#include <cstdio>
#include <algorithm>
#define N 500010
#define mod 1032992941ll
using namespace std;
typedef long long ll;
int head[N] , to[N << 1] , next[N << 1] , cnt;;
ll fa[N] , f1[N] , f2[N] , f3[N] , s1[N] , s2[N] , s3[N];
void add(int x , int y)
{
	to[++cnt] = y , next[cnt] = head[x] , head[x] = cnt;
}
void dfs(int x)
{
	int i , ft;
	ll st;
	f1[x] = s1[x] = s2[x] = s3[x] = 1 , f2[x] = N;
	for(i = head[x] ; i ; i = next[i])
	{
		if(to[i] != fa[x])
		{
			fa[to[i]] = x , dfs(to[i]);
			ft = min(f1[to[i]] , min(f2[to[i]] , f3[to[i]])) , st = 0;
			if(f1[to[i]] == ft) st += s1[to[i]];
			if(f2[to[i]] == ft) st += s2[to[i]];
			if(f3[to[i]] == ft) st += s3[to[i]];
			f1[x] += ft , s1[x] = s1[x] * st % mod;
			ft = min(min(f2[x] + f1[to[i]] , f2[x] + f2[to[i]]) , f3[x] + f1[to[i]]) , st = 0;
			if(f2[x] + f1[to[i]] == ft) st += s2[x] * s1[to[i]] % mod;
			if(f2[x] + f2[to[i]] == ft) st += s2[x] * s2[to[i]] % mod;
			if(f3[x] + f1[to[i]] == ft) st += s3[x] * s1[to[i]] % mod;
			f2[x] = ft , s2[x] = st % mod;
			f3[x] += f2[to[i]] , s3[x] = s3[x] * s2[to[i]] % mod;
		}
	}
}
int main()
{
	int n , i , x , y;
	scanf("%d" , &n);
	for(i = 1 ; i < n ; i ++ ) scanf("%d%d" , &x , &y) , add(x , y) , add(y , x);
	dfs(1);
	if(f1[1] < f2[1]) printf("%lld\n%lld\n" , f1[1] , s1[1]);
	else if(f1[1] > f2[1]) printf("%lld\n%lld\n" , f2[1] , s2[1]);
	else printf("%lld\n%lld\n" , f1[1] , (s1[1] + s2[1]) % mod);
	return 0;
}

 

 

posted @ 2017-07-03 11:13  GXZlegend  阅读(639)  评论(0编辑  收藏  举报