树链剖分学习

        近期学习树链剖分,在网上搜了非常多文章,这里推荐一篇博客:点击打开链接

        这篇博客讲的非常细致。在了解了基本原理之后,能够学习一下kuangbin大大的模板:


定义部分:

struct Edge
{
    int to,next;
}edge[2*maxn];//树的边

int head[maxn],tot;//邻接表
int top[maxn];//节点所在重链的最高点
int fa[maxn];//节点的父亲节点
int deep[maxn];//节点的深度
int num[maxn];//节点子树的节点数
int p[maxn];//p与父节点间连线在线段树中的位置
int fp[maxn];//线段树中fp位置相应的的节点
int tson[maxn];//重儿子
int pos;//线段树中的位置


初始化部分:

void init()   //初始化操作
{
    tot=1;
    memset(head,-1,sizeof(head));
    pos=1;
    memset(tson,-1,sizeof(tson));
}


邻接表的加边操作:

void addedge(int u,int v)     //邻接表的加边操作
{
    edge[tot].to=v;edge[tot].next=head[u];head[u]=tot++;
}


以下就是两个重点操作:1.求fa,deep,son,num   2.求top,p

void  dfs(int u,int pre,int d) //求fa,deep,son,num
{
    fa[u]=pre;
    deep[u]=d;
    num[u]=1;
    for(int i=head[u];i!=-1;i=edge[i].next)
    {
        int v=edge[i].to;
        if(v!=pre)
        {
            dfs(v,u,d+1);
            num[u]+=num[v];
            if(tson[u]==-1||num[tson[u]]<num[v])
            tson[u]=v;
        }
    }
}

void getpos(int u,int sp)//求top,p
{
    top[u]=sp;
    p[u]=pos++;
    fp[p[u]]=u;
    if(tson[u]==-1) return;
    getpos(tson[u],sp);
    for(int i=head[u];i!=-1;i=edge[i].next)
    {
        int v=edge[i].to;
        if(v!=tson[u]&&v!=fa[u])
          getpos(v,v);
    }
}


以下这个是依据题目而订的操作,就是訪问u->v链,并运行操作:

void Change(int u,int v,int value)  //区间更新
{
    int f1=top[u],f2=top[v];
    while(f1!=f2)
    {
        if(deep[f1]<deep[f2])
        {
            swap(f1,f2);
            swap(u,v);
        }
        //详细操作
        u=fa[f1];f1=top[u];
    }
    if(deep[v]<deep[u])
    {
        swap(u,v);
    }
      //详细操作
}

注意:树链剖分能够分为两种,一种是边权处理,一种是点权处理。点权处理没什么好说的,边权处理时,我们用儿子节点代表该节点与父节点之间的边的权值。


hdu 3966 Aragorn's Story(点权)

这道题使用的是树链剖分+树状数组,注意查询t节点时,查询的是p[t](常常忘记,WA了好几次 = =。)


#pragma comment(linker, "/STACK:1024000000,1024000000")
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <set>
#include <map>
#include <queue>
#include <string>
#define maxn 50010
using namespace std;
typedef long long  ll;

struct Edge
{
    int to,next;
}edge[2*maxn];//树的边

int head[maxn],tot;//邻接表
int top[maxn];//节点所在重链的最高点
int fa[maxn];//节点的父亲节点
int deep[maxn];//节点的深度
int num[maxn];//节点子树的节点数
int p[maxn];//p与父节点间连线在线段树中的位置
int fp[maxn];//线段树中fp位置相应的的节点
int tson[maxn];//重儿子
int pos;//线段树中的位置


void init()   //初始化操作
{
    tot=1;
    memset(head,-1,sizeof(head));
    pos=1;
    memset(tson,-1,sizeof(tson));
}

void addedge(int u,int v)     //邻接表的加边操作
{
    edge[tot].to=v;edge[tot].next=head[u];head[u]=tot++;
}

void  dfs(int u,int pre,int d) //求fa,deep,son,num
{
    fa[u]=pre;
    deep[u]=d;
    num[u]=1;
    for(int i=head[u];i!=-1;i=edge[i].next)
    {
        int v=edge[i].to;
        if(v!=pre)
        {
            dfs(v,u,d+1);
            num[u]+=num[v];
            if(tson[u]==-1||num[tson[u]]<num[v])
            tson[u]=v;
        }
    }
}

void getpos(int u,int sp)//求top,p
{
    top[u]=sp;
    p[u]=pos++;
    fp[p[u]]=u;
    if(tson[u]==-1) return;
    getpos(tson[u],sp);
    for(int i=head[u];i!=-1;i=edge[i].next)
    {
        int v=edge[i].to;
        if(v!=tson[u]&&v!=fa[u])
          getpos(v,v);
    }
}

int son[maxn];
int n,m,k;

int lowbit(int x)
{
    return x&(-x);
}

void add(int d,int value)
{
    while(d<=n)
    {
        son[d]+=value;
        d+=lowbit(d);
    }
}

int Query(int d)
{
    int sum=0;
    while(d>0)
    {
        sum+=son[d];
        d-=lowbit(d);
    }
    return sum;
}


void Change(int u,int v,int value)  //区间更新
{
    int f1=top[u],f2=top[v];
    while(f1!=f2)
    {
        if(deep[f1]<deep[f2])
        {
            swap(f1,f2);
            swap(u,v);
        }
        add(p[f1],value);
        add(p[u]+1,-value);
        u=fa[f1];f1=top[u];
    }
    if(deep[v]<deep[u])
    {
        swap(u,v);
    }
    add(p[u],value);
    add(p[v]+1,-value);
}
int v[maxn];
int main()
{
    char s[10];
    int l,r,t;
    while(scanf("%d%d%d",&n,&m,&k)!=EOF)
    {
        init();
        memset(son,0,sizeof(son));
        for(int i=1;i<=n;i++)
        {
            scanf("%d",&v[i]);
        }
        for(int i=1;i<=m;i++)
        {
            scanf("%d%d",&l,&r);
            addedge(l,r);
            addedge(r,l);
        }
        dfs(1,0,0);
        getpos(1,1);
        for(int i=1;i<=n;i++)
        {
            add(p[i],v[i]);
            add(p[i]+1,-v[i]);
        }
        while(k--)
        {
            scanf("%s",s);
            if(s[0]=='Q')
            {
                scanf("%d",&t);
                printf("%d\n",Query(p[t]));   //就是这里,不是t
            }
            else
            {
                scanf("%d%d%d",&l,&r,&t);
                if(s[0]=='D')t=-t;
                Change(l,r,t);
            }

        }
    }
    return 0;
}


SPOJ QTREE(边权) 树链剖分+线段树


#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <set>
#include <map>
#include <queue>
#include <string>
#define maxn 10010
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
using namespace std;
typedef long long  ll;

struct Edge
{
    int to,next;
}edge[2*maxn];//树的边

int head[maxn],tot;//邻接表
int top[maxn];//节点所在重链的最高点
int fa[maxn];//节点的父亲节点
int deep[maxn];//节点的深度
int num[maxn];//节点子树的节点数
int p[maxn];//p与父节点间连线在线段树中的位置
int fp[maxn];//线段树中fp位置相应的的节点
int tson[maxn];//重儿子
int pos;//线段树中的位置


void init()   //初始化操作
{
    tot=1;
    memset(head,-1,sizeof(head));
    pos=1;
    memset(tson,-1,sizeof(tson));
}

void addedge(int u,int v)     //邻接表的加边操作
{
    edge[tot].to=v;edge[tot].next=head[u];head[u]=tot++;
}

void  dfs(int u,int pre,int d) //求fa,deep,son,num
{
    fa[u]=pre;
    deep[u]=d;
    num[u]=1;
    for(int i=head[u];i!=-1;i=edge[i].next)
    {
        int v=edge[i].to;
        if(v!=pre)
        {
            dfs(v,u,d+1);
            num[u]+=num[v];
            if(tson[u]==-1||num[tson[u]]<num[v])
            tson[u]=v;
        }
    }
}

void getpos(int u,int sp)//求top,p
{
    top[u]=sp;
    p[u]=pos++;
    fp[p[u]]=u;
    if(tson[u]==-1) return;
    getpos(tson[u],sp);
    for(int i=head[u];i!=-1;i=edge[i].next)
    {
        int v=edge[i].to;
        if(v!=tson[u]&&v!=fa[u])
          getpos(v,v);
    }
}


//以下是线段树模板了,不介绍了
struct segment
{
    int l,r;
    int value;
    int nv;
} son[maxn<<2];


void PushUp(int rt)
{
    son[rt].value=max(son[rt<<1].value,son[rt<<1|1].value);
}


void PushDown(int rt)
{
    if(son[rt].nv)
    {
        son[rt<<1].nv=son[rt].nv;
        son[rt<<1|1].nv=son[rt].nv;
        son[rt<<1].value=son[rt<<1].nv;
        son[rt<<1|1].value=son[rt<<1|1].nv;
        son[rt].nv=0;
    }
}


void Build(int l,int r,int rt)
{
    son[rt].l=l;
    son[rt].r=r;
    if(l==r)
    {
        son[rt].value=1;
        return;
    }
    int m=(l+r)/2;
    Build(lson);
    Build(rson);
    PushUp(rt);
}

//线段树单点更新
void Update_1(int p,int value,int rt)
{
    if(son[rt].l==son[rt].r)
    {
        son[rt].value=value;
        return;
    }

    //PushDown(rt);

    int m=(son[rt].l+son[rt].r)/2;
    if(p<=m)
        Update_1(p,value,rt<<1);
    else
        Update_1(p,value,rt<<1|1);

    PushUp(rt);
}

//线段树区间更新
void Update_n(int w,int l,int r,int rt)
{
    if(son[rt].l==l&&son[rt].r==r)
    {
        son[rt].value+=w*(r-l+1);
        son[rt].nv+=w;
        return;
    }

    PushDown(rt);

    int m=(son[rt].l+son[rt].r)/2;

    if(r<=m)
        Update_n(w,l,r,rt<<1);
    else if(l>m)
        Update_n(w,l,r,rt<<1|1);
    else
    {
        Update_n(w,lson);
        Update_n(w,rson);
    }
    PushUp(rt);
}


int  Query(int l,int r,int rt)
{
    if(son[rt].l==l&&son[rt].r==r)
    {
        return son[rt].value;
    }

    //PushDown(rt);

    int ret=0;
    int m=(son[rt].l+son[rt].r)/2;

    if(r<=m)
        ret=Query(l,r,rt<<1);
    else if(l>m)
        ret=Query(l,r,rt<<1|1);
    else
    {
        ret=Query(lson);
        ret=max(Query(rson),ret);
    }
    return ret;
}

int find(int u,int v)  //寻找u->v的链上的最值
{                            //主要思想就是比較u和v的是否在同一条重链上
    int f1=top[u],f2=top[v];//若在同一条重链上(f1=f2),直接求就能够了
    int tmp=0;              //由于同一条重链上的点在线段树中是连续的
    while(f1!=f2)           //若不在同一条重链上。那么我们比較两个点的深度
    {                       //一直向上寻找,直到连个点在同一条重链上
        if(deep[f1]<deep[f2])
        {
            swap(f1,f2);
            swap(u,v);
        }
        tmp=max(tmp,Query(p[f1],p[u],1));
        u=fa[f1];f1=top[u];
    }
    if(u==v) return tmp;
    if(deep[u]<deep[v])
    {
        swap(u,v);
    }
    return max(tmp,Query(p[tson[v]],p[u],1));
}

int e[maxn][3];
int main()
{
    char s[10];
    int cas,l,r;
    scanf("%d",&cas);
    while(cas--)
    {
        int n;
        scanf("%d",&n);
        init();
        for(int i=1;i<n;i++)
        {
            scanf("%d%d%d",&e[i][0],&e[i][1],&e[i][2]);
            addedge(e[i][0],e[i][1]);
            addedge(e[i][1],e[i][0]);
        }
        dfs(1,1,0);
        getpos(1,1);
        Build(1,pos-1,1);
        for(int i=1;i<n;i++)
        {
            if(deep[e[i][0]]<deep[e[i][1]])
                swap(e[i][0],e[i][1]);
            Update_1(p[e[i][0]],e[i][2],1);
        }
        do
        {
            scanf("%s",s);
            if(s[0]=='D') break;
            scanf("%d%d",&l,&r);
            if(s[0]=='Q') printf("%d\n",find(l,r));
            else Update_1(p[e[l][0]],r,1);
        }while(1);
    }
    return 0;
}

 








posted @ 2016-02-18 18:59  phlsheji  阅读(194)  评论(0编辑  收藏  举报