树链剖分

树链剖分

0x00 绪言

在阅读这篇文章前,确保你已学会你下内容:

  • 线段树
  • 深度优先遍历

会了这些就可以开始阅读本篇文章了。

0x01 什么是树剖

把一棵树拆成若干个不相交的链,然后用一些数据结构去维护这些链

那么如何把树拆成链?

首先明确一些定义:

重儿子:该节点的子树中,节点个数最多的子树的根节点(也就是和该节点相连的点),即为该节点的重儿子

重边:连接该节点与它的重儿子的边

重链:由一系列重边相连得到的链

轻链:由一系列非重边相连得到的链

这样就不难得到拆树的方法

对于每一个节点,找出它的重儿子,那么这棵树就自然而然的被拆成了许多重链与许多轻链。

0x02 如何维护这些链

首先,要对这些链进行维护,就要确保每个链上的节点都是连续的,

因此我们需要对整棵树进行重新编号,然后利用 dfs 序的思想,用线段树等数据结构进行维护。

注意在进行重新编号的时候先访问重链。

这样可以保证重链内的节点编号连续。

0x03 操作

dfs1

按照我们上面说的,我们首先要对整棵树 dfs 一遍,找出每个节点的重儿子

顺便处理出每个节点的深度,以及他们的父亲节点

void dfs1(int x, int father, int depth)
{
    dep[x] = depth;
    fa[x] = father;
    size[x] = 1;
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (y == father)
        {
            continue;
        }
        dfs1(y, x, depth + 1);
        size[x] += size[y];
        if (size[son[x]] < size[y])
        {
            son[x] = y;
        }
    }
}

dfs2

然后我们需要对整棵树进行重新编号

我把一开始的每个节点的权值存在了 wt[]

void dfs2(int x, int t)
{
    id[x] = ++times;
    wt[times] = w[x];
    top[x] = t;
    if (son[x])
    {
        dfs2(son[x], t);
        for (rint i = h[x]; i; i = ne[i])
        {
            int y = e[i];
            if (y != fa[x] && y != son[x])
            {
                dfs2(y, y);
            }
        }
    }
}

线段树维护

我们需要根据重新编完号的树,把这棵树的上每个点映射到线段树上。

struct Tree
{
    int l, r;
    long long sum, add;
} t[N << 2];

void build(int p, int l, int r)
{
    t[p].l = l;
    t[p].r = r;
    if (l == r)
    {
        t[p].sum = wt[r];
        return;
    }
    int mid = (l + r) >> 1;
    build(p << 1, l, mid);
    build(p << 1 | 1, mid + 1, r);
    push_up(p);
}

另外线段树的基本操作。会在后边的代码里加入一些注释方便理解。

0x04 AcWing 模板代码实现

题目链接

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 1e5 + 5;
const int M = 2e5 + 5;

int n, m;
int w[N], h[N], e[M], ne[M], idx;
int dep[N], top[N], son[N], fa[N];
//son[] 记录的是节点的重儿子 
int wt[N], id[N], size[N], times;

struct Tree
{
    int l, r;
    long long sum, add;
} t[N << 2];

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

void dfs1(int x, int father, int depth)
{
    dep[x] = depth;
    fa[x] = father;
    size[x] = 1;
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (y == father)
        {
            continue;
        }
        dfs1(y, x, depth + 1);
        size[x] += size[y];
        if (size[son[x]] < size[y])
        {
            son[x] = y;
        }
    }
}

void dfs2(int x, int t)
{
    id[x] = ++times;
    wt[times] = w[x];
    top[x] = t;
    if (son[x])
    {
        dfs2(son[x], t);
        for (rint i = h[x]; i; i = ne[i])
        {
            int y = e[i];
            if (y != fa[x] && y != son[x])
            {
                dfs2(y, y);
            }
        }
    }
}

void push_up(int u)
{
    t[u].sum = t[u << 1].sum + t[u << 1 | 1].sum;
}

void push_down(int u)
{
    if (t[u].add)
    {
        t[u << 1].add += t[u].add, t[u << 1 | 1].add += t[u].add;
        t[u << 1].sum += t[u].add * (t[u << 1].r - t[u << 1].l + 1);
        t[u << 1 | 1].sum += t[u].add * (t[u << 1 | 1].r - t[u << 1 | 1].l + 1);
        t[u].add = 0;
    }
}

void build(int p, int l, int r)
{
    t[p].l = l;
    t[p].r = r;
    if (l == r)
    {
        t[p].sum = wt[r];
        return;
    }
    int mid = (l + r) >> 1;
    build(p << 1, l, mid);
    build(p << 1 | 1, mid + 1, r);
    push_up(p);
}

void change(int p, int l, int r, int x)
{
    if (t[p].l >= l && t[p].r <= r)
    {
        t[p].sum += x * (t[p].r - t[p].l + 1);
        t[p].add += x;
        return;
    }
    push_down(p);
    int mid = (t[p].l + t[p].r) >> 1;
    if (l <= mid)
    {
        change(p << 1, l, r, x);
    }
    if (mid < r)
    {
        change(p << 1 | 1, l, r, x);
    }
    push_up(p);
}

long long query(int p, int l, int r)
{
    if (t[p].l >= l && t[p].r <= r)
    {
        return t[p].sum;
    }
    push_down(p);
    int mid = (t[p].l + t[p].r) >> 1;
    if (r <= mid)
    {
        return query(p << 1, l, r);
    }
    if (mid < l)
    {
        return query(p << 1 | 1, l, r);
    }
    return query(p << 1, l, r) + query(p << 1 | 1, l, r);
}

void change_path(int a, int b, int c)
{
    while (top[a] != top[b])//当两个点不在同一条链上 
    {
        if (dep[top[a]] < dep[top[b]])//把 a 点改为所在链顶端的深度更深的那个点
        {
            std::swap(a, b);
        }
        change(1, id[top[a]], id[a], c);
        a = fa[top[a]];//把 a 跳到 a 所在链顶端的那个点的上面一个点
    }
    if (dep[b] < dep[a])
    {
        std::swap(a, b);
    }
    change(1, id[a], id[b], c);
}

void change_tree(int a, int c)
{
    change(1, id[a], id[a] + size[a] - 1, c);
}

long long query_path(int a, int b)
{
    long long ans = 0;
    while (top[a] != top[b])
    {
        if (dep[top[a]] < dep[top[b]])
        {
            std::swap(a, b);
        }
        ans += query(1, id[top[a]], id[a]);
        a = fa[top[a]];
    }
    if (dep[b] < dep[a])
    {
        std::swap(a, b);
    }
    ans += query(1, id[a], id[b]);
    return ans;
}

long long query_tree(int a)
{
    return query(1, id[a], id[a] + size[a] - 1);
}

int main()
{
    scanf("%d", &n);
    for (rint i = 1; i <= n; i++)
    {
        scanf("%d", &w[i]);
    }
    for (rint i = 1; i < n; i++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b);
        add(b, a);
    }
    dfs1(1, 1, 1);
    dfs2(1, 1);
    build(1, 1, n);
    scanf("%d", &m);
    while (m--)
    {
        int op, a, b, c;
        scanf("%d", &op);
        if (op == 1)
        {
            scanf("%d%d%d", &a, &b, &c);
            change_path(a, b, c);
        }
        if (op == 2)
        {
            scanf("%d%d", &a, &c);
            change_tree(a, c);
        }
        if (op == 3)
        {
            scanf("%d%d", &a, &b);
            printf("%lld\n", query_path(a, b));
        }
        if (op == 4)
        {
            scanf("%d", &a);
            printf("%lld\n", query_tree(a));
        }
    }
    return 0;
}
posted @ 2022-11-27 16:13  PassName  阅读(35)  评论(0编辑  收藏  举报