[模板]树链剖分

link

 树链剖分,感觉是一个很神奇的东西,但是其实并不是那样的

树链剖分其实就是一个线段树

线段树处理的是连续区间,所以当你要加的时候都是连续区间修改

所以可以用轻重链的方式将树分解成为链条,然后用线段树处理

可以很容易看到,为什么用的是dfs但不是用的是bfs呢

因为dfs保持了重链是连续的,所以可以用top[x]记录已x为节点的重链最上方,一个点也包含在重链内

若修改区间为(u,v),但是重链的祖先是一起的,所以当他们的LCA相同时,边break

所以现在u,v是连续的

所以查询(u,v)的简单路径和也就处理了

所以说线段树中可以进行的操作在树上也可以执行了

在处理一个问题

在u的子树上加w

所以修改的区间是u在线段树中的位置$(t)$ 到 $t+size(u)-1$

$size$ 记录以它为根 的子节点个数

$deep(x)$  深度

$father(x)$  记录父亲

$son(x)$  它的重儿子

$top(x)$ 所在重路径的顶部节点

$seg(x)$ x在线段树中的编号

$rev(x)$ 线段树中x的位置所对应的树中节点编号

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
inline int read()
{
    int f=1,ans=0;char c;
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){ans=ans*10+c-'0';c=getchar();}
    return ans*f;
}
int n,val[1200001];
struct node{
    int u,v,nex;
}x[1200001];
int head[1200001],cnt;
int size[1200001];
int deep[1200001];
int father[1200001];
int son[1200001];
int top[1200001];
int seg[1200001];
int rev[1200001];
int q,root,mod;
void dfs1(int f,int fath)
{
    deep[f]=deep[fath]+1;
    father[f]=fath;
    size[f]=1;
    for(int i=head[f];i!=-1;i=x[i].nex)
    {
        if(x[i].v==fath) continue;
        dfs1(x[i].v,f);
        size[f]+=size[x[i].v];
        if(size[x[i].v]>size[son[f]]) son[f]=x[i].v;
    }
    return;
}
void dfs2(int f,int fath)
{
    if(son[f])
    {
        top[son[f]]=top[f];
        seg[son[f]]=++seg[0];
        rev[seg[0]]=son[f];
        dfs2(son[f],f);
    }
    for(int i=head[f];i!=-1;i=x[i].nex)
    {
        if(x[i].v==fath) continue;
        if(top[x[i].v]) continue;
        top[x[i].v]=x[i].v;
        seg[x[i].v]=++seg[0];
        rev[seg[0]]=x[i].v;
        dfs2(x[i].v,f);
    }
    return;
}
void add(int u,int v)
{
    x[cnt].u=u,x[cnt].v=v,x[cnt].nex=head[u],head[u]=cnt++;
}
int ans[1200001],sum[1200001];
void build(int k,int l,int r)
{
    if(l==r) 
    {
        ans[k]=val[rev[l]];
        return; 
    }
    int mid=l+r>>1;
    build(k<<1,l,mid);
    build(k<<1|1,mid+1,r);
    ans[k]=ans[k<<1]+ans[k<<1|1];
    return; 
}
void push_down(int k,int l,int r)
{
    int mid=l+r>>1;
    ans[k<<1]+=sum[k]*(mid-l+1);sum[k<<1]%=mod;
    sum[k<<1]+=sum[k];sum[k<<1]%=mod;
    
    ans[k<<1|1]+=sum[k]*(r-mid);ans[k<<1|1]%=mod;
    sum[k<<1|1]+=sum[k];sum[k<<1|1]%=mod;
    
    sum[k]=0;
    return; 
}
void add(int k,int l,int r,int x,int y,int v)
{
    if(x<=l&&r<=y){
        sum[k]+=v;
        sum[k]%=mod;
        ans[k]+=((r-l+1)%mod)*v%mod;
        ans[k]%=mod;
        return;
    } 
    push_down(k,l,r);
    int mid=l+r>>1;
    if(x<=mid) add(k<<1,l,mid,x,y,v);
    if(mid<y) add(k<<1|1,mid+1,r,x,y,v);
    ans[k]=ans[k<<1]+ans[k<<1|1];
    ans[k]%=mod;
}
void ask_add(int x,int y,int w)
{
    int fx=top[x],fy=top[y];
    while(fx!=fy)
    {
        if(deep[fx]<deep[fy]) swap(x,y),swap(fx,fy);
        add(1,1,seg[0],seg[fx],seg[x],w%mod); 
        x=father[fx],fx=top[x];
    }
    if(deep[x]>deep[y]) swap(x,y);
    add(1,1,seg[0],seg[x],seg[y],w);
}
int summ;
int query(int k,int l,int r,int x,int y)
{
    if(x<=l&&r<=y) return ans[k]%mod;
    push_down(k,l,r);
    int res=0,mid=l+r>>1;
    if(x<=mid) res+=query(k<<1,l,mid,x,y)%mod;
    if(mid<y) res+=query(k<<1|1,mid+1,r,x,y)%mod;
    return res;
}
int ask(int x,int y)
{
    summ=0;
    int fx=top[x],fy=top[y];
    while(fx!=fy)
    {
        if(deep[fx]<deep[fy]) swap(x,y),swap(fx,fy);
        summ+=query(1,1,seg[0],seg[fx],seg[x])%mod; 
        x=father[fx],fx=top[x];
    }
    if(deep[x]>deep[y]) swap(x,y);
    summ+=query(1,1,seg[0],seg[x],seg[y])%mod;
    return summ%mod;
}
int main()
{
    memset(head,-1,sizeof(head));
    n=read(),q=read(),root=read(),mod=read();
    for(int i=1;i<=n;i++) val[i]=read(); 
    for(int i=1;i<n;i++) 
    {
        int u=read(),v=read();
        add(u,v),add(v,u);
    }
    dfs1(root,0);
    seg[0]=1;seg[root]=1;
    top[root]=root;
    rev[1]=root;
    dfs2(root,0);
    build(1,1,seg[0]);
    while(q--)
    {
        int s=read();
        if(s==1) 
        {
            int u=read(),v=read();
            int w=read();
            ask_add(u,v,w%mod);
        }
        if(s==3)
        {
            summ=0;
            int u=read(),v=read();
            add(1,1,seg[0],seg[u],seg[u]+size[u]-1,v%mod);
        }
        if(s==2)
        {
            int u=read(),v=read();
            printf("%lld\n",ask(u,v)%mod);
        }
        if(s==4)
        {
            int u=read();
            printf("%lld\n",query(1,1,seg[0],seg[u],seg[u]+size[u]-1)%mod);
        }
    }
    return 0;
}
View Code
posted @ 2018-09-25 20:57  siruiyang_sry  阅读(173)  评论(0编辑  收藏  举报