树链剖分

树链剖分

树链剖分是为了解决树上链信息维护的问题。
利用线段树等结构,可以维护数组中的增删改查问题。
树链剖分的作用其实就是将整颗结构树分解成若干段数组,然后使用一个数据结构来维护每一条链的信息
(一般来讲不会使用平衡树去维护树链,不然你还不如直接上LCT呢)

如果是数组上的问题,一般直接使用线段树来维护数组的区间信息。那么如何把这种思路扩展到树上呢?

链分解

先从最简单的问题入手,假设整棵树只有一个分叉,我们就叫它“Y树”好了。

右边的图就是一个所谓的”Y树”,如果现在问题比较简单,只需要你维护这个“Y树”的信息,查询某两个点的简单路径信息。
维护方法也比较多
可以左边的链建一颗线段树,右边的链建一颗线段树。
也可以把中间的交叉点拿出来作为根建三颗线段树。
像这样,把一颗“树”拆解成若干个连续的“数组”,且每一部分互不包含就称之为是一个树的链分解。
如何随便的给出一个树的链分解呢,很简单,利用dfn,也就是多叉树的dfs序。
在dfs的过程中从当前节点一路向下到叶子的过程天然的就是一种链分解。
我们定义top[]数组为链分解的链首,dfs的过程中对于第一个孩子继承当前节点的top值,否则另起新链即可。
这个代码非常简单,就是一个dfs就写完了。

链分解和树链剖分

问题来了,既然树的任意链分解这么好写,而且貌似这样也可以把树变成若干个连续的数组,为什么还需要轻重树链剖分这样相对逻辑复杂的链分解?
我们看一种叫做“毛毛虫树”的结构。
image
所谓“毛毛虫”树就是指伸展度较低的树。
所谓“伸展度t”是指在造树的时候往往需要经过以下几个步骤。
1、连边i<->random(min(max(1+t,1),i-1),min(max(i+t,1),i-1))(大概是这个意思,实际上伸展度只是一个影响随机分布的参数)
2、节点随机重标号
3、边随机排列
伸展度越大,全局树深度越小,当伸展度=INF时,树退化为菊花。
伸展度越小,全局树深度越大,当伸展度=-INF时,树退化为单链。
毛毛虫树是伸展度较小,并且不为-INF的情况。
这种情况下会使得随机dfs的链分解出的链非常短。

轻重树链剖分

所以在做链分解之前,首先要做的一个预处理过程是对于子树轻重的划分。
首先利用树DP求出所有子树的大小size。
定义重儿子为当前节点中所有子树size最大的子树。
定义轻儿子为当前节点中所有子树size最小的子树。
定理:轻儿子的size必定小于等于当前子树size的一半。
定理的扩展:树上每经过一条轻边后,子树尺寸缩减一半。

在对树进行链分解以后,借助面向对象的思想,我们把一整个重链看成一个整体,作为一个大的“节点”看待,树上的轻链看成是结构树的部分。
image
这样的话,整个树就被“压缩”了。
可以证明,压缩后的结构树,深度不超过log(N)

树链剖分实现LCA

树链剖分后的结构树深度可控
直接暴力复杂度就是O(log)级别的

树链剖分

所以说树链剖分的本质是,对于结构树的压缩重构,使得在新的结构树上面暴力移动的复杂度不多于log(即进行树链剖分后重链形成的结构树,树的直径变为logN)。
树链剖分本身不具备任何维护信息的功能,换句话说树链剖分本身不是数据结构,只是利用dfs对树的整理过程。

时间复杂度

对于路径上的更新权值/求和的时间复杂度是\(O(log^2n)\)

因为重链的条数不超过\(log_2 n\)条(一般的话数量可能远远低于这个值),每次线段树维护的时间复杂度是\(O(log n)\)的,所以时间复杂度是\(O(log^2n)\)
子树修改/查询时间复杂度:\(O(log n)\),很明显,因为子树的dfs序编号是连续的,时间复杂度就是线段树做区间查询/修改的复杂度

例题

1.【模板】轻重链剖分/树链剖分

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#include<stack>
#include<map>
#include<set>
#define ll long long 
#define pa pair<int,int>
using namespace std;
const int maxn=4e5+101;
const int inf=2147483647;
const double eps=1e-9;

int read(){
    int x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}
int n,m,r,MOD,a[maxn],w[maxn];
struct SegTree{             //线段树
    int tr[maxn<<1],lz[maxn<<1];
    void build(int k,int l,int r){
        if(l==r){
            tr[k]=w[l];     //w[]是经过id数组转化的a[]
            return ;
        }
        int mid=(l+r)>>1;
        build(k<<1,l,mid);build(k<<1|1,mid+1,r);
        tr[k]=(tr[k<<1]+tr[k<<1|1])%MOD;
        return ;
    }

    void add(int k,int val,int l,int r){
        lz[k]=(lz[k]+val)%MOD;
        tr[k]=(tr[k]+val*(r-l+1)%MOD)%MOD;
        return ;
    }

    void pushdown(int k,int l,int r){
        if(!lz[k])return ;
        int mid=(l+r)>>1;
        add(k<<1,lz[k],l,mid);add(k<<1|1,lz[k],mid+1,r);
        lz[k]=0;
        return ;
    }

    void modify(int k,int l,int r,int L,int R,int val){
        pushdown(k,l,r);
        if(r<L || l>R)return ;
        if(L<=l && r<=R){
            add(k,val,l,r);
            return ;
        }
        int mid=(l+r)>>1;
        modify(k<<1,l,mid,L,R,val);modify(k<<1|1,mid+1,r,L,R,val);
        tr[k]=(tr[k<<1]+tr[k<<1|1])%MOD;
        return ;
    }

    int query(int k,int l,int r,int L,int R){
        pushdown(k,l,r);
        if(r<L || l>R)return 0;
        if(L<=l && r<=R)return tr[k];
        int mid=(l+r)>>1;
        return (query(k<<1,l,mid,L,R)+query(k<<1|1,mid+1,r,L,R))%MOD;
    }
}ST;

int tot,head[maxn],nx[maxn],to[maxn];
void add(int x,int y){to[++tot]=y;nx[tot]=head[x];head[x]=tot;}

struct Tree{            //树链剖分
    int cnt,id[maxn];
    int fa[maxn],sz[maxn],son[maxn],top[maxn],dep[maxn];
    void dfs1(int x){
        sz[x]=1;dep[x]=dep[fa[x]]+1;
        for(int i=head[x];i;i=nx[i]){
            int v=to[i];if(v==fa[x])continue;
            fa[v]=x;dfs1(v);sz[x]+=sz[v];
            if(sz[son[x]]<sz[v])son[x]=v;
        }
        return ;
    }
    void dfs2(int x){
        id[x]=++cnt;top[son[x]]=top[x];
        w[cnt]=a[x];        //根据情况添加
        if(!son[x])return ;
        dfs2(son[x]);
        for(int i=head[x];i;i=nx[i]){
            int v=to[i];if(v==son[x] || v==fa[x])continue;
            top[v]=v;dfs2(v);
        }
        return ;
    }

    void change(int x,int y,int z){
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]])swap(x,y);
            ST.modify(1,1,n,id[top[x]],id[x],z);
            x=fa[top[x]];
        }
        if(dep[x]>dep[y])swap(x,y);
        ST.modify(1,1,n,id[x],id[y],z);
        return ;
    }

    int getans(int x,int y){
        int ans=0;
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]])swap(x,y);
            ans+=ST.query(1,1,n,id[top[x]],id[x]);
            ans%=MOD;x=fa[top[x]];
        }
        if(dep[x]>dep[y])swap(x,y);
        ans+=ST.query(1,1,n,id[x],id[y]);ans%=MOD;
        return ans;
    }

}T;

int main(){
    n=read();m=read();r=read();MOD=read();
    for(int i=1;i<=n;i++)a[i]=read();
    for(int i=1;i<n;i++){
        int x=read(),y=read();
        add(x,y);add(y,x);
    }
    T.dfs1(r);T.top[r]=r;T.dfs2(r);ST.build(1,1,n);
    while(m--){
        int opt=read(),x=read();
        if(opt==1){
            int y=read(),z=read();
            T.change(x,y,z);
        }
        else if(opt==2){
            int y=read();
            printf("%d\n",T.getans(x,y));
        }
        else if(opt==3){
            int z=read();
            ST.modify(1,1,n,T.id[x],T.id[x]+T.sz[x]-1,z);
        }
        else printf("%d\n",ST.query(1,1,n,T.id[x],T.id[x]+T.sz[x]-1));
    }
    return 0;
}

2.软件包管理器
install x-->把根到x点的权值全设为1
uninstall x-->把x的子树节点权值全设为0
(代码中MOD懒得删了)

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#include<stack>
#include<map>
#include<set>
#define ll long long 
#define pa pair<int,int>
using namespace std;
const int maxn=2e6+101;
const int MOD=201314;
const int inf=2147483647;
const double eps=1e-9;

int read(){
    int x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}
int n,q,a[maxn],w[maxn];
struct SegTree{             //线段树
    int tr[maxn<<1],lz[maxn<<1];
    void build(int k,int l,int r){
        lz[k]=0;
        if(l==r){tr[k]=w[l];return ;}
        int mid=(l+r)>>1;
        build(k<<1,l,mid);build(k<<1|1,mid+1,r);
        tr[k]=(tr[k<<1]+tr[k<<1|1])%MOD;
        return ;
    }

    void add(int k,int val,int l,int r){
        if(val==-1){
            lz[k]=-1;
            tr[k]=0;
        }
        else {
            lz[k]=1;
            tr[k]=r-l+1;
        }
        return ;
    }

    void pushdown(int k,int l,int r){
        if(!lz[k])return ;
        int mid=(l+r)>>1;
        add(k<<1,lz[k],l,mid);add(k<<1|1,lz[k],mid+1,r);
        lz[k]=0;
        return ;
    }

    void modify(int k,int l,int r,int L,int R,int val){
        pushdown(k,l,r);
        if(r<L || l>R)return ;
        if(L<=l && r<=R){
            add(k,val,l,r);
            return ;
        }
        int mid=(l+r)>>1;
        modify(k<<1,l,mid,L,R,val);modify(k<<1|1,mid+1,r,L,R,val);
        tr[k]=(tr[k<<1]+tr[k<<1|1])%MOD;
        return ;
    }

    int query(int k,int l,int r,int L,int R){
        pushdown(k,l,r);
        if(r<L || l>R)return 0;
        if(L<=l && r<=R)return tr[k];
        int mid=(l+r)>>1;
        return (query(k<<1,l,mid,L,R)+query(k<<1|1,mid+1,r,L,R))%MOD;
    }
}ST;

int tot,head[maxn],nx[maxn],to[maxn];
void add(int x,int y){to[++tot]=y;nx[tot]=head[x];head[x]=tot;}

struct Tree{            //树链剖分
    int cnt,id[maxn];
    int fa[maxn],sz[maxn],son[maxn],top[maxn],dep[maxn];
    void dfs1(int x){
        sz[x]=1;dep[x]=dep[fa[x]]+1;
        for(int i=head[x];i;i=nx[i]){
            int v=to[i];if(v==fa[x])continue;
            fa[v]=x;dfs1(v);sz[x]+=sz[v];
            if(sz[son[x]]<sz[v])son[x]=v;
        }
        return ;
    }
    void dfs2(int x){
        id[x]=++cnt;top[son[x]]=top[x];
        //w[cnt]=a[i];
        if(!son[x])return ;
        dfs2(son[x]);
        for(int i=head[x];i;i=nx[i]){
            int v=to[i];if(v==son[x] || v==fa[x])continue;
            top[v]=v;dfs2(v);
        }
        return ;
    }

    void change(int x,int y,int z){
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]])swap(x,y);
            ST.modify(1,1,n,id[top[x]],id[x],z);
            x=fa[top[x]];
        }
        if(dep[x]>dep[y])swap(x,y);
        ST.modify(1,1,n,id[x],id[y],z);
        return ;
    }

    int getans(int x,int y){
        int ans=0;
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]])swap(x,y);
            ans+=ST.query(1,1,n,id[top[x]],id[x]);
            ans%=MOD;x=fa[top[x]];
        }
        if(dep[x]>dep[y])swap(x,y);
        ans+=ST.query(1,1,n,id[x],id[y]);ans%=MOD;
        return ans;
    }

}T;

int main(){
    n=read();
    for(int i=2;i<=n;i++){
        int x=read()+1;
        add(i,x);add(x,i);
    }
    T.dfs1(1);T.top[1]=1;T.dfs2(1);ST.build(1,1,n);
    q=read();
    while(q--){
        string s;cin>>s;
        int x=read()+1;
        if(s=="install"){
            printf("%d\n",T.dep[x]-T.getans(1,x));
            T.change(1,x,1);
        }
        else {
            printf("%d\n",ST.query(1,1,n,T.id[x],T.id[x]+T.sz[x]-1));
            ST.modify(1,1,n,T.id[x],T.id[x]+T.sz[x]-1,-1);
        }
    }
    return 0;
}

3. [LNOI2014] LCA
直接查询lca肯定不行
不妨考虑dep有什么性质,dep定义是深度即根到节点的长度。由lca性质,我们知道,对于[l,r]区间里的任意i,它与z的lca一定是z的祖先或者z自己。
考虑i对lca的dep的贡献,我们把i到根的路径上的所有结点+1,对区间所以i操作完后,dep之和即为z到根路径上的权值和。
若查询一个点z和区间中点的LCA深度和就可以对于区间每个点到根节点的路径上点权值+1,然后查询1~z路径上数字和即可。
那么区间[l,r]答案可以转化为z到[1,r]的ans减去z到[1,l-1]的ans
这道题可以离线做,对于每个询问,先把[1,l-1]的答案减去,再把[1,r]的答案加上

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#include<stack>
#include<map>
#include<set>
#define ll long long 
#define pa pair<int,int>
using namespace std;
const int maxn=4e5+101;
const int MOD=201314;
const int inf=2147483647;
const double eps=1e-9;

int read(){
    int x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}
int n,q,w[maxn];
struct SegTree{             //线段树
    int tr[maxn<<1],lz[maxn<<1];
    void build(int k,int l,int r){
        lz[k]=0;
        if(l==r){tr[k]=w[l];return ;}
        int mid=(l+r)>>1;
        build(k<<1,l,mid);build(k<<1|1,mid+1,r);
        tr[k]=(tr[k<<1]+tr[k<<1|1])%MOD;
        return ;
    }

    void add(int k,int val,int l,int r){
        lz[k]=(lz[k]+val)%MOD;
        tr[k]=(tr[k]+val*(r-l+1)%MOD)%MOD;
        return ;
    }

    void pushdown(int k,int l,int r){
        if(!lz[k])return ;
        int mid=(l+r)>>1;
        add(k<<1,lz[k],l,mid);add(k<<1|1,lz[k],mid+1,r);
        lz[k]=0;
        return ;
    }

    void modify(int k,int l,int r,int L,int R,int val){
        pushdown(k,l,r);
        if(r<L || l>R)return ;
        if(L<=l && r<=R){
            add(k,val,l,r);
            return ;
        }
        int mid=(l+r)>>1;
        modify(k<<1,l,mid,L,R,val);modify(k<<1|1,mid+1,r,L,R,val);
        tr[k]=(tr[k<<1]+tr[k<<1|1])%MOD;
        return ;
    }

    int query(int k,int l,int r,int L,int R){
        pushdown(k,l,r);
        if(r<L || l>R)return 0;
        if(L<=l && r<=R)return tr[k];
        int mid=(l+r)>>1;
        return (query(k<<1,l,mid,L,R)+query(k<<1|1,mid+1,r,L,R))%MOD;
    }
}ST;

int tot,head[maxn],nx[maxn],to[maxn];
void add(int x,int y){to[++tot]=y;nx[tot]=head[x];head[x]=tot;}

struct Tree{            //树链剖分
    int cnt,id[maxn];
    int fa[maxn],sz[maxn],son[maxn],top[maxn],dep[maxn];
    void dfs1(int x){
        sz[x]=1;dep[x]=dep[fa[x]]+1;
        for(int i=head[x];i;i=nx[i]){
            int v=to[i];if(v==fa[x])continue;
            fa[v]=x;dfs1(v);sz[x]+=sz[v];
            if(sz[son[x]]<sz[v])son[x]=v;
        }
        return ;
    }
    void dfs2(int x){
        id[x]=++cnt;top[son[x]]=top[x];
        if(!son[x])return ;
        dfs2(son[x]);
        for(int i=head[x];i;i=nx[i]){
            int v=to[i];if(v==son[x] || v==fa[x])continue;
            top[v]=v;dfs2(v);
        }
        return ;
    }

    void change(int x,int y,int z){
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]])swap(x,y);
            ST.modify(1,1,n,id[top[x]],id[x],z);
            x=fa[top[x]];
        }
        if(dep[x]>dep[y])swap(x,y);
        ST.modify(1,1,n,id[x],id[y],z);
        return ;
    }

    int getans(int x,int y){
        int ans=0;
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]])swap(x,y);
            ans+=ST.query(1,1,n,id[top[x]],id[x]);
            ans%=MOD;x=fa[top[x]];
        }
        if(dep[x]>dep[y])swap(x,y);
        ans+=ST.query(1,1,n,id[x],id[y]);ans%=MOD;
        return ans;
    }

}T;

struct wzq{
    int l,r,z;
    int id,ans;
}p[maxn];
int main(){
    n=read();q=read();
    for(int i=2;i<=n;i++){
        int x=read()+1;add(x,i);add(i,x);
    }
    T.dfs1(1);T.top[1]=1;T.dfs2(1);ST.build(1,1,n);
    for(int i=1;i<=q;i++){
        p[i].l=read()+1;p[i].r=read()+1;p[i].z=read()+1;p[i].id=i;
    }
    sort(p+1,p+q+1,[](wzq i,wzq j){return i.l<j.l;});
    int l=1;
    for(int i=1;i<=q;i++){
        while(p[i].l>l)T.change(1,l++,1);
        p[i].ans-=T.getans(1,p[i].z);
        p[i].ans%=MOD;
    }
    ST.build(1,1,n);l=1;
    sort(p+1,p+q+1,[](wzq i,wzq j){return i.r<j.r;});
    for(int i=1;i<=q;i++){
        while(p[i].r>=l)T.change(1,l++,1);
        p[i].ans+=T.getans(1,p[i].z);p[i].ans%=MOD;
    }
    sort(p+1,p+q+1,[](wzq i,wzq j){return i.id<j.id;});
    for(int i=1;i<=q;i++)printf("%d\n",(p[i].ans%MOD+MOD)%MOD);
    return 0;
}

4.小 Q 与树
考虑min不太好处理,我们按照点权从大到小加入
设权值第i大的点为\(rk_i\)
那么这个点和前面所有点产生的贡献是\(2*a_{rk_i}*\sum_{k=1}^i dis[lca(rk_i,rk_k)]\)

前两个都可以\(O(1)\)求出
\(\sum_{k=1}^I dis[lca(rk_i,rk_k)]\)就跟上一题一样,每加入一个点,就把当前点到根的链权值加1,最后询问\(rk_i\)到根链的和即可

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#include<stack>
#include<map>
#define ll long long 
#define pa pair<int,int>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
using namespace std;
const int maxn=800000+101;
const int MOD=998244353;
const int inf=2147483647;
const double pi=acos(-1);
int read(){
    int x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}

int n,m,w[maxn];
struct SegTree{             //线段树
    int tr[maxn<<1],lz[maxn<<1];
    void build(int k,int l,int r){
        if(l==r){
            tr[k]=w[l];     //w[]是经过id数组转化的a[]
            return ;
        }
        int mid=(l+r)>>1;
        build(k<<1,l,mid);build(k<<1|1,mid+1,r);
        tr[k]=(tr[k<<1]+tr[k<<1|1])%MOD;
        return ;
    }

    void add(int k,int val,int l,int r){
        lz[k]=(lz[k]+val)%MOD;
        tr[k]=(tr[k]+val*(r-l+1)%MOD)%MOD;
        return ;
    }

    void pushdown(int k,int l,int r){
        if(!lz[k])return ;
        int mid=(l+r)>>1;
        add(k<<1,lz[k],l,mid);add(k<<1|1,lz[k],mid+1,r);
        lz[k]=0;
        return ;
    }

    void modify(int k,int l,int r,int L,int R,int val){
        pushdown(k,l,r);
        if(r<L || l>R)return ;
        if(L<=l && r<=R){
            add(k,val,l,r);
            return ;
        }
        int mid=(l+r)>>1;
        modify(k<<1,l,mid,L,R,val);modify(k<<1|1,mid+1,r,L,R,val);
        tr[k]=(tr[k<<1]+tr[k<<1|1])%MOD;
        return ;
    }

    int query(int k,int l,int r,int L,int R){
        pushdown(k,l,r);
        if(r<L || l>R)return 0;
        if(L<=l && r<=R)return tr[k];
        int mid=(l+r)>>1;
        return (query(k<<1,l,mid,L,R)+query(k<<1|1,mid+1,r,L,R))%MOD;
    }
}ST;

int tot,head[maxn],nx[maxn],to[maxn];
void add(int x,int y){to[++tot]=y;nx[tot]=head[x];head[x]=tot;}

struct Tree{            //树链剖分
    int cnt,id[maxn];
    int fa[maxn],sz[maxn],son[maxn],top[maxn],dep[maxn];
    void dfs1(int x){
        sz[x]=1;dep[x]=dep[fa[x]]+1;
        for(int i=head[x];i;i=nx[i]){
            int v=to[i];if(v==fa[x])continue;
            fa[v]=x;dfs1(v);sz[x]+=sz[v];
            if(sz[son[x]]<sz[v])son[x]=v;
        }
        return ;
    }
    void dfs2(int x){
        id[x]=++cnt;top[son[x]]=top[x];
       // w[cnt]=a[x];        //根据情况添加
        if(!son[x])return ;
        dfs2(son[x]);
        for(int i=head[x];i;i=nx[i]){
            int v=to[i];if(v==son[x] || v==fa[x])continue;
            top[v]=v;dfs2(v);
        }
        return ;
    }

    void change(int x,int y,int z){
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]])swap(x,y);
            ST.modify(1,1,n,id[top[x]],id[x],z);
            x=fa[top[x]];
        }
        if(dep[x]>dep[y])swap(x,y);
        ST.modify(1,1,n,id[x],id[y],z);
        return ;
    }

    ll getans(int x,int y){
        int ans=0;
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]])swap(x,y);
            ans+=ST.query(1,1,n,id[top[x]],id[x]);
            ans%=MOD;x=fa[top[x]];
        }
        if(dep[x]>dep[y])swap(x,y);
        ans+=ST.query(1,1,n,id[x],id[y]);ans%=MOD;
        return ans;
    }

}T;

struct wzq{
    int id;
    ll val;
}a[maxn];
ll ans=0;
int main(){
    n=read();
    for(int i=1;i<=n;i++){
        a[i].val=read();a[i].id=i;
    }
    for(int i=1;i<n;i++){
        int x=read(),y=read();
        add(x,y);add(y,x);
    }
    T.dfs1(1);T.top[1]=1;T.dfs2(1);ST.build(1,1,n);
    sort(a+1,a+n+1,[](wzq i,wzq j){return i.val>j.val;});
    ll pre=0;
    for(ll i=1;i<=n;i++){
        T.change(1,a[i].id,1);
        pre+=T.dep[a[i].id];
        ans+=a[i].val*(i*T.dep[a[i].id]+pre-2ll*T.getans(1,a[i].id));
        ans%=MOD;
    }
    ans=ans*2%MOD;
    printf("%lld\n",(ans%MOD+MOD)%MOD);
    return 0;
}

5.树上路径
注意3操作是任何两个节点的乘积和,uv=vu是一种情况
考虑在线段树上如何做
有一个性质是\((区间和)^2-(区间每个数的平方的和)=区间内两两节点的乘积和\)
那么线段树维护每个链的区间和和区间数平方和

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#include<stack>
#include<map>
#include<set>
#define ll long long 
#define pa pair<ll,ll>
#define mp make_pair
#define se second
#define fi first
using namespace std;
const int maxn=4e5+101;
const int inf=2147483647;
const int MOD=1e9+7;
const double eps=1e-9;

int read(){
    int x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}
ll power(ll x,ll y){
    ll ans=1;
    while(y){
        if(y&1)ans=ans*x%MOD;
        y>>=1;x=x*x%MOD;
    }
    return ans;
}
int n,m,r=1,a[maxn],w[maxn];
struct SegTree{             //线段树
    ll lz[maxn<<1];
    struct wzq{
        ll sum,fas;
        wzq operator+(wzq x){
            wzq now;
            now.sum=(sum+x.sum)%MOD;
            now.fas=(fas+x.fas)%MOD;
            return now;
        } 
    }tr[maxn<<1];
    void build(int k,int l,int r){
        if(l==r){
            tr[k].sum=w[l];     
            tr[k].fas=w[l]*w[l]%MOD;
            return ;
        }
        int mid=(l+r)>>1;
        build(k<<1,l,mid);build(k<<1|1,mid+1,r);
        tr[k]=tr[k<<1]+tr[k<<1|1];
        return ;
    }

    void add(int k,ll val,int l,int r){
        (lz[k]+=val)%=MOD;
        (tr[k].fas+=((r-l+1)*val%MOD)*val%MOD+2*val*tr[k].sum%MOD)%=MOD;
        (tr[k].sum+=val*(r-l+1)%MOD)%=MOD;
        return ;
    }

    void pushdown(int k,int l,int r){
        if(!lz[k])return ;
        int mid=(l+r)>>1;
        add(k<<1,lz[k],l,mid);add(k<<1|1,lz[k],mid+1,r);
        lz[k]=0;
        return ;
    }

    void modify(int k,int l,int r,int L,int R,ll val){
        pushdown(k,l,r);
        if(r<L || l>R)return ;
        if(L<=l && r<=R){
            add(k,val,l,r);
            return ;
        }
        int mid=(l+r)>>1;
        modify(k<<1,l,mid,L,R,val);modify(k<<1|1,mid+1,r,L,R,val);
        tr[k]=tr[k<<1]+tr[k<<1|1];
        return ;
    }

    pa query(int k,int l,int r,int L,int R){
        pushdown(k,l,r);
        if(r<L || l>R)return mp(0,0);
        if(L<=l && r<=R)return mp(tr[k].sum,tr[k].fas);
        int mid=(l+r)>>1;
        pa ls=query(k<<1,l,mid,L,R),rc=query(k<<1|1,mid+1,r,L,R);
        return mp(ls.fi+rc.fi,ls.se+rc.se);
    }
}ST;

int tot,head[maxn],nx[maxn],to[maxn];
void add(int x,int y){to[++tot]=y;nx[tot]=head[x];head[x]=tot;}

struct Tree{            //树链剖分
    int cnt,id[maxn];
    int fa[maxn],sz[maxn],son[maxn],top[maxn],dep[maxn];
    void dfs1(int x){
        sz[x]=1;dep[x]=dep[fa[x]]+1;
        for(int i=head[x];i;i=nx[i]){
            int v=to[i];if(v==fa[x])continue;
            fa[v]=x;dfs1(v);sz[x]+=sz[v];
            if(sz[son[x]]<sz[v])son[x]=v;
        }
        return ;
    }
    void dfs2(int x){
        id[x]=++cnt;top[son[x]]=top[x];
        w[cnt]=a[x];        //根据情况添加
        if(!son[x])return ;
        dfs2(son[x]);
        for(int i=head[x];i;i=nx[i]){
            int v=to[i];if(v==son[x] || v==fa[x])continue;
            top[v]=v;dfs2(v);
        }
        return ;
    }

    void change(int x,int y,int z){
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]])swap(x,y);
            ST.modify(1,1,n,id[top[x]],id[x],z);
            x=fa[top[x]];
        }
        if(dep[x]>dep[y])swap(x,y);
        ST.modify(1,1,n,id[x],id[y],z);
        return ;
    }

    pa getans(int x,int y){
        pa ans;
        while(top[x]!=top[y]){
            if(dep[top[x]]<dep[top[y]])swap(x,y);
            pa now=ST.query(1,1,n,id[top[x]],id[x]);
            ans.se+=now.se;ans.fi+=now.fi;
            ans.se%=MOD;ans.fi%=MOD;
           // cout<<ans.fi<<" "<<ans.se<<endl;
            x=fa[top[x]];
        }
        if(dep[x]>dep[y])swap(x,y);
        pa now=ST.query(1,1,n,id[x],id[y]);
        ans.se+=now.se;ans.fi+=now.fi;
        ans.se%=MOD;ans.fi%=MOD;
        return ans;
    }

}T;

int main(){
    n=read();m=read();
    for(int i=1;i<=n;i++)a[i]=read();
    for(int i=1;i<n;i++){
        int x=read(),y=read();
        add(x,y);add(y,x);
    }
    T.dfs1(1);T.top[1]=1;T.dfs2(1);ST.build(1,1,n);
    while(m--){
        int opt=read(),u,v,val;
        if(opt==1){
            u=read();val=read();
            ST.modify(1,1,n,T.id[u],T.id[u]+T.sz[u]-1,val);
        }
        else if(opt==2){
            u=read(),v=read();val=read();
            T.change(u,v,val);
        }
        else {
            u=read();v=read();
            pa now=T.getans(u,v);
            ll ans=(now.fi*now.fi)%MOD-now.se;
            ans%=MOD;ans=ans*power(2,MOD-2)%MOD;
            printf("%lld\n",(ans%MOD+MOD)%MOD);
        }
    }
    return 0;
}

带顺序的信息维护

在暴力爬lca过程中,一般顺序都是u到lca和v到lca,两个顺序是相反的,对于矩阵相乘一些问题是不行的。
对于树链剖分后,带有从u到v的顺序关系的信息,维护起来相对麻烦一点,在于LCA“左侧”和LCA“右侧”的处理不同。
典型的例如树上的旅行经商问题,树上的矩阵链乘问题。
一般来说你全都遵循一个“左上,右下”的顺序关系比较好,避免给自己挖坑。

总结

我非常建议你把树链剖分理解成“结构树套数据结构”的这种“树套树”的理解。
原因是这对于后面学习动态树,乃至点分树这种巨型数据结构,在理解上是有好处的。

posted @ 2022-07-08 15:31  I_N_V  阅读(54)  评论(0编辑  收藏  举报