【数据结构、图论/线段树合并】AcWing 354. 天天爱跑步

这里提供一个线段树合并的做法,个人感觉思维难度和代码难度都不大,我一发过了这题,nice

分析

把每个人对应的路径看成是询问,考虑如何处理这些询问:

可以发现对于一条路径 \(u\to v\),相当于是 \(u\to v\) 依次打上 \(0, 1, 2,\dots\) 这样的时间戳,然后最暴力的做法就是沿着路径直接打时间戳,路过的每个点都维护一个桶(记为 \(b\))来存这个时间戳,处理总询问次。记点 \(u\) 观察员出现时间为 \(q[u]\),那么点 \(u\) 的答案就是 \(b[q[u]]\) 了。

这样做时间、空间复杂度都是 \(O(n^2)\) 级别的,故考虑优化。

记路径 \(u\to v\)\(LCA(u, v) = a\),那么对于一条路径就可以表示为两条链(例如下图)

image

可以发现 \(u\to a\)内向的,故沿根向时间戳值递增(\(+1\));而 \(a\to v\)外向的,故沿根向时间戳值递减。也就是说我们可以将问题拆成分别考虑内向、外向的链进行统计。

下以内向的情况为例:

直接维护时间戳的值很困难,但是我们可以利用 \(+1\) 的性质:

维护 \(val = u~时间戳的值 + dep[u]\),那么一条链上所有的 \(val\) 都是相等的,我们就可以使用树上差分的思想,在 \(u\) 点打一个插入标记 \(ins\),代表向桶里插入 \(val\),再于 \(fa[a]\) 点打一个删除标记 \(del\),代表 \(a\) 点父节点以及上面的点的桶删掉 \(val\)

而对于外向的情况也是完全类似的,只需要维护 \(val = u~时间戳的值 - dep[u]\) 就可以了,不过因为这里 \(val\) 可能是负值,我们可以加上一个偏移量 \(n\)(题目的点数)以方便下面的线段树操作。

我们使用值域线段树来维护桶(值域维护上面所说的 \(val\)),然后使用线段树合并的技巧,在进行完上面所有的标记(可以发现,上面的 \(ins,del\) 对应着线段树的单点修改)后,使用一次 \(\texttt{dfs}\) 来将桶里的信息自下而上地合并起来,复杂度是 \(O(n{\rm log}n)。\)

实现

// Problem: 天天爱跑步
// Contest: AcWing
// URL: https://www.acwing.com/problem/content/description/356/
// Memory Limit: 512 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include<iostream>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
using namespace std;

#define debug(x) cerr << #x << ": " << (x) << endl
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define dwn(i,a,b) for(int i=(a);i>=(b);i--)
#define pb push_back
#define all(x) (x).begin(), (x).end()
 
#define x first
#define y second
using pii = pair<int, int>;
using ll = long long;
 
inline void read(int &x){
    int s=0; x=1;
    char ch=getchar();
    while(ch<'0' || ch>'9') {if(ch=='-')x=-1;ch=getchar();}
    while(ch>='0' && ch<='9') s=(s<<3)+(s<<1)+ch-'0',ch=getchar();
    x*=s;
}

const int N=3e5+5, M=N<<1;

struct Edge{
	int to, next;
}e[M];

int h[N], tot;

void add(int u, int v){
	e[tot].to=v, e[tot].next=h[u], h[u]=tot++;
}

int dep[N], sz[N], son[N], fa[N];
int top[N];

void dfs1(int u, int f, int d=0){
	dep[u]=d, sz[u]=1, fa[u]=f;
	for(int i=h[u]; ~i; i=e[i].next){
		int go=e[i].to;
		if(go==f) continue;
		dfs1(go, u, d+1);
		if(sz[go]>sz[son[u]]) son[u]=go;
		sz[u]+=sz[go];
	}
}

void dfs2(int u, int t){
	top[u]=t;
	if(!son[u]) return;
	dfs2(son[u], t);
	for(int i=h[u]; ~i; i=e[i].next){
		int go=e[i].to;
		if(go==fa[u] || go==son[u]) continue;
		dfs2(go, go);
	}
}

int lca(int u, int v){
	while(top[u]!=top[v]){
		if(dep[top[u]]<dep[top[v]]) swap(u, v);
		u=fa[top[u]];
	}
	return dep[u]<dep[v]? u: v;
}

int dis(int u, int v){
	return dep[u]+dep[v]-(dep[lca(u, v)]<<1);
}

int n, m;
int q[N], res[N];

vector<int> ins1[N], del1[N], ins2[N], del2[N];

struct Segtree{
	struct Node{
		int l, r;
		int cnt;
		
		#define ls tr[u].l
		#define rs tr[u].r
	}tr[N*20];
	
	int idx, root[N];
	
	Segtree(){
		idx=0;
		memset(root, 0, sizeof root);
	}
	
	void upd(int &u, int l, int r, int x, int k){
		if(!u) u=++idx;
		if(l==r){
			tr[u].cnt+=k;
			return;
		}
		int mid=l+r>>1;
		if(x<=mid) upd(ls, l, mid, x, k);
		else upd(rs, mid+1, r, x, k);
	}
	
	int query(int u, int l, int r, int x){
		if(x<l || x>r) return 0;
		if(l==r) return tr[u].cnt;
		int mid=l+r>>1;
		if(x<=mid) return query(ls, l, mid, x);
		return query(rs, mid+1, r, x);
	}
	
	void merge(int &p, int q, int l, int r){
		if(!p || !q) return p=(p|q), void();
		if(l==r) return tr[p].cnt+=tr[q].cnt, void();
		int mid=l+r>>1;
		merge(tr[p].l, tr[q].l, l, mid);
		merge(tr[p].r, tr[q].r, mid+1, r);
	}
}T1, T2;

void dfs(int u){
	for(int i=h[u]; ~i; i=e[i].next){
		int go=e[i].to;
		if(go==fa[u]) continue;
		dfs(go);
		T1.merge(T1.root[u], T1.root[go], 0, n);
		T2.merge(T2.root[u], T2.root[go], 0, n<<1);
	}
	for(auto i: ins1[u]) T1.upd(T1.root[u], 0, n, i, 1);
	for(auto i: del1[u]) T1.upd(T1.root[u], 0, n, i, -1);
	for(auto i: ins2[u]) T2.upd(T2.root[u], 0, n<<1, i+n, 1);
	for(auto i: del2[u]) T2.upd(T2.root[u], 0, n<<1, i+n, -1);
	
	res[u]=T1.query(T1.root[u], 0, n, q[u]+dep[u])+T2.query(T2.root[u], 0, n<<1, q[u]-dep[u]+n); // check bound
}

int main(){
	memset(h, -1, sizeof h);
	cin>>n>>m;
	rep(i,1,n-1){
		int u, v; read(u), read(v);
		add(u, v), add(v, u);
	}
	
	rep(i,1,n) read(q[i]);
	
	dfs1(1, -1), dfs2(1, 1);
	
	rep(i,1,m){
		int u, v; read(u), read(v);
		int a=lca(u, v);
		ins1[u].pb(dep[u]), del1[fa[a]].pb(dep[u]); // 0+dep
		ins2[v].pb(-dep[v]+dis(u, v)), del2[a].pb(-dep[v]+dis(u, v));
	}
	
	dfs(1);
	
	rep(i,1,n) cout<<res[i]<<' ';
	
	return 0;
}
posted @ 2022-07-14 15:37  HinanawiTenshi  阅读(30)  评论(0编辑  收藏  举报