YZhe的头像

简述树链剖分

@

题目链接:luogu P3384 【模板】树链剖分
先上完整代码,变量名解释[1]

#include<cstdio>
#include<algorithm>
#include<iostream>
using namespace std;
typedef long long ll;
#define N 500005
#define RI register int
int tot=0,n,m,rt,md;
int fa[ N ],deep[ N ],head[ N ],size[ N ],son[ N ],id[ N ],w[ N ],nw[ N ],top[ N ];

struct EDGE{
    int to,next;
}e[ N ];

inline void add( int from , int to ){
    e[ ++ tot ].to = to;
    e[ tot ].next = head[ from ];
    head[ from ] = tot;
}

template<class T>
inline void read(T &res){
    static char ch;T flag = 1;
    while( ( ch = getchar() ) < '0' || ch > '9' ) if( ch == '-' ) flag = -1;
    res = ch - 48;
    while( ( ch = getchar() ) >= '0' && ch <= '9' ) res = res * 10 + ch - 48;
    res *= flag;
}

struct NODE{
    ll sum,flag;
    NODE *ls,*rs;
    NODE(){
        sum = flag = 0;
        ls = rs = NULL;
    }
    inline void pushdown( int l , int r )
    {
        if( flag )
        {
            int midd = ( l + r ) >> 1;
            ls->flag += flag;
            rs->flag += flag;
            ls->sum += flag * ( midd - l + 1 );
            rs->sum += flag * ( r - midd );
            flag = 0;
        }
    }
    inline void update()
    {
        sum = ls->sum + rs->sum;
    }
}tree[ N * 2 + 5 ],*p = tree,*root;

NODE *build( int l , int r )
{
    NODE *nd = ++p;
    if( l == r )
    {
        nd->sum = nw[ l ];
        return nd;
    }
    int mid = ( l + r ) >> 1;
    nd->ls = build( l , mid );
    nd->rs = build( mid + 1 , r );
    nd->update();
    return nd;
}

ll sum( int l , int r , int x , int y , NODE *nd )
{
    if( x <= l && r <= y )
    {
        return nd->sum;
    }
    nd->pushdown( l , r );
    int mid = ( l + r ) >> 1;
    ll res = 0;
    if( x <= mid )
      res += sum( l , mid , x , y , nd->ls );
    if( y >= mid + 1 )
      res += sum( mid + 1 , r , x , y , nd->rs );
    return res;
}

void modify( int l , int r , int x , int y , ll add , NODE *nd )
{
    if( x <= l && r <= y ) 
    {
        nd->sum += ( r - l + 1 ) * add;
        nd->flag += add;
        return;
    }
    int mid = ( l + r ) >> 1;
    nd->pushdown( l , r );
    if( x <= mid )
        modify( l , mid , x , y , add , nd->ls );
    if( y > mid )
        modify( mid + 1 , r , x , y , add , nd->rs );
    nd->update();
}

void dfs1( int p ){
    size[ p ] = 1;
    deep[ p ] = deep[ fa[ p ] ] + 1;
    for( int i = head[ p ] ; i ; i = e[ i ].next ){
        int k = e[ i ].to;
        if( k == fa[ p ] )
          continue;
        fa[ k ] = p;
        dfs1( k );
        size[ p ] += size[ k ];
        if( size[ son[ p ] ] < size[ k ] || !son[ p ] )
          son[ p ] = k;
    }
}

void dfs2( int p , int tp ){ 
    id[ p ] = ++tot;
    nw[ tot ] = w[ p ];
    top[ p ] = tp;
    if( son[ p ] )
      dfs2( son[ p ] , tp );
    for( int i = head[ p ] ; i ; i = e[ i ].next ){
        int k = e[ i ].to;
        if( k == fa[ p ] || k == son[ p ] ) 
          continue;
        dfs2( k , k );
    }
} 

inline void ope1( int x , int y , ll add ){
    while( top[ x ] != top[ y ] ){
        if( deep[ top[ x ] ] < deep[ top[ y ] ] )
          swap( x , y );
        modify( 1 , n , id[ top[ x ] ] , id[ x ] , add , root );
        x = fa[ top[ x ] ];
    }
    if( deep[ x ] > deep[ y ] )
      swap( x , y );
    modify( 1 , n , id[ x ] , id[ y ] , add , root );
}

inline ll ope2( int x , int y ){
    ll res = 0;
    while( top[ x ] != top[ y ] ){
        if( deep[ top[ x ] ] < deep[ top[ y ] ] )
          swap( x , y );
        res += sum( 1 , n , id[ top[ x ] ] , id[ x ] , root );
        x = fa[ top[ x ] ];
    }
    if( deep[ x ] > deep[ y ] )
      swap( x , y );
    res += sum( 1 , n , id[ x ] , id[ y ] , root );
    return res;
}

inline void ope3( int x , int add ){
    modify( 1 , n , id[ x ] , id[ x ] + size[ x ] - 1 , add , root );
} 

inline ll ope4( int x ){
    return sum( 1 , n , id[ x ] , id[ x ] + size[ x ] - 1 , root );
}

int main()
{
    cin>>n>>m>>rt>>md;
    for( RI i = 1 ; i <= n ; i ++ )
      read( w[ i ] ); 
    for( RI i = 1 ; i <= n - 1 ; i ++ ){
        int x,y;
        read( x ),read( y );
        add( x , y );
        add( y , x );
    }
    dfs1( rt ),tot = 0;
    dfs2( rt , rt );
    root = build( 1 , n );
    for( RI i = 1 ; i <= m ; i ++ ){
        int f;
        read( f );
        switch( f ){
            case 1:{
                int x,y;
                ll add;
                read( x ),read( y ),read( add );
                ope1( x , y , add ); 
                break;
            }
            case 2:{
                int x,y;
                read( x ),read( y );
                printf( "%lld\n" , ope2( x , y ) % md );
                break;
            }
            case 3:{ 
                int x;
                ll add;
                read( x ),read( add );
                ope3( x , add );
                break;
            }
            case 4:{ 
                int x;
                read( x );
                printf( "%lld\n" , ope4( x ) % md );
                break;
            }
        }
    }
    return 0;
}

前置知识

请先能够熟练写出线段树并了解\(dfs\)序的性质

预处理

预处理分两次\(dfs\)
第一次处理出各个结点的深度,\(size\),重儿子,父亲。
第二次处理出重链,\(dfs\)序和每个点的\(top\)
dfs1:

void dfs1( int p ){
    size[ p ] = 1;
    deep[ p ] = deep[ fa[ p ] ] + 1;
    for( int i = head[ p ] ; i ; i = e[ i ].next ){
        int k = e[ i ].to;
        if( k == fa[ p ] )
          continue;
        fa[ k ] = p;
        dfs1( k );
        size[ p ] += size[ k ];
        if( size[ son[ p ] ] < size[ k ] || !son[ p ] )
          son[ p ] = k;
    }
}

dfs2:

void dfs2( int p , int tp ){ 
    id[ p ] = ++tot;//每个点在dfs序里的位置
    nw[ tot ] = w[ p ];
    top[ p ] = tp;
    if( son[ p ] )
      dfs2( son[ p ] , tp );//重链
    for( int i = head[ p ] ; i ; i = e[ i ].next ){
        int k = e[ i ].to;
        if( k == fa[ p ] || k == son[ p ] ) 
          continue;
        dfs2( k , k );//轻链
    }
} 

维护

为了更加高效的查询,我们选择用线段树来维护\(dfs\)序(树状数组等数据结构也可)。
没什么技术含量,直接套模板即可。

struct NODE{
    ll sum,flag;
    NODE *ls,*rs;
    NODE(){
        sum = flag = 0;
        ls = rs = NULL;
    }
    inline void pushdown( int l , int r ) 
    {
        if( flag )
        {
            int midd = ( l + r ) >> 1;
            ls->flag += flag;
            rs->flag += flag;
            ls->sum += flag * ( midd - l + 1 );
            rs->sum += flag * ( r - midd );
            flag = 0;
        }
    }
    inline void update()
    {
        sum = ls->sum + rs->sum;
    }
}tree[ N * 2 + 5 ],*p = tree,*root;

NODE *build( int l , int r )
{
    NODE *nd = ++p;
    if( l == r )
    {
        nd->sum = nw[ l ];
        return nd;
    }
    int mid = ( l + r ) >> 1;
    nd->ls = build( l , mid );
    nd->rs = build( mid + 1 , r );
    nd->update();
    return nd;
}

ll sum( int l , int r , int x , int y , NODE *nd )
{
    if( x <= l && r <= y )
    {
        return nd->sum;
    }
    nd->pushdown( l , r );
    int mid = ( l + r ) >> 1;
    ll res = 0;
    if( x <= mid )
      res += sum( l , mid , x , y , nd->ls );
    if( y >= mid + 1 )
      res += sum( mid + 1 , r , x , y , nd->rs );
    return res;
}

void modify( int l , int r , int x , int y , ll add , NODE *nd )
{
    if( x <= l && r <= y ) 
    {
        nd->sum += ( r - l + 1 ) * add;
        nd->flag += add;
        return;
    }
    int mid = ( l + r ) >> 1;
    nd->pushdown( l , r );
    if( x <= mid )
        modify( l , mid , x , y , add , nd->ls );
    if( y > mid )
        modify( mid + 1 , r , x , y , add , nd->rs );
    nd->update();
}

查询

这是核心操作(敲黑板)。

子树有关操作

子树查询

由于\(dfs\)序的性质,以一个点为根的子树在\(dfs\)序中一定是连续的,所以我们只需要进行一次区间查询,需要查询的区间为:

[根结点在\(dfs\)序中的位置,根结点在\(dfs\)序中的位置+\(size\) - 1 ]

复杂度为\(O(logn)\)
代码如下:

inline ll ope4( int x ){
    return sum( 1 , n , id[ x ] , id[ x ] + size[ x ] - 1 , root );
}

子树修改

同理,进行一次区间修改
复杂度为\(O(logn)\)
代码如下:

inline void ope3( int x , int add ){
    modify( 1 , n , id[ x ] , id[ x ] + size[ x ] - 1 , add , root );
} 

树链有关操作

这才是树剖的精髓所在啊!(战术后仰
这里主要会利用重链在\(dfs\)序中一定是连续的性质,一定要记住,否则你将无法理解接下来的操作

链查询

操作流程:

  • 若两个点的top不同,则让top较深的点爬升到它的topfather,每次爬升进行一次区间查询[2],把结果加到res上,直到top相等为止
  • 此时两点的top为原来两点的LCA,且其中深度较浅的点就是LCA,再进行一次区间查询即可。

最坏时间复杂度\(O(log_{2}n)\)
代码如下:

inline ll ope2( int x , int y ){
    ll res = 0;
    while( top[ x ] != top[ y ] ){
        if( deep[ top[ x ] ] < deep[ top[ y ] ] )//把x调整为top深度更深的的点
          swap( x , y );
        res += sum( 1 , n , id[ top[ x ] ] , id[ x ] , root );
        x = fa[ top[ x ] ];
    }
    if( deep[ x ] > deep[ y ] )
      swap( x , y );
    res += sum( 1 , n , id[ x ] , id[ y ] , root );
    return res;
}

链修改

同理,爬升过程一模一样,只需要将链查询的区间查询改为区间修改即可。

最坏时间复杂度O(log2n)

代码如下:

inline void ope1( int x , int y , ll add ){
    while( top[ x ] != top[ y ] ){
        if( deep[ top[ x ] ] < deep[ top[ y ] ] )
          swap( x , y );
        modify( 1 , n , id[ top[ x ] ] , id[ x ] , add , root );
        x = fa[ top[ x ] ];
    }
    if( deep[ x ] > deep[ y ] )
      swap( x , y );
    modify( 1 , n , id[ x ] , id[ y ] , add , root );
}

  1. 变量名解释
    fa:每个结点的父结点
    deep:每个结点所在位置在树中的深度
    size:以每个结点为根的子树的大小
    son:每个结点的重儿子(即所有儿子中size最大的那个)
    id:每个结点在dfs序中的位置
    w:每个结点的权值
    nw:dfs序中,每个结点的权值
    top:每个点所在的重链的顶端 ↩︎

  2. 链查询和修改的区间为:[id[ top[ x ] ] , id[ x ]],即是这条重链。 ↩︎

posted @ 2019-10-12 08:07  YZhe  阅读(135)  评论(0编辑  收藏  举报
ヾ(≧O≦)〃嗷~