bzoj3572 世界树(虚树,倍增)

题目链接

解题思路

  题目的问法很明显是虚树题的问法。根据询问点建虚树,因为虚树不止有询问的点,还有他们的lca,所以先对建出来的虚树预处理出来虚树上每个点离他最近的议事处的距离和点的编号(自己就是议事处肯定就是自己,主要是对增加的lca进行处理)。因为最近的点可能来自下面的点也可能是上面的点,所以跑两次dfs从两个方向更新一下即可。
  对于虚树的一条边来说,两个点都有其对应的议事处,它们之间的边代表原树上的一些点(或者说子树)。对于每条边我们都计算一下两个点对应的议事处的分界点(注意不是这个点本身,以为它可能不是议事处而是新增的lca)即可。方法是用倍增从深度最大的那个点往上跳,计算跳到的那个点当分界点时到两个议事处的距离,注意还要考虑距离相等编号选最小的议事处。
  还有一个细节需要注意,对于点u和点v(u是v的祖先),我们计算分界点d分开的点的数量不能直接加上sz[u]-sz[d],因为u有好几个子树,这样加肯定会重复,可以算出u的儿子son,满足son是v的祖先,这样每次减去sz[son],加上sz[son]-sz[d],最后再加上sz[u]就行了。

代码

const int maxn = 6e5+10;
struct E {
	int to, nxt;
} e[maxn];
int h[maxn], tot;
void add(int u, int v) {
	e[++tot] = {v, h[u]};
	h[u] = tot;
}
int ldfn[maxn], rdfn[maxn], tim;
int dep[maxn], f[maxn][20], sz[maxn];
void dfs(int u, int p) {
    //cout << "!" << u << endl;
	ldfn[u] = ++tim; sz[u] = 1;
	for (int i = h[u]; i; i=e[i].nxt) {
		int v = e[i].to;
		if (v==p) continue;
		dep[v] = dep[u]+1;
		f[v][0] = u;
		for (int j = 1; j<20; ++j) f[v][j] = f[f[v][j-1]][j-1];
		dfs(v, u);
		sz[u] += sz[v];
	}
	rdfn[u] = tim;
}
int lca(int u, int v) {
	if (dep[u]<dep[v]) swap(u, v);
	for (int i = 19; i>=0; --i)
		if (dep[f[u][i]]>=dep[v]) u = f[u][i];
	if (u==v) return u;
	for (int i = 19; i>=0; --i)
		if (f[u][i]!=f[v][i]) u = f[u][i], v = f[v][i];
	return f[u][0];
}
int n, m, k, d[maxn], sk[maxn], vis[maxn];
void build() {
    tot = 0; //!!!!
	sort(d+1, d+k+1, [](int a, int b) {return ldfn[a]<ldfn[b];});
    //for (int i = 1; i<=k; ++i) cout << d[i] << (i==k ? '\n':' ');
	int keynum = k;
	for (int i = 1; i<keynum; ++i) d[++k] = lca(d[i], d[i+1]);
	sort(d+1, d+k+1, [](int a, int b) {return ldfn[a]<ldfn[b];});
	k = unique(d+1, d+k+1)-d-1;
	int tp = 1; sk[tp] = d[1];
	for (int i = 2; i<=k; ++i) {
		while(tp && rdfn[sk[tp]]<ldfn[d[i]]) --tp; 
		if (tp) add(sk[tp], d[i]);
		sk[++tp] = d[i];
	}
}
int dp[maxn], g[maxn];
void dfs1(int u, int p) {
	dp[u] = INF;
	for (int i = h[u]; i; i=e[i].nxt) {
		int v = e[i].to;
		if (v==p) continue;
        dfs1(v, u);
		int dis = dep[v]-dep[u]; //cout << dis << endl; //注意相同距离取编号小的
		if (dp[v]+dis<dp[u] || (dp[v]+dis==dp[u]&&g[v]<g[u])) dp[u] = dp[v]+dis, g[u] = g[v];
        //cout << "!" << u << ' ' << v << ' ' << g[u] << ' ' << dp[u] << endl;
	}
	if (vis[u]) dp[u] = 0, g[u] = u;
}
void dfs2(int u, int p) {
	for (int i = h[u]; i; i=e[i].nxt) {
		int v = e[i].to;
		if (v==p) continue;
		int dis = dep[v]-dep[u]; //cout << dis << endl; //注意相同距离取编号小的
		if (dp[u]+dis<dp[v] || (dp[u]+dis==dp[v]&&g[u]<g[v])) dp[v] = dp[u]+dis, g[v] = g[u];
        dfs2(v, u);
	}
}
int ans[maxn], tmp[maxn];
void calc(int u, int v) {
	int son = v;
	for (int i = 19; i>=0; --i)
		if (dep[f[son][i]]>dep[u]) son = f[son][i];
	ans[g[u]] -= sz[son]; //先减去儿子的贡献
    //cout << ans[g[u]] << endl;
	int d = v;
	for (int i = 19; i>=0; --i) {
		int t = f[d][i];
		if (dep[t]<dep[u]) continue; //倍增往上找分界点
		int l = dp[v]+dep[v]-dep[t], r = dp[u]+dep[t]-dep[u];
		if (l<r || (l==r && g[v]<g[u])) d = t; //相同距离取最小
	}
    //cout << u << ' ' << g[u] << "|||" << v << ' ' << g[v] << "!!" << d << endl;
	ans[g[u]] += sz[son]-sz[d]; ans[g[v]] += sz[d]-sz[v];
    //cout << ans[g[u]] << ' ' << ans[g[v]] << endl;
}
void dfs3(int u, int p) {
	for (int i = h[u]; i; i=e[i].nxt) {
		int v = e[i].to;
		if (v==p) continue;
		calc(u, v);
        dfs3(v, u);
	}
	ans[g[u]] += sz[u];
}
void init() {
    tot = 0;
    for (int i = 1; i<=k; ++i) h[d[i]] = ans[d[i]] = g[d[i]] = vis[d[i]] = 0, dp[d[i]]= INF;
}
int main() {
	cin >> n;
	for (int i = 1, a, b; i<n; ++i) {
		scanf("%d%d", &a, &b);
		add(a, b);
		add(b, a);
	}
    dep[1] = 1; dfs(1, 0);
    //for (int i = 1; i<=n; ++i) cout << ldfn[i] << ' ' << rdfn[i] << endl;
    clr(h, 0); tot = 0; //清空原树
	cin >> m;
	while(m--) {
		scanf("%d", &k); int tk = k;
		for (int i = 1; i<=k; ++i) scanf("%d", &d[i]), tmp[i] = d[i];
		for (int i = 1; i<=k; ++i) vis[d[i]] = 1;
		if (!vis[1]) d[++k] = 1;
		build();
		//for (int i = 1; i<=k; ++i) printf(i==k ? "%d\n":"%d ", d[i]);
		dfs1(1, 0); 
        //for (int i = 1; i<=k; ++i) cout << d[i] << ' ' << g[d[i]] << endl;
        dfs2(1, 0); dfs3(1, 0);
        //for (int i = 1; i<=k; ++i) cout << d[i] << ' ' << g[d[i]] << endl;
        for (int i = 1; i<=tk; ++i) printf(i==tk ? "%d\n":"%d ", ans[tmp[i]]);
        init();
	}
	return 0;
}
posted @ 2021-10-19 21:22  shuitiangong  阅读(43)  评论(0编辑  收藏  举报