[BZOJ3611] [Heoi2014]大工程(DP + 虚树)

传送门

 

$dp[i][0]$表示节点i到子树中的所有点的距离之和

$dp[i][1]$表示节点i到子树中最近距离的点的距离

$dp[i][2]$表示节点i到子树中最远距离的点的距离

建好虚树后dp即可。

因为对于虚树掌握的还不是很熟,有些细节还是要注意。

虚树中可能会加入一些lca节点,这些节点在dp的时候是不应该统计的。

对于本题来说,别忘记考虑某一节点不同子树中点对的组合。

 

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define N 2000010
#define LL long long

using namespace std;

LL ans1, ans2, dp[N][3];
int n, cnt, rp, m, top, T;
int head[N], to[N], nex[N], val[N], dis[N], size[N], dfn[N], deep[N], f[N][21], q[N], s[N], flag[N];

inline int read()
{
	int x = 0, f = 1;
	char ch = getchar();
	for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = -1;
	for(; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + ch - '0';
	return x * f;
}

inline void add(int x, int y)
{
	to[cnt] = y;
	nex[cnt] = head[x];
	head[x] = cnt++;
}

inline void dfs1(int u)
{
	int i, v;
	dfn[u] = ++rp;
	deep[u] = deep[f[u][0]] + 1;
	for(i = 0; f[u][i]; i++) f[u][i + 1] = f[f[u][i]][i];
	for(i = head[u]; ~i; i = nex[i])
	{
		v = to[i];
		if(!dfn[v])
		{
			f[v][0] = u;
			dis[v] = dis[u] + 1;
			dfs1(v);
		}
	}
	head[u] = -1;
}

inline int calc_lca(int x, int y)
{
	int i;
	if(deep[x] < deep[y]) swap(x, y);
	for(i = 20; i >= 0; i--)
		if(deep[f[x][i]] >= deep[y]) x = f[x][i];
	if(x == y) return x;
	for(i = 20; i >= 0; i--)
		if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
	return f[x][0];
}

inline bool cmp(int x, int y)
{
	return dfn[x] < dfn[y];
}

inline void dfs2(int u)
{
	int i, v;
	size[u] = flag[u];
	dp[u][1] = 1e9, dp[u][0] = dp[u][2] = 0;
	for(i = head[u]; ~i; i = nex[i])
	{
		v = to[i];
		dfs2(v);
		size[u] += size[v];
		ans1 = min(ans1, dp[u][1] + dp[v][1] + dis[v] - dis[u]);
		ans2 = max(ans2, dp[u][2] + dp[v][2] + dis[v] - dis[u]);
		dp[u][0] += dp[v][0] + 1ll * size[v] * (m - size[v]) * (dis[v] - dis[u]);
		dp[u][1] = min(dp[u][1], dis[v] - dis[u] + dp[v][1]);
		dp[u][2] = max(dp[u][2], dis[v] - dis[u] + dp[v][2]);
	}
	if(flag[u])
	{
		ans1 = min(ans1, dp[u][1]);
		ans2 = max(ans2, dp[u][2]);
		dp[u][1] = 0;
	}
	head[u] = -1;
}

inline void solve()
{
	int i, lca;
	m = read();
	top = cnt = 0;
	for(i = 1; i <= m; i++) q[i] = read(), flag[q[i]] = 1;
	sort(q + 1, q + m + 1, cmp);
	for(i = 1; i <= m; i++)
	{
		if(!top)
		{
			s[++top] = q[i];
			continue;
		}
		lca = calc_lca(s[top], q[i]);
		while(dfn[lca] < dfn[s[top]])
		{
			if(dfn[lca] >= dfn[s[top - 1]])
			{
				add(lca, s[top]);
				if(s[--top] != lca) s[++top] = lca;
				break;
			}
			add(s[top - 1], s[top]), top--;
		}
		s[++top] = q[i];
	}
	while(top > 1) add(s[top - 1], s[top]), top--;
	ans2 = 0;
	ans1 = 1ll * 1e9 * 1e9;
	dfs2(s[1]);
	printf("%lld %lld %lld\n", dp[s[1]][0], ans1, ans2);
	for(i = 1; i <= m; i++) flag[q[i]] = 0;
}

int main()
{
	int i, x, y;
	n = read();
	memset(head, -1, sizeof(head));
	for(i = 1; i < n; i++)
	{
		x = read();
		y = read();
		add(x, y);
		add(y, x);
	}
	dfs1(1);
	T = read();
	while(T--) solve();
	return 0;
}

  

posted @ 2018-01-08 19:08  zht467  阅读(109)  评论(0编辑  收藏  举报