树链剖分--P3384 【模板】轻重链剖分
经过了一系列的前置知识,终于学会了树链剖分!!
重链剖分的思想:
重链剖分可以将树上的任意一条路径划分成不超过\(O(logn)\)条连续的链,每条链上的点深度互不相同(即是自底向上的一条链,链上所有点的\(LCA\)$为链的一个端点)。
重链剖分还能保证划分出的每条链上的节点\(DFS\)序连续,因此可以方便地用一些维护序列的数据结构(如线段树)来维护树上路径的信息。
如:
-
修改 树上两点之间的路径上 所有点的值。
-
查询 树上两点之间的路径上 节点权值的 和/极值/其它(在序列上可以用数据结构维护,便于合并的信息)。
我们给出一些定义:
-
重子节点 :表示其子节点中子树最大的子结点。如果有多个子树最大的子结点,取其一。如果没有子节点,就无重子节点。
-
轻子节点 :表示剩余的所有子结点。
-
重边 :从这个结点到重子节点的边为重边 。
-
轻边 :到其他轻子节点的边为 轻边 。
-
重链 :若干条首尾衔接的重边构成 重链 。
实现:
树剖的实现分两个\(DFS\)的过程
第一个\(DFS\)记录每个结点的父节点、深度、子树大小、重子节点。
-
\(siz_x\),表示子树\(x\)的大小
-
\(dep_x\),表示点\(x\)的深度
-
\(fa_x\),表示点\(x\)的父亲
-
\(son_x\),表示点\(x\)的重儿子
void dfs1(int x){
siz[x] = 1;dep[x] = dep[fa[x]]+1;
for (int i = head[x];i;i = ed[i].nxt){
int to = ed[i].to;
if (to == fa[x]) continue;
fa[to] = x;
dfs1(to);
if (siz[to] > siz[son[x]]) son[x] = to;
siz[x] += siz[to];
}
}
第二个\(DFS\)记录所在链的链顶(\(root\),应初始化为结点本身)、重边优先遍历时的\(DFS\)序\((dfn)\)、\(DFS\)序对应的节点编号\((pos)\)。
-
\(str_x\),表示\(x\)所在重链的链顶
-
\(dfn_x\),表示点\(x\)的\(dfs\)序
-
\(pos_x\),表示\(dfs\)序为\(x\)的点
void dfs2(int x,int root){
str[x] = root;
dfn[x] = ++cnt;pos[cnt] = x;
if (son[x]) dfs2(son[x],root);
for (int i = head[x];i;i = ed[i].nxt){
int to = ed[i].to;
if (to == fa[x]||to == son[x]) continue;
dfs2(to,to);
}
}
路径上修改和查询:
链上的\(DFS\)序是连续的,可以使用线段树、树状数组维护,每次选择深度较大的链往上跳,直到两点在同一条链上。
void fix1(){
int a = read(),b = read(),x = read();
while (str[a] != str[b]){
if (dep[str[a]] < dep[str[b]]) swap(a,b);
modify(1,1,n,dfn[str[a]],dfn[a],x);
a = fa[str[a]];
}
if (dep[a] > dep[b]) swap(a,b);
modify(1,1,n,dfn[a],dfn[b],x);
}
void fix2(){
int a = read(),b = read();
int res = 0;
while (str[a] != str[b]){
if (dep[str[a]] < dep[str[b]]) swap(a,b);
(res += query(1,1,n,dfn[str[a]],dfn[a]))%=mod;
a = fa[str[a]];
}
if (dep[a] > dep[b]) swap(a,b);
(res += query(1,1,n,dfn[a],dfn[b]))%=mod;
printf("%lld\n",res);
}
子树修改和查询:
在\(DFS\)搜索的时候,子树中的结点的\(DFS\)序是连续的,每一个结点到子树末端的结点的\(dfs\)序就为他本身的\(dfs\)序+子树大小-1。
void fix3(){
int a = read(),x = read();
modify(1,1,n,dfn[a],dfn[a]+siz[a]-1,x);
}
void fix4(){
int a = read();
printf("%lld\n",query(1,1,n,dfn[a],dfn[a]+siz[a]-1));
}
例题的完整代码:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#define int long long
using namespace std;
int read(){
int x = 1,a = 0;char ch = getchar();
while (ch < '0'||ch > '9'){if (ch == '-') x = -1;ch = getchar();}
while (ch >= '0'&&ch <= '9'){a = a*10+ch-'0';ch = getchar();}
return x*a;
}
const int maxn = 1e5+10;
int n,m,r,mod,a[maxn];
struct node{
int to,nxt;
}ed[maxn*2];
int head[maxn*2],tot;
void add(int u,int to){
ed[++tot].to = to;
ed[tot].nxt = head[u];
head[u] = tot;
}
int fa[maxn],siz[maxn],son[maxn],dep[maxn];
void dfs1(int x){
siz[x] = 1;dep[x] = dep[fa[x]]+1;
for (int i = head[x];i;i = ed[i].nxt){
int to = ed[i].to;
if (to == fa[x]) continue;
fa[to] = x;
dfs1(to);
if (siz[to] > siz[son[x]]) son[x] = to;
siz[x] += siz[to];
}
}
int cnt,str[maxn],dfn[maxn],pos[maxn];
void dfs2(int x,int root){
str[x] = root;
dfn[x] = ++cnt;pos[cnt] = x;
if (son[x]) dfs2(son[x],root);
for (int i = head[x];i;i = ed[i].nxt){
int to = ed[i].to;
if (to == fa[x]||to == son[x]) continue;
dfs2(to,to);
}
}
int tree[maxn*4],lazy[maxn*4];
int ls(int x){return x<<1;}
int rs(int x){return x<<1|1;}
void pushup(int x){
tree[x] = tree[ls(x)] + tree[rs(x)];
}
void build(int x,int l,int r){
if (l == r){tree[x] = a[pos[l]];return;}
int mid = (l+r)>>1;
build(ls(x),l,mid);build(rs(x),mid+1,r);
pushup(x);
}
void tag(int x,int l,int r,int k){
lazy[x] += k;
tree[x] += (r-l+1)*k;
}
void pushdown(int x,int l,int r){
int mid = (l+r)>>1;
tag(ls(x),l,mid,lazy[x]);
tag(rs(x),mid+1,r,lazy[x]);
lazy[x] = 0;
}
void modify(int x,int l,int r,int nl,int nr,int k){
if (nl <= l&&r <= nr){tag(x,l,r,k);return;}
int mid = (l+r)>>1;
pushdown(x,l,r);
if (nl <= mid) modify(ls(x),l,mid,nl,nr,k);
if (nr > mid) modify(rs(x),mid+1,r,nl,nr,k);
pushup(x);
}
int query(int x,int l,int r,int nl,int nr){
int res = 0;
if (nl <= l&&r <= nr) return tree[x];
int mid = (l+r)>>1;
pushdown(x,l,r);
if (nl <= mid) (res+=query(ls(x),l,mid,nl,nr))%=mod;
if (nr > mid) (res+=query(rs(x),mid+1,r,nl,nr))%=mod;
return res;
}
void fix1(){
int a = read(),b = read(),x = read();
while (str[a] != str[b]){
if (dep[str[a]] < dep[str[b]]) swap(a,b);
modify(1,1,n,dfn[str[a]],dfn[a],x);
a = fa[str[a]];
}
if (dep[a] > dep[b]) swap(a,b);
modify(1,1,n,dfn[a],dfn[b],x);
}
void fix2(){
int a = read(),b = read();
int res = 0;
while (str[a] != str[b]){
if (dep[str[a]] < dep[str[b]]) swap(a,b);
(res += query(1,1,n,dfn[str[a]],dfn[a]))%=mod;
a = fa[str[a]];
}
if (dep[a] > dep[b]) swap(a,b);
(res += query(1,1,n,dfn[a],dfn[b]))%=mod;
printf("%lld\n",res);
}
void fix3(){
int a = read(),x = read();
modify(1,1,n,dfn[a],dfn[a]+siz[a]-1,x);
}
void fix4(){
int a = read();
printf("%lld\n",query(1,1,n,dfn[a],dfn[a]+siz[a]-1));
}
signed main(){
n = read(),m = read(),r = read(),mod = read();
for (int i = 1;i <= n;i++) a[i] = read();
for (int i = 1;i <= n-1;i++){
int x = read(),y = read();
add(x,y),add(y,x);
}
dfs1(r);dfs2(r,r);
build(1,1,n);
for (int i = 1;i <= m;i++){
int op = read();
if (op == 1) fix1();
if (op == 2) fix2();
if (op == 3) fix3();
if (op == 4) fix4();
}
return 0;
}