【模板】轻重链剖分
【模板】轻重链剖分
题目
讲解
总的来说,就是一个不难理解,码量爆炸的东西
推几篇题解,讲得不错
https://www.luogu.com.cn/blog/zengqinyi/solution-p3384
https://www.cnblogs.com/ivanovcraft/p/9019090.html
前置知识
线段树(必备),倍增LCA(可以帮助理解,不会应该也可以),链式前向星(存图,不会有人不会吧)
概念
重儿子(子树结点最多的儿子),重边(某个点到它的重儿子连成的边),重链(重边连成的链),轻儿子(除重儿子外的其它儿子),轻边,轻链
int n , m , rt , mod;//如题所述,为了习惯,p换位mod
int dat[nn];//输入的初始值
int dep[nn] , fa[nn] , siz[nn] , hvyson[nn];
//dep[i]:i结点深度,fa[i]:父节点,siz[i]:以i为根的子树大小,hvyson[i]:i结点的重儿子
int id[nn] , top[nn];
//id[i]:i结点的新标号,等下(第二轮dfs)会讲,top[i]:i所在重链的顶端结点
第一轮dfs
我们需要处理出这些:
int dep[nn] , fa[nn] , siz[nn] , hvyson[nn];
void dfs1(int x , int fa_ , int deep) {//x:当前结点,fa_:x的父结点,deep:当前深度
fa[x] = fa_;
hvyson[x] = 0;
int maxsiz = 0;
dep[x] = deep;
for(int i = head[x] ; i ; i = ed[i].nxt) {//链式前向星
if(ed[i].to == fa_)continue;
dfs1(ed[i].to , x , deep + 1);
siz[x] += siz[ed[i].to];
if(maxsiz < siz[ed[i].to])
maxsiz = siz[ed[i].to],
hvyson[x] = ed[i].to;
}
++siz[x];
}
第二轮dfs
我们需要处理出这些:
int id[nn] , top[nn];
void dfs2(int x , int chaintop) {//x:同上,chaintop:x所在的重链的顶端结点
static int cnt = 0;
id[x] = ++cnt;//按照dfs序赋值新编号,以便线段树操作
change(root , id[x] , id[x] , dat[x]);//线段树修改,上传x原权值到新编号上
top[x] = chaintop;
if(hvyson[x] == 0)return;//无儿子
dfs2(hvyson[x] , chaintop);//优先处理重儿子,下面讲
for(int i = head[x] ; i ; i = ed[i].nxt) {
int to = ed[i].to;
if(to == fa[x] || to == hvyson[x])continue;
dfs2(to , to);
}
}
核心-查询&修改
相信很多人都产生了疑问:为什么要给结点按照dfs序重新编号,并优先处理重儿子呢?
由于是dfs序,那么以任意一个结点为根,整颗子树的新编号都是连续的,这就是说,我们可以直接利用线段树修改或查询整颗子树的权值,这就把3,4操作的时间复杂度降到了log级别
由于我们优先处理重儿子,所以同一条重链上所有结点的编号都是连续的,这也为线段树操作提供了方便,以2操作为例(代码解释)
inline int path_query(int x , int y) {//询问x~y路径上的点权和
int res = 0;
while(top[x] != top[y]) {//x和y不在同一条重链上
if(dep[top[x]] < dep[top[y]])swap_(x , y);//强行让x所处的重链的顶端深度更大
res += query(root , id[top[x]] , id[x]);//答案累加上从x到top[x]的权值
res %= mod;
x = fa[top[x]];//x跳到所处重链顶端的父结点
}
if(dep[x] > dep[y])swap_(x , y);//此时x,y已经处于同一条重链上,强行让y结点深度更大
res += query(root , id[x] , id[y]); //答案累加上x~y的点权
return res % mod;
}
修改操作同理:
inline int path_add(int x , int y , int z) {
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]])swap_(x , y);
change(root , id[top[x]] , id[x] , z);
x = fa[top[x]];
}
if(dep[x] > dep[y])swap_(x , y);
change(root , id[x] , id[y] , z);
}
后记
树链剖分这个东西真的不难,就是很繁琐,如果想学就真的要沉下心来好好写代码,不要急
代码
#include <iostream>
#include <cstdio>
#define nn 100010
#define ll long long
using namespace std;
int read() {
int re = 0 , sig = 1;
char c;
do if((c = getchar()) == '-')sig = -1; while(c < '0' || c > '9');
while(c >= '0' && c <= '9')re = (re << 1) + (re << 3) + c - '0' , c = getchar();
return re * sig;
}
int n , m , rt , mod;
int dat[nn];
int dep[nn] , fa[nn] , siz[nn] , hvyson[nn];
int id[nn] , top[nn];
//SegmentTree-begin======================
struct SegmentTree{
int ls , rs , l , r;
ll dat , tag;
}tr[nn * 4];
int root;
int build(int l , int r) {
static int top = 1;
int p = top++;
tr[p].l = l , tr[p].r = r , tr[p].dat = 0 , tr[p].tag = 0;
if(l == r) tr[p].ls = tr[p].rs = 0;
else {
int mid = (l + r) / 2;
tr[p].ls = build(l , mid);
tr[p].rs = build(mid + 1 , r);
}
return p;
}
inline void spread(int p) {
if(tr[p].tag != 0) {
int ls = tr[p].ls , rs = tr[p].rs;
tr[ls].dat += (tr[ls].r - tr[ls].l + 1) * tr[p].tag;
tr[rs].dat += (tr[rs].r - tr[rs].l + 1) * tr[p].tag;
tr[ls].dat %= mod; tr[rs].dat %= mod;
tr[ls].tag += tr[p].tag;
tr[rs].tag += tr[p].tag;
tr[ls].tag %= mod; tr[rs].dat %= mod;
tr[p].tag = 0;
}
return;
}
void change(int p , int l , int r , ll dat) {
if(l <= tr[p].l && r >= tr[p].r) {
tr[p].dat += (tr[p].r - tr[p].l + 1) * dat;
tr[p].tag += dat;
return;
}
spread(p);
if(l > tr[p].r || r < tr[p].l)return;
change(tr[p].ls , l , r , dat);
change(tr[p].rs , l , r , dat);
tr[p].dat = (tr[tr[p].ls].dat + tr[tr[p].rs].dat) % mod;
}
int query(int p , int l , int r) {
if(l <= tr[p].l && r >= tr[p].r)return tr[p].dat;
if(l > tr[p].r || r < tr[p].l) return 0;
spread(p);
return query(tr[p].ls , l , r) + query(tr[p].rs , l , r);
}
//SegmentTree-end========================
//=======================================
struct ednode{
int nxt , to;
}ed[nn * 2];
int head[nn];
inline void addedge(int u , int v) {
static int top = 1;
ed[top].to = v , ed[top].nxt = head[u] , head[u] = top;
++top;
}
//=======================================
void dfs1(int x , int fa_ , int deep) {
fa[x] = fa_;
hvyson[x] = 0;
int maxsiz = 0;
dep[x] = deep;
for(int i = head[x] ; i ; i = ed[i].nxt) {
if(ed[i].to == fa_)continue;
dfs1(ed[i].to , x , deep + 1);
siz[x] += siz[ed[i].to];
if(maxsiz < siz[ed[i].to])
maxsiz = siz[ed[i].to],
hvyson[x] = ed[i].to;
}
++siz[x];
}
void dfs2(int x , int linktop) {
// cout << x <<' ';
static int cnt = 0;
id[x] = ++cnt;
change(root , id[x] , id[x] , dat[x]);
top[x] = linktop;
if(hvyson[x] == 0)return;
dfs2(hvyson[x] , linktop);
for(int i = head[x] ; i ; i = ed[i].nxt) {
int to = ed[i].to;
if(to == fa[x] || to == hvyson[x])continue;
dfs2(to , to);
}
}
//========================================
inline void tree_add(int x , int z) {
change(root , id[x] , id[x] + siz[x] - 1 , z);
}
inline int tree_query(int x) {
return query(root , id[x] , id[x] + siz[x] - 1);
}
inline void swap_(int &a , int &b){int tmp=a;a=b;b=tmp;}
inline int path_add(int x , int y , int z) {
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]])swap_(x , y);
change(root , id[top[x]] , id[x] , z);
x = fa[top[x]];
}
if(dep[x] > dep[y])swap_(x , y);
change(root , id[x] , id[y] , z);
}
inline int path_query(int x , int y) {
int res = 0;
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]])swap_(x , y);
res += query(root , id[top[x]] , id[x]);
res %= mod;
x = fa[top[x]];
}
if(dep[x] > dep[y])swap_(x , y);
res += query(root , id[x] , id[y]);
return res % mod;
}
int main() {
n = read() , m = read() , rt = read() , mod = read();
for(int i = 1 ; i <= n ; i++)
dat[i] = read();
for(int i = 1 ; i < n ; i++) {
int u = read() , v = read();
addedge(u , v);
addedge(v , u);
}
root = build(1 , n);
dfs1(rt , 0 , 0);
dfs2(rt , rt);
while(m--) {
int op = read();
int x , y , z;
switch(op) {
case 1 :
x = read() , y = read() , z = read();
path_add(x , y , z);
break;
case 2 :
x = read() , y = read();
printf("%d\n" , path_query(x , y) % mod);
break;
case 3 :
x = read() , z = read();
tree_add(x , z);
break;
case 4 :
x = read();
printf("%d\n" , tree_query(x) % mod);
break;
}
}
return 0;
}