树链剖分详解
1,将树从x到y结点最短路径上所有节点的值都加上z
2,求树从x到y结点最短路径上所有节点的值之和
3,将x为根节点的子树内所有节点的值加上z
4,求x为根节点的子树内所有节点值之和
(以下都基于这个题目展开讲解)
如果没有操作3和4,这题可以用树上差分和lca解决,也是模板题
树上差分指路:[https://www.cnblogs.com/gzh-red/p/11185914.html]
求lca指路:[https://www.cnblogs.com/lsdsjy/p/4071041.html]
好,进入正题
树链剖分,顾名思义,就是通过轻重边的划分将树分割成多条链,保证每个点属于且只属于一条链,然后再通过数据结构(树状数组、BST、SPLAY、线段树等)来维护每一条链。(摘自百度百科)
几个基本概念:
- 重儿子:当前节点所有子结点中子树节点数(size)最多的结点,为当前节点的重儿子
- 轻儿子:除了重节点的儿子
- 重边:连接父亲节点和重儿子的边
- 轻边:连接父亲节点和轻儿子的边
- 重链:多条重边连接成的路径(也就是说有多条重边在一条连续的路径上)
- 轻链:多条轻边组成的路径
比如这样一棵树
①的子节点中,④的size最大,所以④是①的重儿子,连接①④的边为重边。
最终dfs一次这张图应该有这些信息
![ZqqA1S](C:\Users\wawa\Pictures\Camera Roll\ZqqA1S.jpg)
红圈表示重儿子,加粗边为重边
这个过程的代码实现
void dfs1(int u, int fa)
{
deg[u] = deg[fa] + 1;//deg为节点深度
f[u] = fa; //f[u]为u的父亲节点
sz[u] = 1; //sz[u]为u所在子树大小(包括自己)
for (int i = head[u]; i; i = e[i].nex)
{
int ne = e[i].to;
if (ne == fa) continue;
dfs1(ne, u);
sz[u] += sz[ne];
if (sz[ne] > sz[son[u]])
son[u] = ne;
}
}
注意的是:如果u为叶节点,它是没有重儿子的。如果u有多个子结点子树大小相等,随便谁当重儿子都行。
第二遍dfs要将重边连成重链,保证一条重链上的节点dfs序连续,以便用用数据结构维护(比如线段树肯定是对连续的区间进行维护)同时要处理出重链的链头,也就是假如一条重链深度最小的点为①,然后有②③④,则top[1],top[2],top[3]和top[4]均为1。
void dfs2(int u, int top_fa)
{
xu[u] = ++inde;
v[inde] = w[u];
top[u] = top_fa;
if (!son[u]) return ;//如果为叶节点,返回
dfs2(son[u], top_fa);//优先走重边
for (int i = head[u]; i; i = e[i].nex)
{
int ne = e[i].to;
if (ne == f[u] || ne == son[u]) continue;
dfs2(ne, ne);
}
}
剖分的工作就做完了,接下来就是用数据结构维护
因为一条重链上的节点dfs序连续,那么路径和以及路径修改就可以转化成区间求和以及区间修改,那么可以用线段树维护。
以查询为例 ,还是用求lca的方式,用top可以直接跳到一条重链的起始位置,让top[x]和top[y]中比较深的节点来跳,直接跳到对应top[]的父节点,可以保证两个节点一定不会擦肩而过,也能保证最好两个点一定能跳到同一条重链上。
ll query(int rt, int l, int r, int x, int y)
{
if (x <= l && r <= y) return sum[rt];
pushdown(rt, l, r);
int mid = (l + r) >> 1;
ll res = 0;
if (x <= mid)
res = (res + query(lson, l, mid, x, y)) % mod;
if (mid < y)
res = (res + query(rson, mid+1, r, x, y)) % mod;
return res;
}
ll qRange(int x, int y)
{
ll ans=0;
while (top[x] != top[y])
{
if (deg[top[x]] < deg[top[y]])
swap(x,y);
ans = (ans + query(1, 1, n, xu[top[x]], xu[x])) % mod;
x = f[top[x]];
}
if (deg[x] > deg[y])
swap(x, y);
ans = (ans + query(1, 1, n, xu[x], xu[y])) % mod;
return ans;
}
可以用上面的图跳几对点手动模拟一下,加深理解。
完整代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<sstream>
#include<algorithm>
#include<vector>
#include<queue>
#include<map>
#include<cstdlib>
#include<cmath>
using namespace std;
#define re register int
#define ull unsigned long long
#define ll long long
#define inf 0x3f3f3f3f
#define N 1000010
#define lson rt<<1
#define rson rt<<1|1
#define lowbit(x) (x)&(-(x))
void FRE(){freopen("subsets.in","r",stdin);freopen("subsets.out","w",stdout);}
inline ll read()
{
ll x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ll)(ch-'0');ch=getchar();}
return x*f;
}
int cnt, head[N], n, m, root, xu[N], deg[N], top[N];
int f[N],sz[N],son[N],inde;
ll mod,sum[N],v[N],w[N],tag[N];
struct node
{
int to,nex;
}e[N*2];
void add(int u, int v)
{
e[++cnt].nex = head[u];
head[u] = cnt;
e[cnt].to=v;
}
void dfs1(int u, int fa)
{
deg[u] = deg[fa] + 1;//deg为节点深度
f[u] = fa; //f[u]为u的父亲节点
sz[u] = 1; //sz[u]为u所在子树大小(包括自己)
for (int i = head[u]; i; i = e[i].nex)
{
int ne = e[i].to;
if (ne == fa) continue;
dfs1(ne, u);
sz[u] += sz[ne];
if (sz[ne] > sz[son[u]])
son[u] = ne;
}
}
void dfs2(int u, int top_fa)
{
xu[u] = ++inde;
v[inde] = w[u];
top[u] = top_fa;
if (!son[u]) return ;//如果为叶节点,返回
dfs2(son[u], top_fa);//优先走重边
for (int i = head[u]; i; i = e[i].nex)
{
int ne = e[i].to;
if (ne == f[u] || ne == son[u]) continue;
dfs2(ne, ne);
}
}
void pushup(int rt)
{
sum[rt] = (sum[lson] + sum[rson])%mod;
}
void pushdown(int rt, int l, int r)
{
if (!tag[rt]) return;
int mid = (l + r) >> 1;
tag[lson] = (tag[lson] + tag[rt]) % mod;
tag[rson] = (tag[rson] + tag[rt]) %mod;
sum[lson] = (sum[lson] + (tag[rt] * (ll)(mid - l + 1)) % mod) % mod;
sum[rson] = (sum[rson] + (tag[rt] * (ll)(r-mid)) % mod) % mod;
tag[rt] = 0;
}
void build(int rt, int l, int r)
{
if (l == r)
{
sum[rt] = v[l];
return ;
}
int mid = (l + r)>>1;
build(lson, l, mid);
build(rson, mid+1, r);
pushup(rt);
}
void update(int rt, int l, int r, int x, int y, int val)
{
if (x <= l && r<=y)
{
sum[rt] = (sum[rt] + (r-l+1) * val) % mod;
tag[rt] = (tag[rt] + val) % mod;
return ;
}
pushdown(rt, l, r);
int mid = (l + r) >> 1;
if (x <= mid)
update(lson, l, mid, x, y, val);
if (mid < y)
update(rson, mid+1, r, x, y, val);
pushup(rt);
}
void upRange(int x, int y, int val)
{
while (top[x] != top[y])
{
if (deg[top[x]] < deg[top[y]])
swap(x,y);
update(1, 1, n, xu[top[x]], xu[x], val);
x = f[top[x]];
}
if (deg[x] >deg[y])
swap(x,y);
update(1, 1, n, xu[x], xu[y], val);
}
ll query(int rt, int l, int r, int x, int y)
{
if (x <= l && r <= y) return sum[rt];
pushdown(rt, l, r);
int mid = (l + r) >> 1;
ll res = 0;
if (x <= mid)
res = (res + query(lson, l, mid, x, y)) % mod;
if (mid < y)
res = (res + query(rson, mid+1, r, x, y)) % mod;
return res;
}
ll qRange(int x, int y)
{
ll ans=0;
while (top[x] != top[y])
{
if (deg[top[x]] < deg[top[y]])
swap(x,y);
ans = (ans + query(1, 1, n, xu[top[x]], xu[x])) % mod;
x = f[top[x]];
}
if (deg[x] > deg[y])
swap(x, y);
ans = (ans + query(1, 1, n, xu[x], xu[y])) % mod;
return ans;
}
int main()
{
n = read(); m = read();
root = read(); mod = read();
for (int i = 1; i <= n; i++)
w[i] = read();
for (int i = 1; i < n; i++)
{
int x = read(), y = read();
add(x, y), add(y, x);
}
dfs1(root, 0); dfs2(root, root);
build(1, 1, n);
while (m--)
{
int ty = read();
if (ty == 1)
{
int x = read(), y = read();
ll z = read() % mod;
upRange(x, y, z);
}
if (ty == 2)
{
int x = read(), y = read();
printf("%lld\n", qRange(x, y));
}
if (ty == 3)
{
int x = read();
ll z = read();
update(1, 1, n, xu[x], xu[x] + sz[x] - 1, z);
}
if (ty == 4)
{
int x = read();
printf("%lld\n", query(1, 1, n, xu[x], xu[x] + sz[x] - 1));
}
}
return 0;
}
这个题目就可以理解成区间修改,每次修改的是根节点到x路径
完整代码
#include <bits/stdc++.h>
#define inf 1e18
#define ll long long
#define N 1000010
#define lson rt << 1
#define rson rt << 1 | 1
#define mo 998244353
using namespace std;
typedef pair<int, int> P;
inline ll read()
{
ll x = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9')
{
if (ch == '-')
f = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9')
x = x * 10 + ch - '0', ch = getchar();
return x * f;
}
int f[N], sz[N], son[N], head[N], cnt, deg[N], xu[N], len, tag[N], sum[N], top[N], n;
char ty[20];
struct node
{
int nex, to;
} e[N * 2];
void add(int u, int v)
{
e[++cnt].nex = head[u];
head[u] = cnt;
e[cnt].to = v;
}
void dfs1(int u, int fa)
{
f[u] = fa;
sz[u] = 1;
deg[u] = deg[fa] + 1;
for (int i = head[u]; i; i = e[i].nex)
{
int ne = e[i].to;
if (ne == fa)
continue;
dfs1(ne, u);
sz[u] += sz[ne];
if (sz[ne] > sz[son[u]])
son[u] = ne;
}
}
void dfs2(int u, int fa)
{
top[u] = fa;
xu[u] = ++len;
if (!son[u])
return;
dfs2(son[u], fa);
for (int i = head[u]; i; i = e[i].nex)
{
int ne = e[i].to;
if (ne == f[u] || ne == son[u])
continue;
dfs2(ne, ne);
}
}
void pushdown(int rt, int l, int r)
{
if (tag[rt] == -1)
return;
int mid = (l + r) >> 1;
tag[lson] = tag[rt], tag[rson] = tag[rt];
sum[lson] = (mid - l + 1) * tag[rt], sum[rson] = (r - mid) * tag[rt];
tag[rt] = -1;
}
void pushup(int rt) { sum[rt] = sum[lson] + sum[rson]; }
void modify(int rt, int l, int r, int x, int y, int val)
{
if (x <= l && r <= y)
{
sum[rt] = (r - l + 1) * val;
tag[rt] = val;
return;
}
pushdown(rt, l, r);
int mid = (l + r) >> 1;
if (x <= mid)
modify(lson, l, mid, x, y, val);
if (y > mid)
modify(rson, mid + 1, r, x, y, val);
pushup(rt);
}
void upRange(int x, int y)
{
while (top[x] != top[y])
{
if (deg[top[x]] < deg[top[y]])
swap(x, y);
modify(1, 1, n, xu[top[x]], xu[x], 1);
x = f[top[x]];
}
if (deg[x] > deg[y])
swap(x, y);
modify(1, 1, n, xu[x], xu[y], 1);
}
int query(int rt, int l, int r, int x, int y)
{
if (x <= l && r <= y)
return sum[rt];
pushdown(rt, l, r);
int mid = (l + r) >> 1, res = 0;
if (x <= mid)
res += query(lson, l, mid, x, y);
if (y > mid)
res += query(rson, mid + 1, r, x, y);
return res;
}
int main()
{
n = read();
for (int i = 2; i <= n; i++)
{
int x = read();
add(x + 1, i), add(i, x + 1);
}
dfs1(1, 0);
dfs2(1, 1);
int Q = read();
while (Q--)
{
scanf("%s", ty);
int x = read() + 1;
int tmp = sum[1];
if (ty[0] == 'i')
{
upRange(1, x);
printf("%d\n", sum[1] - tmp);
}
else
{
modify(1, 1, n, xu[x], xu[x] + sz[x] - 1, 0);
printf("%d\n", tmp - sum[1]);
}
}
return 0;
}