树链剖分详解

树链剖分模板题

1,将树从x到y结点最短路径上所有节点的值都加上z

2,求树从x到y结点最短路径上所有节点的值之和

3,将x为根节点的子树内所有节点的值加上z

4,求x为根节点的子树内所有节点值之和

(以下都基于这个题目展开讲解)

如果没有操作3和4,这题可以用树上差分和lca解决,也是模板题

树上差分指路:[https://www.cnblogs.com/gzh-red/p/11185914.html]

求lca指路:[https://www.cnblogs.com/lsdsjy/p/4071041.html]


好,进入正题

树链剖分,顾名思义,就是通过轻重边的划分将树分割成多条链,保证每个点属于且只属于一条链,然后再通过数据结构(树状数组、BST、SPLAY、线段树等)来维护每一条链。(摘自百度百科)

几个基本概念:
  • 重儿子:当前节点所有子结点中子树节点数(size)最多的结点,为当前节点的重儿子
  • 轻儿子:除了重节点的儿子
  • 重边:连接父亲节点和重儿子的边
  • 轻边:连接父亲节点和轻儿子的边
  • 重链:多条重边连接成的路径(也就是说有多条重边在一条连续的路径上)
  • 轻链:多条轻边组成的路径

比如这样一棵树

这样一棵树

①的子节点中,④的size最大,所以④是①的重儿子,连接①④的边为重边。

最终dfs一次这张图应该有这些信息

![ZqqA1S](C:\Users\wawa\Pictures\Camera Roll\ZqqA1S.jpg)

红圈表示重儿子,加粗边为重边

这个过程的代码实现

void dfs1(int u, int fa)
{
    deg[u] = deg[fa] + 1;//deg为节点深度
    f[u] = fa;           //f[u]为u的父亲节点
    sz[u] = 1;           //sz[u]为u所在子树大小(包括自己)
    for (int i = head[u]; i; i = e[i].nex)
    {
        int ne = e[i].to;
        if (ne == fa) continue;
        dfs1(ne, u);
        sz[u] += sz[ne];
        if (sz[ne] > sz[son[u]]) 
            son[u] = ne;
    }
}

注意的是:如果u为叶节点,它是没有重儿子的。如果u有多个子结点子树大小相等,随便谁当重儿子都行。


第二遍dfs要将重边连成重链,保证一条重链上的节点dfs序连续,以便用用数据结构维护(比如线段树肯定是对连续的区间进行维护)同时要处理出重链的链头,也就是假如一条重链深度最小的点为①,然后有②③④,则top[1],top[2],top[3]和top[4]均为1。

void dfs2(int u, int top_fa)
{
    xu[u] = ++inde; 
    v[inde] = w[u]; 
    top[u] = top_fa;
    if (!son[u]) return ;//如果为叶节点,返回
    dfs2(son[u], top_fa);//优先走重边
    for (int i = head[u]; i; i = e[i].nex)
    {
        int ne = e[i].to;
        if (ne == f[u] || ne == son[u]) continue;
        dfs2(ne, ne);
    }
}

剖分的工作就做完了,接下来就是用数据结构维护

因为一条重链上的节点dfs序连续,那么路径和以及路径修改就可以转化成区间求和以及区间修改,那么可以用线段树维护。

以查询为例 ,还是用求lca的方式,用top可以直接跳到一条重链的起始位置,让top[x]和top[y]中比较深的节点来跳,直接跳到对应top[]的父节点,可以保证两个节点一定不会擦肩而过,也能保证最好两个点一定能跳到同一条重链上。

ll query(int rt, int l, int r, int x, int y)
{
    if (x <= l && r <= y) return sum[rt];
    pushdown(rt, l, r);
    int mid = (l + r) >> 1;
    ll res = 0;
    if (x <= mid) 
        res = (res + query(lson, l, mid, x, y)) % mod;
    if (mid < y) 
        res = (res + query(rson, mid+1, r, x, y)) % mod;
    return res;
}
ll qRange(int x, int y)
{
    ll ans=0;
    while (top[x] != top[y])
    {
        if (deg[top[x]] < deg[top[y]])
            swap(x,y);
        ans = (ans + query(1, 1, n, xu[top[x]], xu[x])) % mod;
        x = f[top[x]];
    }
    if (deg[x] > deg[y])
        swap(x, y);
    ans = (ans + query(1, 1, n, xu[x], xu[y])) % mod;
    return ans;
}

可以用上面的图跳几对点手动模拟一下,加深理解。

完整代码

#include<cstdio>
#include<cstring>
#include<iostream>
#include<sstream>
#include<algorithm>
#include<vector>
#include<queue>
#include<map>
#include<cstdlib>
#include<cmath>
using namespace std;
#define re register int
#define ull unsigned long long
#define ll long long
#define inf 0x3f3f3f3f
#define N 1000010
#define lson rt<<1
#define rson rt<<1|1
#define lowbit(x) (x)&(-(x))
void FRE(){freopen("subsets.in","r",stdin);freopen("subsets.out","w",stdout);}
inline ll read()
{
    ll x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ll)(ch-'0');ch=getchar();}
    return x*f;
}

int cnt, head[N], n, m, root, xu[N], deg[N], top[N];
int f[N],sz[N],son[N],inde;
ll mod,sum[N],v[N],w[N],tag[N];
struct node
{
    int to,nex;
}e[N*2];

void add(int u, int v)
{
    e[++cnt].nex = head[u];
    head[u] = cnt;
    e[cnt].to=v;
}

void dfs1(int u, int fa)
{
    deg[u] = deg[fa] + 1;//deg为节点深度
    f[u] = fa;           //f[u]为u的父亲节点
    sz[u] = 1;           //sz[u]为u所在子树大小(包括自己)
    for (int i = head[u]; i; i = e[i].nex)
    {
        int ne = e[i].to;
        if (ne == fa) continue;
        dfs1(ne, u);
        sz[u] += sz[ne];
        if (sz[ne] > sz[son[u]]) 
            son[u] = ne;
    }
}

void dfs2(int u, int top_fa)
{
    xu[u] = ++inde; 
    v[inde] = w[u]; 
    top[u] = top_fa;
    if (!son[u]) return ;//如果为叶节点,返回
    dfs2(son[u], top_fa);//优先走重边
    for (int i = head[u]; i; i = e[i].nex)
    {
        int ne = e[i].to;
        if (ne == f[u] || ne == son[u]) continue;
        dfs2(ne, ne);
    }
}

void pushup(int rt)
{
    sum[rt] = (sum[lson] + sum[rson])%mod;
}

void pushdown(int rt, int l, int r)
{
    if (!tag[rt]) return;
    int mid = (l + r) >> 1;
    tag[lson] = (tag[lson] + tag[rt]) % mod;
    tag[rson] = (tag[rson] + tag[rt]) %mod;
    sum[lson] = (sum[lson] + (tag[rt] * (ll)(mid - l + 1)) % mod) % mod;
    sum[rson] = (sum[rson] + (tag[rt] * (ll)(r-mid))  % mod) % mod;
    tag[rt] = 0;
}
void build(int rt, int l, int r)
{
    if (l == r)
    {
        sum[rt] = v[l];
        return ;
    }
    int mid = (l + r)>>1;
    build(lson, l, mid);
    build(rson, mid+1, r);
    pushup(rt);
}

void update(int rt, int l, int r, int x, int y, int val)
{
    if (x <= l && r<=y) 
    {
        sum[rt] = (sum[rt] + (r-l+1) * val) % mod;
        tag[rt] = (tag[rt] + val) % mod;
        return ;
    }
    pushdown(rt, l, r);
    int mid = (l + r) >> 1;
    if (x <= mid)
        update(lson, l, mid, x, y, val);
    if (mid < y) 
        update(rson, mid+1, r, x, y, val);
    pushup(rt);
}

void upRange(int x, int y, int val)
{
    while (top[x] != top[y])
    {
        if (deg[top[x]] < deg[top[y]])
            swap(x,y);
        update(1, 1, n, xu[top[x]], xu[x], val);
        x = f[top[x]];
    }
    if (deg[x] >deg[y])
        swap(x,y);
    update(1, 1, n, xu[x], xu[y], val);
}

ll query(int rt, int l, int r, int x, int y)
{
    if (x <= l && r <= y) return sum[rt];
    pushdown(rt, l, r);
    int mid = (l + r) >> 1;
    ll res = 0;
    if (x <= mid) 
        res = (res + query(lson, l, mid, x, y)) % mod;
    if (mid < y) 
        res = (res + query(rson, mid+1, r, x, y)) % mod;
    return res;
}
ll qRange(int x, int y)
{
    ll ans=0;
    while (top[x] != top[y])
    {
        if (deg[top[x]] < deg[top[y]])
            swap(x,y);
        ans = (ans + query(1, 1, n, xu[top[x]], xu[x])) % mod;
        x = f[top[x]];
    }
    if (deg[x] > deg[y])
        swap(x, y);
    ans = (ans + query(1, 1, n, xu[x], xu[y])) % mod;
    return ans;
}

int main()
{
	n = read(); m = read();
    root = read(); mod = read();
    for (int i = 1; i <= n; i++)
        w[i] = read();
    for (int i = 1; i < n; i++)
    {
        int x = read(), y = read();
        add(x, y), add(y, x);
    }
    dfs1(root, 0); dfs2(root, root);
    build(1, 1, n);
    while (m--)
    {
        int ty = read();
        if (ty == 1)
        {
            int x = read(), y = read();
            ll z = read() % mod;
            upRange(x, y, z);
        }
        if (ty == 2)
        {
            int x = read(), y = read();
            printf("%lld\n", qRange(x, y));
        }
        if (ty == 3)
        {
            int x = read();
            ll z = read();
            update(1, 1, n, xu[x], xu[x] + sz[x] - 1, z);
        }
        if (ty == 4)
        {
            int x = read();
            printf("%lld\n", query(1, 1, n, xu[x], xu[x] + sz[x] - 1));
        }
    }
	return 0;
}

练习:[NOI2015]软件包管理器

这个题目就可以理解成区间修改,每次修改的是根节点到x路径

完整代码

#include <bits/stdc++.h>
#define inf 1e18
#define ll long long
#define N 1000010
#define lson rt << 1
#define rson rt << 1 | 1
#define mo 998244353
using namespace std;
typedef pair<int, int> P;
inline ll read()
{
    ll x = 0, f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9')
    {
        if (ch == '-')
            f = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
        x = x * 10 + ch - '0', ch = getchar();
    return x * f;
}
int f[N], sz[N], son[N], head[N], cnt, deg[N], xu[N], len, tag[N], sum[N], top[N], n;
char ty[20];
struct node
{
    int nex, to;
} e[N * 2];
void add(int u, int v)
{
    e[++cnt].nex = head[u];
    head[u] = cnt;
    e[cnt].to = v;
}
void dfs1(int u, int fa)
{
    f[u] = fa;
    sz[u] = 1;
    deg[u] = deg[fa] + 1;
    for (int i = head[u]; i; i = e[i].nex)
    {
        int ne = e[i].to;
        if (ne == fa)
            continue;
        dfs1(ne, u);
        sz[u] += sz[ne];
        if (sz[ne] > sz[son[u]])
            son[u] = ne;
    }
}
void dfs2(int u, int fa)
{
    top[u] = fa;
    xu[u] = ++len;
    if (!son[u])
        return;
    dfs2(son[u], fa);
    for (int i = head[u]; i; i = e[i].nex)
    {
        int ne = e[i].to;
        if (ne == f[u] || ne == son[u])
            continue;
        dfs2(ne, ne);
    }
}
void pushdown(int rt, int l, int r)
{
    if (tag[rt] == -1)
        return;
    int mid = (l + r) >> 1;
    tag[lson] = tag[rt], tag[rson] = tag[rt];
    sum[lson] = (mid - l + 1) * tag[rt], sum[rson] = (r - mid) * tag[rt];
    tag[rt] = -1;
}
void pushup(int rt) { sum[rt] = sum[lson] + sum[rson]; }
void modify(int rt, int l, int r, int x, int y, int val)
{
    if (x <= l && r <= y)
    {
        sum[rt] = (r - l + 1) * val;
        tag[rt] = val;
        return;
    }
    pushdown(rt, l, r);
    int mid = (l + r) >> 1;
    if (x <= mid)
        modify(lson, l, mid, x, y, val);
    if (y > mid)
        modify(rson, mid + 1, r, x, y, val);
    pushup(rt);
}
void upRange(int x, int y)
{
    while (top[x] != top[y])
    {
        if (deg[top[x]] < deg[top[y]])
            swap(x, y);
        modify(1, 1, n, xu[top[x]], xu[x], 1);
        x = f[top[x]];
    }
    if (deg[x] > deg[y])
        swap(x, y);
    modify(1, 1, n, xu[x], xu[y], 1);
}
int query(int rt, int l, int r, int x, int y)
{
    if (x <= l && r <= y)
        return sum[rt];
    pushdown(rt, l, r);
    int mid = (l + r) >> 1, res = 0;
    if (x <= mid)
        res += query(lson, l, mid, x, y);
    if (y > mid)
        res += query(rson, mid + 1, r, x, y);
    return res;
}
int main()
{
    n = read();
    for (int i = 2; i <= n; i++)
    {
        int x = read();
        add(x + 1, i), add(i, x + 1);
    }
    dfs1(1, 0);
    dfs2(1, 1);
    int Q = read();
    while (Q--)
    {
        scanf("%s", ty);
        int x = read() + 1;
        int tmp = sum[1];
        if (ty[0] == 'i')
        {
            upRange(1, x);
            printf("%d\n", sum[1] - tmp);
        }
        else
        {
            modify(1, 1, n, xu[x], xu[x] + sz[x] - 1, 0);
            printf("%d\n", tmp - sum[1]);
        }
    }
    return 0;
}
posted @ 2020-11-08 00:33  蛙蛙1551  阅读(128)  评论(0编辑  收藏  举报