dls的数据结构_树链剖分

树链剖分

给定一个树,将一棵树分成几个链,每个点只能在一个链上(每个节(非叶子)点选一个儿子连下去)
剖法:长链剖分(找深度最大儿子子树连接),重链剖分(找子树大小最大的子树连接)
性质:一个节点不断地向上走,如果遇到了一条轻边,这条轻边的父亲节点的子树大小一定是会翻倍的
      所以只会经过O(logn)条轻边,所以也有任意两点之间的路径经过的轻边也是不超过O(log)
      所以去维护重链的信息,查询每一段的时候在O(log),所以每条路径是O(log^2);
      
解决的问题
1. 路径修改查询
2. 子树修改查询
// 树链剖分可能比倍增求lca更快一点,因为lca每次是跑满for循环的,树链没被卡的话更快
#include<bits/stdc++.h>

using namespace std;
const int N = 1e5+10;

int h[N], ne[2 * N], e[2 * N], idx;
void add(int a, int b){
	e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

int l[N], r[N], id[N], tot, depth[N], sz[N], hs[N], fa[N], top[N];
// 第一遍DFS,先不求dfs序,求子树大小,重儿子, 父亲, 根绝题目可能需要额外求深度等信息
void dfs1(int u, int f){
	sz[u] = 1;
	hs[u] = -1;
	fa[u] = f;
	depth[u] = depth[f] + 1;
	for(int i = h[u]; i != -1; i = ne[i]){
		int j = e[i];
		if(j == f) continue;
		dfs1(j, u);
		sz[u] += sz[j];
		if(hs[u] == -1 || sz[j] > sz[hs[u]]) hs[u] = j;
	}
}

// 第二遍DFS,每个点的DFS序,重链上的头元素(深度最小的元素), t表示当前链的链头元素
void dfs2(int u, int t){
	top[u] = t;
	l[u] = ++tot;
	id[tot] = u;
	// 如果有重儿子的话,先dfs重儿子,这样可以保证在dfs序中重链是有序的
	if(hs[u] != -1) dfs2(hs[u], t);
	// 再遍历轻儿子,他们的链头会变化成当前的第一个节点
	for(int i = h[u]; i != -1; i = ne[i]){
		int j = e[i];
		// 如果j不是父亲,并且j不是重链的话,进行遍历
		if(j != fa[u] && j != hs[u]){
			dfs2(j, j);
		}
	}
	r[u] = tot;
}

int LCA(int u, int v){
	while(top[u] != top[v]){
		if(depth[top[u]] < depth[top[v]]) v = fa[top[v]];
		else u = fa[top[u]];
	}
	if(depth[u] < depth[v]) return u;
	else return v;
}


int main(){
	memset(h, -1, sizeof h);
	int n; scanf("%d", &n);
	for(int i = 1; i <= n - 1; i ++){
		int x, y; scanf("%d %d", &x, &y);
		add(x, y); add(y, x);
	}
	dfs1(1, 0); dfs2(1, 1);
	int m; scanf("%d", &m);
	for(int i = 1; i <= m; i ++){
		int x, y; scanf("%d %d", &x, &y);
		printf("%d\n", LCA(x, y));
	}

}

// 注意方向问题
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5+10;

int h[N], ne[2 * N], e[2 * N], idx;

void add(int a, int b){
	e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

// 一道维护树链点权和以及最大值的板子题

int l[N], r[N], id[N], tot, depth[N], sz[N], hs[N], fa[N], top[N];
// 第一遍DFS,先不求dfs序,求子树大小,重儿子, 父亲, 深度等信息(可能根据题目意思求一些额外的信息)
void dfs1(int u, int f){
	sz[u] = 1;
	hs[u] = -1;
	fa[u] = f;
	depth[u] = depth[f] + 1;
	for(int i = h[u]; i != -1; i = ne[i]){
		int j = e[i];
		if(j == f) continue;
		dfs1(j, u);
		sz[u] += sz[j];
		if(hs[u] == -1 || sz[j] > sz[hs[u]]) hs[u] = j;
	}
}

// 第二遍DFS,每个点的DFS序,重链上的头元素(深度最小的元素), t表示当前链的链头元素
void dfs2(int u, int t){
	top[u] = t;
	l[u] = ++tot;
	id[tot] = u;
	// 如果有重儿子的话,先dfs重儿子,这样可以保证在dfs序中重链是有序的
	if(hs[u] != -1) dfs2(hs[u], t);
	// 再遍历轻儿子,他们的链头会变化成当前的第一个节点
	for(int i = h[u]; i != -1; i = ne[i]){
		int j = e[i];
		// 如果j不是父亲,并且j不是重链的话,进行遍历
		if(j != fa[u] && j != hs[u]){
			dfs2(j, j);
		}
	}
	r[u] = tot;
}

// 线段树维护树链剖剖分的dfs序
int c[N];
struct Node{
	int l, r;
	int cnt, lc, rc;
	int lazy;
}tr[4*N];

void pushdown(Node &F, Node &L, Node &R){
	// 这里为了防止lazy默认标记和需要修改之间的冲突
	if(F.lazy == 0) return ;
	// 标签融合
	L.lazy = F.lazy;
	R.lazy = F.lazy;

	// 标签向下传递
	L.lc = L.rc = F.lazy;
	L.cnt = 1;
	R.lc = R.rc = F.lazy;
	R.cnt = 1;

	// 标签清空
	F.lazy = 0;
}

void pushdown(int u){
	pushdown(tr[u], tr[u<<1], tr[u<<1|1]);
}

void pushup(Node &F, Node L, Node R){
	F.cnt = L.cnt + R.cnt - (L.rc == R.lc);
	F.lc = L.lc;
	F.rc = R.rc;
}

void pushup(int u){
	pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}

void build(int u, int l, int r){
	if(l == r){
		// 这里注意要加个id,获得序列上l位置的点的是那个点,然后获得他的权值
		tr[u] = {l, r, 1, c[id[l]], c[id[l]], 0};
	}
	else{
		tr[u] = {l, r, 0, 0, 0, 0};
		int mid = l + r >> 1;
		build(u<<1, l, mid), build(u<<1|1, mid+1, r);
		pushup(u);
	}
}

void modify(int u, int l, int r, int v){
	
	if(tr[u].l >= l && tr[u].r <= r) {
		tr[u].cnt = 1;
		tr[u].lc = tr[u].rc = v;
		tr[u].lazy = v;
	}
	else{
		pushdown(u);
		int mid = tr[u].l + tr[u].r >> 1;
		if(l <= mid) modify(u<<1, l, r, v);
		if(r > mid) modify(u<<1|1, l, r, v);
		pushup(u);
	}
}


Node query(int u, int l, int r){
	
	if(tr[u].l >= l && tr[u].r <= r) return tr[u];
	else{
		pushdown(u);
		int mid = tr[u].l + tr[u].r >> 1;
		if(l > mid) return query(u<<1|1, l, r);
		else if(r <= mid) return query(u<<1, l, r);
		else{
			Node res;
			Node L = query(u << 1, l, r), R = query(u << 1 | 1, l, r);
			pushup(res, L, R);
			return res;
		}
	}
}

// 这部分是拼凑出整个链的结果或者整个子树的结果,这里需要注意顺序的问题
int get(int u, int v){
	Node L = {0, 0, 0, 0, 0, 0};
	Node R = {0, 0, 0, 0, 0, 0};
	int Lflag = 0, Rflag = 0;

	// 如果没跳到一条链上的话就一直跳
	while(top[u] != top[v]){
		// 链头深度深的向上跳,跳过的链合并到答案中间
		if(depth[top[u]] < depth[top[v]]){
			auto tmp = query(1, l[top[v]], l[v]);
			if(Lflag == 0){
				Lflag = 1;
				L = tmp;
			}
			else pushup(L, tmp, L);
			v = fa[top[v]];
		}
		else{
			auto tmp = query(1, l[top[u]], l[u]);
			if(Rflag == 0){
				Rflag = 1;
				R = tmp;
			}
			else pushup(R, tmp, R);
			u = fa[top[u]];
		}
	}
	// 根据u,v深度的比较确定
	// cout << L.cnt << ' ' << L.lc << ' ' << L.rc << endl;
	// cout << R.cnt << ' ' << R.lc << ' ' << R.rc << endl;
	if(depth[u] < depth[v]){
		auto tmp = query(1, l[u], l[v]);
		// cout << "_" << endl;
		// cout << L.cnt << ' ' << L.lc << ' ' << L.rc << endl;
		// cout << tmp.cnt << ' ' << tmp.lc << ' ' << tmp.rc << endl;
		pushup(L, tmp, L);
		// cout << L.cnt << ' ' << L.lc << ' ' << L.rc << endl;
	} 
	else{
		auto tmp = query(1, l[v], l[u]);
		pushup(R, tmp, R);	
	} 
	return L.cnt + R.cnt - (L.lc == R.lc);
}

void change(int u, int v, int x){
	
	// 如果没跳到一条链上的话就一直跳
	while(top[u] != top[v]){
		// 链头深度深的向上跳,跳过的链合并到答案中间
		if(depth[top[u]] < depth[top[v]]){
			modify(1, l[top[v]], l[v], x);
			v = fa[top[v]];
		}
		else{
			modify(1, l[top[u]], l[u], x);
			u = fa[top[u]];
		}
	}
	// 根据u,v深度的比较确定
	if(depth[u] < depth[v]){
		modify(1, l[u], l[v], x);
	} 
	else{
		modify(1, l[v], l[u], x);	
	} 
}

int main(){
	int n, m; scanf("%d %d", &n, &m);
	memset(h, -1, sizeof h);
	for(int i = 1; i <= n; i ++) scanf("%d", &c[i]);
	for(int i = 1; i <= n - 1; i ++){
		int x, y; scanf("%d %d", &x, &y);
		add(x, y); add(y, x);
	}
	dfs1(1, 0); dfs2(1, 1);
	build(1, 1, n);
	for(int i = 1; i <= m; i ++){
		char op[2];
		scanf("%s", op);
		if(op[0] == 'C'){
			int u, v, x; scanf("%d %d %d", &u, &v, &x);
			change(u, v, x);
		}
		else{
			int u, v; scanf("%d %d", &u, &v);
			auto res = get(u, v);
			printf("%d\n", res);
		}
	}
	return 0;
}
posted @ 2022-04-25 12:15  牛佳文  阅读(55)  评论(0编辑  收藏  举报