P1600 [NOIP2016 提高组] 天天爱跑步
P1600 NOIP2016 提高组 天天爱跑步
LCA + 桶
分为上行和下行
上行: u->v 被i看到: u在i的子树且 dep[u]-dep[i]=w[i], 用桶维护dep[st(=u)]=x的u有多少个
下行: u->v 被i看到: v在i的子树且(u不在) 到lca(u或f[u])的 + dep[i]-dep[u] = w[i],用桶维护dep[st]-2dep[u]=x的u有多少个
点击查看代码
//
/*
考虑上行的情况
(u, v) 中 u 被 i 看到
<=> 1. u ∈ {i的子树}
2. lca(u, v) 不属于 {i的子树}
3. dep[u] = w[i] + dep[i]
bucket1[x]: dep[u] = x 的 u 有多少个
考虑下行的情况
(u, v) 中 v 被 i 看到
<=> 1. v ∈ {i的子树}
2. lca(u, v) 不属于 {i的子树}
3. dep[u] - 2 * dep[lca(u, v)] = w[i] - dep[i]
bucket2[x]: dep[u] - 2 * dep[lca(u, v)] = x 的 u 有多少个
*/
#include <iostream>
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <utility>
#include <array>
#include <queue>
using namespace std;
const int N = 3e5 + 5, M = N << 1, logN = 25;
int n, m;
int h[N], e[M], nxt[M], idx;
int w[N], s[N], t[N]; // w 为观察者出现的时间, s 为玩家的开始节点, t 为玩家的终止节点
int ans[N]; // ans 为每个观察者能看到的人
int f[N][logN], dep[N];
struct bucket_t { // 桶
int val[N * 2];
bucket_t() { memset(val, 0, sizeof(val)); }
inline int &operator [] (const int &i) { return val[i + N]; }
} bucket1, bucket2;
struct Operation {
int val, t; // t ∈ {1, -1} 为树上差分的操作
};
vector<Operation> oper1[N], oper2[N]; // 分别为上行的操作和下行的操作
void add(int a, int b) {
e[++ idx] = b, nxt[idx] = h[a], h[a] = idx;
}
void dfs(int u) {
for(int i = 1; i < logN; i ++)
if(f[u][i - 1]) f[u][i] = f[f[u][i - 1]][i - 1];
else break;
for(int i = h[u]; i; i = nxt[i]) {
int v = e[i];
if(v == f[u][0]) continue;
f[v][0] = u, dep[v] = dep[u] + 1;
dfs(v);
}
}
void dfs1(int u) { // 处理上行
int old = bucket1[w[u] + dep[u]];
for(int i = h[u]; i; i = nxt[i]) {
int v = e[i];
if(v != f[u][0]) dfs1(v);
}
for(auto &o : oper1[u]) bucket1[o.val] += o.t;
ans[u] += bucket1[w[u] + dep[u]] - old; // 新的
}
void dfs2(int u) { // 处理下行
int old = bucket2[w[u] - dep[u]];
for(int i = h[u]; i; i = nxt[i]) {
int v = e[i];
if(v != f[u][0]) dfs2(v);
}
for(auto &o : oper2[u]) bucket2[o.val] += o.t;
ans[u] += bucket2[w[u] - dep[u]] - old;
}
int LCA(int x, int y) {
if(dep[x] < dep[y]) swap(x, y);
for(int i = logN - 1; i >= 0; i --)
if(dep[f[x][i]] >= dep[y]) x = f[x][i];
if(x == y) return x;
for(int i = logN - 1; i >= 0; i --)
if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
int main() {
scanf("%d%d", &n, &m);
for(int i = 1, a, b; i < n; i ++)
scanf("%d%d", &a, &b), add(a, b), add(b, a);
for(int i = 1; i <= n; i ++) scanf("%d", w + i);
for(int i = 1; i <= m; i ++) scanf("%d%d", s + i, t + i);
dep[1] = 1, dfs(1);
for(int i = 1; i <= m; i ++) {
int &a = s[i], &b = t[i];
int lca = LCA(a, b);
oper1[a].push_back({dep[a], 1});
oper1[lca].push_back({dep[a], -1});
oper2[b].push_back({dep[a] - dep[lca] * 2, 1});
if(f[lca][0]) oper2[f[lca][0]].push_back({dep[a] - dep[lca] * 2, -1});
}
dfs1(1), dfs2(1);
for(int i = 1; i <= n; i ++) printf("%d ", ans[i]);
return 0;
}