P1600 [NOIP2016 提高组] 天天爱跑步

P1600 NOIP2016 提高组 天天爱跑步
LCA + 桶
分为上行和下行
上行: u->v 被i看到: u在i的子树且 dep[u]-dep[i]=w[i], 用桶维护dep[st(=u)]=x的u有多少个
下行: u->v 被i看到: v在i的子树且(u不在) 到lca(u或f[u])的 + dep[i]-dep[u] = w[i],用桶维护dep[st]-2dep[u]=x的u有多少个

点击查看代码
// 
/*
考虑上行的情况
(u, v) 中 u 被 i 看到
<=> 1. u ∈ {i的子树}
	2. lca(u, v) 不属于 {i的子树}
	3. dep[u] = w[i] + dep[i]
bucket1[x]: dep[u] = x 的 u 有多少个
考虑下行的情况
(u, v) 中 v 被 i 看到
<=> 1. v ∈ {i的子树}
    2. lca(u, v) 不属于 {i的子树}
    3. dep[u] - 2 * dep[lca(u, v)] = w[i] - dep[i]
bucket2[x]: dep[u] - 2 * dep[lca(u, v)] = x 的 u 有多少个
*/
#include <iostream>
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <utility>
#include <array>
#include <queue>

using namespace std;

const int N = 3e5 + 5, M = N << 1, logN = 25;

int n, m;
int h[N], e[M], nxt[M], idx;
int w[N], s[N], t[N]; // w 为观察者出现的时间, s 为玩家的开始节点, t 为玩家的终止节点
int ans[N]; // ans 为每个观察者能看到的人
int f[N][logN], dep[N];

struct bucket_t { // 桶
	int val[N * 2];
	bucket_t() { memset(val, 0, sizeof(val)); }
	inline int &operator [] (const int &i) { return val[i + N]; }
} bucket1, bucket2;

struct Operation {
	int val, t; // t ∈ {1, -1} 为树上差分的操作
};
vector<Operation> oper1[N], oper2[N]; // 分别为上行的操作和下行的操作

void add(int a, int b) {
	e[++ idx] = b, nxt[idx] = h[a], h[a] = idx;
}

void dfs(int u) {
	for(int i = 1; i < logN; i ++)
		if(f[u][i - 1]) f[u][i] = f[f[u][i - 1]][i - 1];
		else break;
	for(int i = h[u]; i; i = nxt[i]) {
		int v = e[i];
		if(v == f[u][0]) continue;
		f[v][0] = u, dep[v] = dep[u] + 1;
		dfs(v);
	}
}

void dfs1(int u) { // 处理上行
	int old = bucket1[w[u] + dep[u]];
	for(int i = h[u]; i; i = nxt[i]) {
		int v = e[i];
		if(v != f[u][0]) dfs1(v);
	}
	for(auto &o : oper1[u]) bucket1[o.val] += o.t;
	ans[u] += bucket1[w[u] + dep[u]] - old; // 新的
}

void dfs2(int u) { // 处理下行
	int old = bucket2[w[u] - dep[u]];
	for(int i = h[u]; i; i = nxt[i]) {
		int v = e[i];
		if(v != f[u][0]) dfs2(v);
	}
	for(auto &o : oper2[u]) bucket2[o.val] += o.t;
	ans[u] += bucket2[w[u] - dep[u]] - old;
}

int LCA(int x, int y) {
	if(dep[x] < dep[y]) swap(x, y);
	for(int i = logN - 1; i >= 0; i --)
		if(dep[f[x][i]] >= dep[y]) x = f[x][i];
	if(x == y) return x;
	for(int i = logN - 1; i >= 0; i --)
		if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
	return f[x][0];
}

int main() {
	scanf("%d%d", &n, &m);
	for(int i = 1, a, b; i < n; i ++)
		scanf("%d%d", &a, &b), add(a, b), add(b, a);
	for(int i = 1; i <= n; i ++) scanf("%d", w + i);
	for(int i = 1; i <= m; i ++) scanf("%d%d", s + i, t + i);
	dep[1] = 1, dfs(1);
	for(int i = 1; i <= m; i ++) {
		int &a = s[i], &b = t[i];
		int lca = LCA(a, b);
		oper1[a].push_back({dep[a], 1});
		oper1[lca].push_back({dep[a], -1});
		oper2[b].push_back({dep[a] - dep[lca] * 2, 1});
		if(f[lca][0]) oper2[f[lca][0]].push_back({dep[a] - dep[lca] * 2, -1});
	}
	dfs1(1), dfs2(1);
	for(int i = 1; i <= n; i ++) printf("%d ", ans[i]);
	return 0;
}
posted @ 2022-09-28 10:15  azzc  阅读(41)  评论(0编辑  收藏  举报