luogu P4183 [USACO18JAN]Cow at Large P
https://www.luogu.com.cn/problem/P4183
学到许多
显然要求出来一个\(f[u]\)表示离\(u\)最近的叶子距离
考虑对于一个节点的情况,把它设为根
求出来以它为根的每个点的深度,记为\(dep[u]\)
手玩一下容易发现对于每个节点,叶子是否需要的放判断是
\[f[u]<=dep[u],f[fa]>dep[fa]
\]
那么这个\(u\)节点(对应的子树)就可以对根产生1的贡献
但是直接这么做非常不好计算,发现难处理的是\(fa\)
因为产生贡献的一定是子树,所以考虑
对于一棵子树 \(\sum in[x]=2m-1\)
可以得到\(1=\sum (2-in[x])\)
这要就能保证子树加起来最后的贡献是\(1\)
然后考虑计算
那么根据上面那条性质,可以轻易得到
\[ans=\sum [f[u]<=dep[u]] (2-in[x])
\]
这样就能保证子树的贡献是\(1\)
然后我们考虑点分治,以\(rt\)为分治中心
\(dep[x]+dep[i]>=f[i]\)那么\(i\)就能对\(x\)产生\(2-in[i]\)的贡献
用树状数组维护\(f[i]-dep[i]\)即可
code:
#include<bits/stdc++.h>
#define N 200050
using namespace std;
int f[N], in[N];
vector<int> g[N];
void dfs(int u, int fa) {
f[u] = 114514;
if(in[u] == 1) f[u] = 0;
for(int v : g[u]) {
if(v == fa) continue;
dfs(v, u);
f[u] = min(f[u], f[v] + 1);
}
}
void dfss(int u, int fa) {
for(int v : g[u]) {
if(v == fa) continue;
f[v] = min(f[v], f[u] + 1);
dfss(v, u);
}
}
int t[N], n;
#define lowbit(x) (x & -x)
void update(int x, int y) { x += n;
for(; x <= 2 * n; x += lowbit(x)) t[x] += y;
}
int query(int x) { x += n;
int ret = 0;
for(; x; x -= lowbit(x)) ret += t[x];
return ret;
}
int S, siz[N], vis[N], gs, ls[N], msiz[N], dep[N], ans[N];
void find(int u, int fa) {
ls[++ gs] = u;
siz[u] = 1; msiz[u] = 0;
for(int v : g[u]) {
if(v == fa || vis[v]) continue;
find(v, u);
siz[u] += siz[v];
msiz[u] = max(msiz[u], siz[v]);
}
msiz[u] = max(msiz[u], S - siz[u]);
}
void get(int u, int fa) {
ls[++ gs] = u; siz[u] = 1;
for(int v : g[u]) {
if(v == fa || vis[v]) continue;
dep[v] = dep[u] + 1;
get(v, u); siz[u] += siz[v];
}
}
void calc(int o) {
//for(int i = 1; i <= gs; i ++) printf("%d ", ls[i]); printf("\n");
for(int i = 1; i <= gs; i ++) {
int x = ls[i];
update(f[x] - dep[x], 2 - in[x]);
// if(x == 2) printf("%d ", dep[x]);
}
for(int i = 1; i <= gs; i ++) {
int x = ls[i];
ans[x] += o * query(dep[x]);
}
for(int i = 1; i <= gs; i ++) {
int x = ls[i];
update(f[x] - dep[x], -(2 - in[x]));
}
}
void solve(int u) {
gs = 0;
find(u, u);
for(int i = 1; i <= gs; i ++) if(msiz[ls[i]] < msiz[u]) u = ls[i];
//printf("* %d %d\n", u, siz[u]);
gs = 0;
dep[u] = 0;
get(u, u);
calc(1);
vis[u] = 1;
for(int v : g[u]) {
if(vis[v]) continue;
gs = 0; get(v, u);
calc(-1);
S = siz[v];
solve(v);
}
}
int main() {
// freopen("a.in","r",stdin);
// freopen("a.out","w",stdout);
scanf("%d", &n);
for(int i = 1; i < n; i ++) {
int u, v;
scanf("%d%d", &u, &v);
g[u].push_back(v), g[v].push_back(u);
in[u] ++, in[v] ++;
}
dfs(1, 0), dfss(1, 0);
S = n; solve(1);
for(int i = 1; i <= n; i ++) if(in[i] == 1) ans[i] = 1;
for(int i = 1; i <= n; i ++) printf("%d\n", ans[i]);
return 0;
}