重链剖分

代码思路

主体部分:

初始化,剖分链,求LCA
(也就是dfs1,dfs2,LCA三个函数)

辅助部分:

struct Point{ // 节点信息多的时候会习惯开个结构体来存 
	int dep, siz, son, top, fath;
	// 点的深度 子树大小 重儿子 所在重链的链头 父亲节点 
	// 没有重儿子则son=0
	int id, l, r; // 求lca不会用到,但实际常用的东西
	// dfs序 子树dfs编号左右界 
	void getnum1(Point fa, int f){ // dfs1的初始化 
		dep = fa.dep+1; siz = 1; fath = f;
		// 求深度 初始化字数大小 记录父亲节点 
	}
	void getnum2(int tp){ // dfs2的初始化
		id = l = r = ++tot; top = tp;
		// 初始化dfs序 记录链头 
	}
}p[N];

个人习惯,在需要对节点统计很多信息的时候,开一个结构体储存。

简化部分:

暂无
(树链剖分代码其实挺短的,而且足够优美)
(我收回这句话,下面的代码敲了1h)

注意事项:

搜重儿子前要判断u是不是叶子结点
求LCA时比较的是链头深度


原理

对树进行分块(不过是分成链)

指路详细说明


代码注释

洛谷P3379为例

#include <bits/stdc++.h>
using namespace std;

const int N = 5e5+5; // 数据范围(节点个数) 
int n, m, s, head[N]; // 一些常规的东西 
int tot = 0; // 求lca不会用到,但实际常用的dfs序计数器 
struct Point{ // 节点信息多的时候会习惯开个结构体来存 
	int dep, siz, son, top, fath;
	// 点的深度 子树大小 重儿子 所在重链的链头 父亲节点 
	// 没有重儿子则son=0
	int id, l, r; // 求lca不会用到,但实际常用的东西
	// dfs序 子树dfs编号左右界 
	void getnum1(Point fa, int f){ // dfs1的初始化 
		dep = fa.dep+1; siz = 1; fath = f;
		// 求深度 初始化字数大小 记录父亲节点 
	}
	void getnum2(int tp){ // dfs2的初始化
		id = l = r = ++tot; top = tp;
		// 初始化dfs序 记录链头 
	}
}p[N];
struct Edge{ // 链表存边 
	int u, v;
}e[N*2]; // 双向边

void dfs1(int u, int fa){ // 第一遍深搜:求节点的各种信息 
	p[u].getnum1(p[fa], fa); // 初始化 
	int maxsiz = 0; // 最大的子·子树大小
	// ↑用于辅助求重儿子 
	for(int i = head[u]; i; i = e[i].v){ // 枚举所有边 
		int v = e[i].u; // 边的端点 
		if(v == fa) continue; // 辈分不能乱JPG 
		dfs1(v, u); // 继续往下深搜
		p[u].siz += p[v].siz; // 统计子树大小 
		if(p[v].siz > maxsiz){ // 如果这个儿子更重 
			maxsiz = p[v].siz; // 更新辅助值 
			p[u].son = v; // 更新重儿子 
		}
	}
}

void dfs2(int u, int fa, int top){ // 第二遍深搜:统计重链
	// top是u所在重链的链头 
	p[u].getnum2(top); // 初始化 
	if(!p[u].son) return ;
	// ↑注意要判断u是不是叶子结点
    p[u].l = min(p[u].l, p[p[u].son].l); // 更新子树的左边界
    p[u].r = max(p[u].r, p[p[u].son].r); // 更新子树的右边界
    dfs2(p[u].son, u, top); // 先搜重儿子以保证dfs序连续 
	for(int i = head[u]; i; i = e[i].v){ // 枚举所有边 
		int v = e[i].u; // 边的端点 
		if(v == fa || v == p[u].son) continue;
		// 不能再搜父亲和重儿子
		dfs2(v, u, v); // 轻儿子所在重链的链头是它自己 
        p[u].l = min(p[u].l, p[v].l); // 更新子树的左边界
        p[u].r = max(p[u].r, p[v].r); // 更新子树的右边界
	}
} 

int LCA(int u, int v){ // 求u和v的最近公共祖先 
	while(p[u].top != p[v].top){
		// 在同一重链则退出 
		if(p[p[u].top].dep < p[p[v].top].dep) swap(u, v); // 调整深度 
		// ↑注意这里比较的是链头的深度 
		u = p[p[u].top].fath; // 沿着轻边向上跳
	}
	// 这时候已经在同一个重链上了
	// 那么深度小的节点就是LCA
	if(p[u].dep < p[v].dep) swap(u, v);
	return v;  // 返回LCA 
}

int main(){ // 喜闻乐见的主函数 
	// ↓输入 
	scanf("%d%d%d", &n, &m, &s);
	for(int i = 1; i < n; i++){
		int u, v; scanf("%d%d", &u, &v);
		e[i] = (Edge){v, head[u]}; head[u] = i;
		e[i+n] = (Edge){u, head[v]}; head[v] = i+n;
	}
	// ↑输入 

	dfs1(s, 0); dfs2(s, 0, s); // 两遍dfs剖出重链 

	for(int i = 1; i <= m; i++){ // m次询问 
		int u, v; scanf("%d%d", &u, &v);
		printf("%d\n", LCA(u, v)); // 输出答案 
	}
	return 0;
} 

洛谷P3384

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define ls (u<<1)
#define rs ((u<<1)|1)
#define mid ((l+r)>>1)

const int N = 1e5+5;
int n, m, r, p, head[N];
struct Edge{
	int u, v;
}e[2*N];
struct Point{
	int dep, siz, son, top, fath;
	int val;
	int id, l, r;
}pnt[N];

int t[N*4], tag[N*4], a[N];
void build(int u, int l, int r){
	if(l == r){
		t[u] = a[l]; return;
	} build(ls, l, mid);
	build(rs, mid+1, r);
	t[u] = (t[ls]+t[rs])%p;
}

void add_tag(int u, int l, int r, int d){
    tag[u] += d; t[u] += d*(r-l+1);
//	printf("%d %d %d -> %d\n", u, l, r, t[u]);
}

void push_down(int u, int l, int r){
    if(!tag[u]) return ;
    add_tag(ls, l, mid, tag[u]);
    add_tag(rs, mid+1, r, tag[u]);
    tag[u] = 0;
}


void update(int u, int l, int r, int L, int R, int d){
	if(l >= L && r <= R){
		add_tag(u, l, r, d);
		return ;
	} push_down(u, l, r);
	if(L <= mid) update(ls, l, mid, L, R, d);
	if(R > mid) update(rs, mid+1, r, L, R, d);
	t[u] = (t[ls]+t[rs])%p;
}

int query(int u, int l, int r, int L, int R){
	if(l >= L && r <= R) return t[u];
	push_down(u, l, r); int ret = 0;
	if(L <= mid) ret += query(ls, l, mid, L, R);
	if(R > mid) ret += query(rs, mid+1, r, L, R);
	return ret % p;
}

void dfs1(int u, int fa){
	pnt[u].dep = pnt[fa].dep+1;
	pnt[u].siz = 1; pnt[u].fath = fa;
	int maxsiz = 0;
	for(int i = head[u]; i; i = e[i].v){
		int v = e[i].u;
		if(v == fa) continue;
		dfs1(v, u); pnt[u].siz += pnt[v].siz;
		if(maxsiz < pnt[v].siz){
			maxsiz = pnt[v].siz;
			pnt[u].son = v;
		}
	}
}

int tot = 0;
void dfs2(int u, int fa, int top){
	pnt[u].top = top;
	pnt[u].id = pnt[u].l = pnt[u].r = ++tot;
	a[tot] = pnt[u].val;
	if(!pnt[u].son) return ;
	dfs2(pnt[u].son, u, top);
	pnt[u].l = min(pnt[u].l, pnt[pnt[u].son].l);
	pnt[u].r = max(pnt[u].r, pnt[pnt[u].son].r);
	for(int i = head[u]; i; i = e[i].v){
		int v = e[i].u;
		if(v == fa || v == pnt[u].son) continue;
		dfs2(v, u, v);
		pnt[u].l = min(pnt[u].l, pnt[v].l);
		pnt[u].r = max(pnt[u].r, pnt[v].r);
	}
}

int LCA(int u, int v, int z){
	int ret = 0;
	while(pnt[u].top != pnt[v].top){
//		printf("<%d %d>\n", u, v);
		if(pnt[pnt[u].top].dep < pnt[pnt[v].top].dep)
			swap(u, v);
		if(z>0) update(1, 1, n, pnt[pnt[u].top].id, pnt[u].id, z);
		else ret = (ret+query(1, 1, n, pnt[pnt[u].top].id, pnt[u].id))%p;
		u = pnt[pnt[u].top].fath;
	}
	if(pnt[u].id > pnt[v].id) swap(u, v);
	if(z>0) update(1, 1, n, pnt[u].id, pnt[v].id, z);
	else ret = (ret+query(1, 1, n, pnt[u].id, pnt[v].id))%p;
	return ret;
//	return u;
}

signed main(){
	scanf("%lld%lld%lld%lld", &n, &m, &r, &p);
	for(int i = 1; i <= n; i++)
		scanf("%lld", &pnt[i].val);
	for(int i = 1; i < n; i++){
		int u, v; scanf("%lld%lld", &u, &v);
		e[i] = (Edge){v, head[u]}; head[u] = i;
		e[i+n] = (Edge){u, head[v]}; head[v] = i+n;
	}

	dfs1(r, 0); dfs2(r, 0, r);
	build(1, 1, n);

	for(int i = 1; i <= m; i++){
		int opt; scanf("%lld", &opt);
		if(opt == 1){
			int x, y, z;
			scanf("%lld%lld%lld", &x, &y, &z);
			LCA(x, y, z);
		}
		if(opt == 2){
			int x, y; scanf("%lld%lld", &x, &y);
			printf("%lld\n", LCA(x, y, 0));
		}
		if(opt == 3){
			int x, z; scanf("%lld%lld", &x, &z);
//			printf("%d-%d\n", pnt[x].l, pnt[x].r);
			update(1, 1, n, pnt[x].l, pnt[x].r, z);
		}
		if(opt == 4){
			int x; scanf("%lld", &x);
//			printf("%d-%d\n", pnt[x].l, pnt[x].r);
			printf("%lld\n", query(1, 1, n, pnt[x].l, pnt[x].r));
		}
	}
	return 0;
}
posted @ 2023-10-12 21:54  _kilo-meteor  阅读(8)  评论(0编辑  收藏  举报