1116考试T1

1116考试T1

​ 题目大意:

​ 哺噜国里有 n 个城市,有的城市之间有高速公路相连。在最开始时,哺噜国里有 n-1条高速公路,且任意两座城市之间都存在一条由高速公路组成的通路。由于高速公路的维护成本很高,为了减少哺噜国的财政支出,将更多的钱用来哺育小哺噜,秀秀女王决定关闭一些高速公路。但是为了保证哺噜国居民的正常生活,不能关闭太多的高速公路,要保证每个城市可以通过高速公路与至少 k 个城市(包括自己)相连。
​ 在得到了秀秀女王的指令后,交通部长华华决定先进行预调研。华华想知道在满足每个城市都可以与至少 k 个城市相连的前提下,有多少种关闭高速公路的方案(可以一条也不关)。两种方案不同,当且仅当存在一条高速公路在一个方案中被关闭,而在另外一个方案中没有被关闭。
​ 由于方案数可能很大,你只需输出不同方案数对 786433 取模后的结果即可。 \(n <= 5000, k <= n\)

​ 树形DP.

\(f[i][j]\)表示以\(i\)为根的子树内, \(i\)所在的联通块大小为\(j\)的方案数.

​ 假设当前在以\(x\)为根的子树内, 现在要把子树\(y\)的答案合并上, 那么转移一下:

\(f[x][i + j] = f[x][i] * f[y][j]\) 表示链接\((x, y)\)这条边.

\(f[x][i] = f[x][i] * f[y][j] (j >= k)\) 表示断开\((x, y)\)这条边.

​ 初值\(f[x][1] = 1\).

​ 这么枚举总的复杂度是\(O(n^ 3)\)的, 我们需要优化.

​ 我们发现\(i, j\)不必枚举到\(n\), \(i\)只需枚举到\(\displaystyle \sum_{d = 1}^{p - 1} siz[d]\), \(j\)只需枚举到\(siz[p]\), \(p\)是当前要合并的子树, 那么就有\(p - 1\)个已经合并过的子树.

​ 经分析复杂度是\(O(n ^ 2)\)的.

​ 分析(题解原话) : 对于点u,枚举i,j总次数等于以点u为根的子树中选取无序点对,使得它们的lca为u的无序点对数。树上每个点对只会在lca处算一次。

#include <bits/stdc++.h>

using namespace std;

inline long long read() {
	long long s = 0, f = 1; char ch;
	while(!isdigit(ch = getchar())) (ch == '-') && (f = -f);
	for(s = ch ^ 48;isdigit(ch = getchar()); s = (s << 1) + (s << 3) + (ch ^ 48));
	return s * f;
}

const int N = 5005, mod = 786433;
int n, K, cnt, ans;
int f[N][N], siz[N], head[N];
struct edge { int to, nxt; } e[N << 1];

void add(int x, int y) {
	e[++ cnt].nxt = head[x]; head[x] = cnt; e[cnt].to = y;
}

void get_tree(int x, int fa) {
	siz[x] = 1;
	for(int i = head[x]; i ; i = e[i].nxt) {
		int y = e[i].to; if(y == fa) continue;
		get_tree(y, x); siz[x] += siz[y];
	}
}

void get_f(int x, int fa) {
	f[x][1] = 1; 
	int lim = 1, h[siz[x] + 1];
	for(int i = head[x]; i ; i = e[i].nxt) {
		int y = e[i].to; if(y == fa) continue;
		get_f(y, x); 
		for(int j = 1;j <= lim + siz[y]; j++) h[j] = 0;
		for(int j = 1;j <= lim; j++) 
			for(int k = 1;k <= siz[y]; k++) {
				h[j + k] += 1ll * f[x][j] * f[y][k] % mod;
				if(k >= K) h[j] += 1ll * f[x][j] * f[y][k] % mod;
			}
		lim += siz[y];
		for(int j = 1;j <= lim; j++) f[x][j] = h[j];
	}
}

int main() {
	
	n = read(); K = read();
	for(int i = 1, x, y;i < n; i++)
		x = read(), y = read(), add(x, y), add(y, x);
	get_tree(1, 0);	get_f(1, 0);
	for(int i = K;i <= siz[1]; i++) ans = (ans + f[1][i]) % mod;
	printf("%d", ans);
	
	return 0;
}
posted @ 2020-11-20 12:21  C锥  阅读(107)  评论(0编辑  收藏  举报