【LG1600】[NOIP2016]天天爱跑步

【LG1600】[NOIP2016]天天爱跑步

题面

洛谷

题解

考虑一条路径\(S\rightarrow T\)是如何给一个观测点\(x\)造成贡献的,

一种是从\(x\)的子树内出来,另外一种是从\(x\)的子树外进去。

\(S,T\)的最近公共祖先为\(lca\),那么这条路径可表示为\(S\rightarrow lca\rightarrow T\)(如果\(lca=S\;or\;T\)可以特判)。

考虑两种情况如何贡献,

首先在\(S\rightarrow lca\)上的点,需要满足\(dep_S-dep_x=w_x\)

而对于\(lca\rightarrow T\)上的点,需要满足\((dep_S-dep_{lca})+(dep_x-dep_{lca})=w_x\Leftrightarrow dep_S-2dep_{lca}=w_x-dep_x\)

这样的话,对于一条路径,我们可以拆成两条分别对其进行差分,在用一颗线段树在其对应位置上\(\pm 1\),然后线段树合并在对应位置上查即可。

具体实现细节详见代码。

代码

#include <iostream> 
#include <cstdio> 
#include <cstdlib> 
#include <cstring> 
#include <cmath> 
#include <algorithm> 
#include <vector> 
using namespace std; 
inline int gi() { 
    register int data = 0, w = 1; 
    register char ch = 0; 
    while (!isdigit(ch) && ch != '-') ch = getchar(); 
    if (ch == '-') w = -1, ch = getchar(); 
    while (isdigit(ch)) data = 10 * data + ch - '0', ch = getchar(); 
    return w * data; 
} 
const int MAX_N = 3e5 + 5; 
struct Graph { int to, next; } e[MAX_N << 1]; 
int fir[MAX_N], e_cnt; 
void clearGraph() { memset(fir, -1, sizeof(fir)); e_cnt = 0; } 
void Add_Edge(int u, int v) { e[e_cnt] = (Graph){v, fir[u]}, fir[u] = e_cnt++; }
int fa[MAX_N], top[MAX_N], dep[MAX_N], size[MAX_N], son[MAX_N]; 
void dfs1(int x) { 
	size[x] = 1, dep[x] = dep[fa[x]] + 1; 
	for (int i = fir[x]; ~i; i = e[i].next) { 
		int v = e[i].to; if (v == fa[x]) continue; 
		fa[v] = x, dfs1(v), size[x] += size[v]; 
		if (size[son[x]] < size[v]) son[x] = v; 
	} 
} 
void dfs2(int x, int tp) { 
	top[x] = tp; 
	if (son[x]) dfs2(son[x], tp); 
	for (int i = fir[x]; ~i; i = e[i].next) {
		int v = e[i].to; if (v == son[x] || v == fa[x]) continue; 
		dfs2(v, v); 
	} 
} 
int LCA(int x, int y) { 
	while (top[x] != top[y]) { 
		if (dep[top[x]] < dep[top[y]]) swap(x, y); 
		x = fa[top[x]]; 
	} 
	return dep[x] < dep[y] ? x : y; 
} 
struct Path { int s, t, lca; } p[MAX_N]; 
int N, M, w[MAX_N], ans[MAX_N]; 
vector<int> Add1[MAX_N], Del1[MAX_N], Add2[MAX_N], Del2[MAX_N]; 
struct Node { int ls, rs, v; } t[MAX_N << 6]; 
int rt1[MAX_N], rt2[MAX_N], tot; 
void insert(int &o, int l, int r, int pos, int op) { 
	if (!o) o = ++tot; 
	t[o].v += op; 
	if (l == r) return ; 
	int mid = (l + r) >> 1; 
	if (pos <= mid) insert(t[o].ls, l, mid, pos, op); 
	else insert(t[o].rs, mid + 1, r, pos, op); 
} 
int merge(int x, int y, int l, int r) { 
	if (!x || !y) return x | y; 
	if (l == r) return t[x].v += t[y].v, x; 
	int mid = (l + r) >> 1; 
	t[x].ls = merge(t[x].ls, t[y].ls, l, mid); 
	t[x].rs = merge(t[x].rs, t[y].rs, mid + 1, r); 
	return t[x].v = t[t[x].ls].v + t[t[x].rs].v, x; 
} 
int query(int o, int l, int r, int pos) { 
	if (!o) return 0; 
	if (l == r) return t[o].v; 
	int mid = (l + r) >> 1; 
	if (pos <= mid) return query(t[o].ls, l, mid, pos); 
	else return query(t[o].rs, mid + 1, r, pos); 
} 
void Dfs(int x) { 
	for (int i = fir[x]; ~i; i = e[i].next) { 
		int v = e[i].to; if (v == fa[x]) continue; 
		Dfs(v); 
		rt1[x] = merge(rt1[x], rt1[v], -N, N << 1); 
		rt2[x] = merge(rt2[x], rt2[v], -N, N << 1); 
	} 
	for (int i : Add1[x]) insert(rt1[x], -N, N << 1, i, 1); 
	for (int i : Add2[x]) insert(rt2[x], -N, N << 1, i, 1);
	for (int i : Del1[x]) insert(rt1[x], -N, N << 1, i, -1); 
	for (int i : Del2[x]) insert(rt2[x], -N, N << 1, i, -1); 
	ans[x] = query(rt1[x], -N, N << 1, w[x] + dep[x]) + query(rt2[x], -N, N << 1, w[x] - dep[x]); 
} 
int main () { 
#ifndef ONLINE_JUDGE 
    freopen("cpp.in", "r", stdin); 
#endif 
	clearGraph(); 
	N = gi(), M = gi(); 
	for (int i = 1; i < N; i++) { 
		int u = gi(), v = gi(); 
		Add_Edge(u, v), Add_Edge(v, u); 
	} 
	dfs1(1), dfs2(1, 1); 
	for (int i = 1; i <= N; i++) w[i] = gi(); 
	for (int i = 1; i <= M; i++) { 
		int s = gi(), t = gi(), lca = LCA(s, t); 
		int d1 = dep[s], d2 = -dep[s]; 
		if (lca == t) { Add1[s].push_back(d1), Del1[fa[lca]].push_back(d1); continue; } 
		if (lca == s) { Add2[t].push_back(d2), Del2[fa[lca]].push_back(d2); continue; } 
		d2 = dep[s] - 2 * dep[lca]; 
		Add1[s].push_back(d1), Del1[fa[lca]].push_back(d1); 
		Add2[t].push_back(d2), Del2[lca].push_back(d2); 
	} 
	Dfs(1); 
	for (int i = 1; i <= N; i++) printf("%d ", ans[i]); 
	putchar('\n'); 
    return 0; 
} 
posted @ 2019-11-06 15:19  heyujun  阅读(170)  评论(0编辑  收藏  举报