Luogu 3233 [HNOI2014]世界树
BZOJ 3572
首先看出虚树,然后考虑如何$dp$。
我们先在处理出的虚树上$dp$一遍,处理出虚树上所有点距离最近的关键点(关键点一定在虚树上嘛)。
具体来说,先搜一遍处理出每一个点的父亲到它的可能产生贡献的答案,然后再搜一遍处理出所有儿子到它的可能产生贡献的答案。
接下来考虑一下如何处理出所有不在虚树上的点对答案的贡献,我们去枚举虚树上的每一条边,如果这条边的两边被同一个关键点管辖,那么直接累加答案就行了,否则一定可以在原来的树链上找到一个中间点$mid$使$mid$上方的点都受管辖父亲的点管辖,$mid$下方的点都被管辖深度大的点的点管辖。
这个过程可以用倍增实现。
然而嘴巴是一回事,实现是另一回事……我实现了好久没实现出来,最后对着 ljh2000 大神的代码研究了挺长时间才写出来了。
放个链接吧。 戳这里
由于我的偷懒直接用两倍的点导致常数写大了……在Luogu上需要一发$O2$。
时间复杂度$O(nlogn)$。
Code:
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int N = 3e5 + 5; const int Lg = 22; const int inf = 1 << 30; int n, qn, tot = 0, head[N], dfsc = 0, in[N], out[N]; int fa[N][Lg], dep[N], top, sta[N * 2], a[N * 2], id[N], ans[N]; int bel[N], g[N], siz[N]; bool vis[N]; struct Edge { int to, nxt; } e[N << 1]; inline void add(int from, int to) { e[++tot].to = to; e[tot].nxt = head[from]; head[from] = tot; } bool cmp(int x, int y) { int dfx = x > 0 ? in[x] : out[-x]; int dfy = y > 0 ? in[y] : out[-y]; return dfx < dfy; } inline void read(int &X) { X = 0; char ch = 0; int op = 1; for(; ch > '9' || ch < '0'; ch = getchar()) if(ch == '-') op = -1; for(; ch >= '0' && ch <= '9'; ch = getchar()) X = (X << 3) + (X << 1) + ch - 48; X *= op; } inline void swap(int &x, int &y) { int t = x; x = y; y = t; } inline int min(int x, int y) { return x > y ? y : x; } void dfs(int x, int fat, int depth) { fa[x][0] = fat, dep[x] = depth, siz[x] = 1; in[x] = ++dfsc; for(int i = 1; i <= 20; i++) fa[x][i] = fa[fa[x][i - 1]][i - 1]; for(int i = head[x]; i; i = e[i].nxt) { int y = e[i].to; if(y == fat) continue; dfs(y, x, depth + 1); siz[x] += siz[y]; } out[x] = ++dfsc; } inline int getLca(int x, int y) { if(dep[x] < dep[y]) swap(x, y); for(int i = 20; i >= 0; i--) if(dep[fa[x][i]] >= dep[y]) x = fa[x][i]; if(x == y) return x; for(int i = 20; i >= 0; i--) if(fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i]; return fa[x][0]; } inline int getDis(int x, int y) { int z = getLca(x, y); return dep[x] + dep[y] - 2 * dep[z]; } void dfs1(int x, int fat) { g[x] = siz[x]; for(int i = head[x]; i; i = e[i].nxt) { int y = e[i].to; if(y == fat) continue; dfs1(y, x); if(bel[y] == 0) continue; if(bel[x] == 0) { bel[x] = bel[y]; continue; } int d1 = getDis(bel[x], x), d2 = getDis(bel[y], x); if(d2 < d1 || (d2 == d1 && bel[y] < bel[x])) bel[x] = bel[y]; } } void dfs2(int x, int fat) { for(int i = head[x]; i; i = e[i].nxt) { int y = e[i].to; if(y == fat) continue; if(bel[y] == 0) bel[y] = bel[x]; else { int d1 = getDis(bel[y], y), d2 = getDis(bel[x], y); if(d2 < d1 || (d2 == d1 && (bel[x] < bel[y]))) bel[y] = bel[x]; } dfs2(y, x); } } inline void work(int fat, int x) { int son = x, mid = x; for(int i = 20; i >= 0; i--) if(dep[son] - (1 << i) > dep[fat]) son = fa[son][i]; g[fat] -= siz[son]; if(bel[fat] == bel[x]) ans[bel[fat]] += siz[son] - siz[x]; else { for(int i = 20; i >= 0; i--) { int y = fa[mid][i]; if(dep[y] <= dep[fat]) continue; int d1 = getDis(y, bel[fat]), d2 = getDis(y, bel[x]); if(d1 > d2 || (d1 == d2 && bel[x] < bel[fat])) mid = y; } ans[bel[fat]] += siz[son] - siz[mid]; ans[bel[x]] += siz[mid] - siz[x]; } } void solve() { int K; read(K); for(int i = 1; i <= K; i++) { read(a[i]); if(!vis[a[i]]) { vis[a[i]] = 1; id[i] = a[i]; bel[a[i]] = a[i]; } } int cnt = K; sort(a + 1, a + 1 + K, cmp); for(int i = 1; i < cnt; i++) { int now = getLca(a[i], a[i + 1]); if(!vis[now]) { vis[now] = 1; a[++cnt] = now; } } for(int cur = cnt, i = 1; i <= cur; i++) a[++cnt] = -a[i]; if(!vis[1]) a[++cnt] = 1, a[++cnt] = -1; sort(a + 1, a + 1 + cnt, cmp); /* for(int i = 1; i <= cnt; i++) printf("%d ", a[i]); printf("\n"); */ top = 0; for(int i = 1; i <= cnt; i++) { if(a[i] > 0) sta[++top] = a[i]; else { int x = sta[top--], y = sta[top]; if(y) add(x, y), add(y, x); } } dfs1(1, 0), dfs2(1, 0); for(int i = 1; i <= cnt; i++) { if(a[i] > 0) sta[++top] = a[i]; else { int x = sta[top--], y = sta[top]; if(y) work(y, x); } } for(int i = 1; i <= cnt; i++) if(a[i] > 0) ans[bel[a[i]]] += g[a[i]]; for(int i = 1; i <= K; i++) printf("%d ", ans[id[i]]); printf("\n"); tot = 0; for(int i = 1; i <= cnt; i++) if(a[i] > 0) { ans[a[i]] = head[a[i]] = g[a[i]] = bel[a[i]] = 0; vis[a[i]] = 0; } } int main() { read(n); for(int x, y, i = 1; i < n; i++) { read(x), read(y); add(x, y), add(y, x); } dfs(1, 0, 1); tot = 0; memset(head, 0, sizeof(head)); for(read(qn); qn--; ) solve(); return 0; }