洛谷 P3233 [HNOI2014]世界树(虚树+dp)

题面

luogu

题解

数据范围已经告诉我们是虚树了,考虑如何在虚树上面\(dp\)

以下摘自hzwer博客:

构建虚树以后两遍dp处理出虚树上每个点最近的议事处

然后枚举虚树上每一条边,考虑其对两端点的答案贡献

可以用倍增二分出分界点

如果a,b的分界点为mid,a,b路径上a的第一个儿子为x

则对a的贡献是size[x]-size[mid]

对b的贡献是size[mid]-size[b]

还要算上没被考虑的点

Code

// luogu-judger-enable-o2
#include<bits/stdc++.h>

#define LL long long
#define RG register

using namespace std;
template<class T> inline void read(T &x) {
	x = 0; RG char c = getchar(); bool f = 0;
	while (c != '-' && (c < '0' || c > '9')) c = getchar(); if (c == '-') c = getchar(), f = 1;
	while (c >= '0' && c <= '9') x = x*10+c-48, c = getchar();
	x = f ? -x : x;
	return ;
}
template<class T> inline void write(T x) {
	if (!x) {putchar(48);return ;}
	if (x < 0) x = -x, putchar('-');
	int len = -1, z[20]; while (x > 0) z[++len] = x%10, x /= 10;
	for (RG int i = len; i >= 0; i--) putchar(z[i]+48);return ;
}
int n;
const int N = 300010;
struct node {
	int to, next;
}g[N<<1];
int last[N], gl;
inline void add(int x, int y) {
	g[++gl] = (node) {y, last[x]};
	last[x] = gl;
	return ;
}
int dfn[N], cnt, siz[N], dep[N], anc[N][21], rem[N], bel[N];
void init(int u, int fa) {
	dfn[u] = ++cnt; siz[u] = 1;
	anc[u][0] = fa;
	for (int i = 1; i <= 20; i++)
		anc[u][i] = anc[anc[u][i-1]][i-1];
	for (int i = last[u]; i; i = g[i].next) {
		int v = g[i].to;
		if (v == fa) continue;
		dep[v] = dep[u]+1;
		init(v, u);
		siz[u] += siz[v];
	}
	return ;
}
int lca(int x, int y) {
	if (dep[x] < dep[y]) swap(x, y);
	for (int i = 20; i >= 0; i--)
		if (dep[x]-(1<<i) >= dep[y])
			x = anc[x][i];
	if (x == y) return x;
	for (int i = 20; i >= 0; i--)
		if (anc[x][i] != anc[y][i])
			x = anc[x][i], y = anc[y][i];
	return anc[x][0];
}
int dis(int x, int y) {
	return dep[x]+dep[y]-2*dep[lca(x, y)];
}
int top, len, m, a[N], b[N], s[N], c[N], f[N];
bool cmp(int a, int b) {
	return dfn[a] < dfn[b];
}

inline void insert(int x) {
	if (top == 1) {s[++top] = x; return ;}
	int o = lca(x, s[top]);
	while (top > 1 && dfn[s[top-1]] >= dfn[o]) add(s[top-1], s[top]), top--;
	if (o != s[top]) add(o, s[top]), s[top] = o;
	s[++top] = x;
	return ;
}

void dfs1(int x) {
	c[++len] = x; rem[x] = siz[x];
	for (int i = last[x]; i; i = g[i].next) {
		dfs1(g[i].to);
		if (!bel[g[i].to]) continue;
		int t1 = dis(bel[g[i].to], x), t2 = dis(bel[x], x);
		if ((t1 == t2 && bel[g[i].to] < bel[x]) || t1 < t2 || !bel[x])
			bel[x] = bel[g[i].to];
	}
	return ;
}
void dfs2(int x) {
	for (int i = last[x]; i; i = g[i].next) {
		int t1 = dis(bel[x], g[i].to), t2 = dis(bel[g[i].to], g[i].to);
		if ((t1 == t2 && bel[g[i].to] > bel[x]) || t1 < t2 || !bel[g[i].to])
			bel[g[i].to] = bel[x];
		dfs2(g[i].to);
	}
	return ;
}

void solve(int a, int b) {
	int x = b, mid = b;
	for (int i = 20; i >= 0; i--)
		if (dep[anc[x][i]] > dep[a])
			x = anc[x][i];
	rem[a] -= siz[x];
	if (bel[a] == bel[b]) {
		f[bel[a]] += siz[x]-siz[b];
		return ;
	}
	for (int i = 20; i >= 0; i--) {
		int nxt = anc[mid][i];
		if (dep[nxt] <= dep[a]) continue;
		int t1 = dis(bel[a], nxt), t2 = dis(bel[b], nxt);
		if (t1 > t2 || (t1 == t2 && bel[b] < bel[a])) mid = nxt;
	}
	f[bel[a]] += siz[x]-siz[mid];
	f[bel[b]] += siz[mid]-siz[b];
	return ;
}

void query() {
	top = len = gl = 0;
	read(m);
	for (int i = 1; i <= m; i++) read(a[i]), b[i] = a[i];
	for (int i = 1; i <= m; i++) bel[a[i]] = a[i];
	sort(a+1, a+1+m, cmp);
	if (bel[1] != 1) s[++top] = 1;
	for (int i = 1; i <= m; i++) insert(a[i]);
	for (int i = 1; i < top; i++) add(s[i], s[i+1]);
	dfs1(1); dfs2(1);
	for (int i = 1; i <= len; i++)
		for (int j = last[c[i]]; j; j = g[j].next)
			solve(c[i], g[j].to);
	for (int i = 1; i <= len; i++) f[bel[c[i]]] += rem[c[i]];
	for (int i = 1; i <= m; i++) write(f[b[i]]), putchar(' ');
	putchar('\n');
	for (int i = 1; i <= len; i++) f[c[i]] = bel[c[i]] = last[c[i]] = 0;
	return ;
}

int main() {
	read(n);
	for (int i = 1; i < n; i++) {
		int x, y;
		read(x); read(y);
		add(x, y); add(y, x);
	}
	init(1, 0);
	memset(last, 0, sizeof(last));
	int q; read(q);
	while (q--) query();
	return 0;
}

posted @ 2019-01-08 11:49  zzy2005  阅读(168)  评论(0编辑  收藏  举报