把博客园图标替换成自己的图标
把博客园图标替换成自己的图标end

树链剖分原理、实现及例题

参考博文:
http://www.cnblogs.com/George1994/p/7821357.html

知识点

重结点:子树结点数目最多的结点;
轻节点:父亲节点中除了重结点以外的结点;
重边:父亲结点和重结点连成的边;
轻边:父亲节点和轻节点连成的边;
重链:由多条重边连接而成的路径;
轻链:由多条轻边连接而成的路径;

dfs1用来计算出一些上述的值
而dfs2则是从根节点开始,连重边成重链,以便于后面的线段树等数据结构的操作


例题 Query on a tree系列

Query on a tree
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
#define MAXN 50005
#define INF 0x3fffffff
#define lch rt<<1
#define rch rt<<1|1
int n,dcnt;
int val[MAXN],siz[MAXN],top[MAXN],son[MAXN];
int dep[MAXN],tid[MAXN],rnk[MAXN],fa[MAXN],belong[MAXN];
struct Tnode{
    int v,w,id;
    Tnode *next;
}Edge[MAXN*2];
Tnode *Adj[MAXN],*ecnt;
struct Snode{
    int mx;
}Seg[MAXN*4];

void Init()
{
    memset(Adj,0,sizeof(Adj));
    memset(son,-1,sizeof(son));
    dcnt=0;
    ecnt=&Edge[0]; 
}

void AddEdge(int u,int v,int w,int id)
{
    Tnode *p=++ecnt;
    p->v=v,p->w=w,p->id=id,p->next=Adj[u],Adj[u]=p;
    p=++ecnt;
    p->v=u,p->w=w,p->id=id,p->next=Adj[v],Adj[v]=p;
}

void dfs1(int u,int f,int d)
{
    dep[u]=d;
    fa[u]=f;
    siz[u]=1;
    for(Tnode *p=Adj[u];p!=NULL;p=p->next)
    {
        int v=p->v;
        int id=p->id;
        if(v==f) continue;
        belong[id]=v;
        val[v]=p->w;
        dfs1(v,u,d+1);
        siz[u]+=siz[v];
        if(son[u]==-1||siz[v]>siz[son[u]])
            son[u]=v;
    }
}
void dfs2(int u,int tp)
{
    top[u]=tp;
    tid[u]=++dcnt;
    rnk[dcnt]=u;
    if(son[u]==-1)
        return ;
    dfs2(son[u],tp);
    for(Tnode *p=Adj[u];p!=NULL;p=p->next)
    {
        int v=p->v;
        if(v==son[u]||v==fa[u])
            continue;
        dfs2(v,v);
    }
}
void Pushup(int rt)
{
    Seg[rt].mx=max(Seg[lch].mx,Seg[rch].mx);
}
void Build(int rt,int l,int r)
{
    if(l==r)
    {
        Seg[rt].mx=val[rnk[l]];
        return ;
    }
    int mid=(l+r)>>1;
    Build(lch,l,mid);
    Build(rch,mid+1,r);
    Pushup(rt);
}
void update(int rt,int l,int r,int pos,int val)
{
    if(l==r)
    {
        Seg[rt].mx=val;
        return ;
    }
    int mid=(l+r)>>1;
    if(pos<=mid)
        update(lch,l,mid,pos,val);
    else update(rch,mid+1,r,pos,val);
    Pushup(rt);
}
int Query(int rt,int l,int r,int st,int ed)
{
    if(st<=l&&r<=ed)
        return Seg[rt].mx;
    int mid=(l+r)>>1;
    int ret=-INF;
    if(st<=mid)
        ret=max(ret,Query(lch,l,mid,st,ed));
    if(ed>mid)
        ret=max(ret,Query(rch,mid+1,r,st,ed));
    return ret;
}
void Change(int i,int val)
{
    int u=belong[i];
    update(1,2,n,tid[u],val);
}
int Ask(int u,int v)
{
    int ans=-INF;
    while(top[u]!=top[v])
    {
        if(dep[top[u]]<dep[top[v]])
            swap(u,v);
        ans=max(ans,Query(1,2,n,tid[top[u]],tid[u]));
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])
        swap(u,v);
    if(u!=v)
        ans=max(ans,Query(1,2,n,tid[u]+1,tid[v]));
    return ans;
}
int main()
{
    int cs;
    scanf("%d",&cs);
    while(cs--)
    {

        Init();
        scanf("%d",&n);
        for(int i=1;i<n;i++)
        {
            int u,v,w;
            scanf("%d %d %d",&u,&v,&w);
            AddEdge(u,v,w,i);
        }
        dfs1(1,1,1);
        dfs2(1,1);
        Build(1,2,n);
        char s[50];
        int u,v;
        while(1)
        {
            scanf("%s",s);
            if(s[0]=='D')
                break;
            scanf("%d %d",&u,&v);
            if(s[0]=='Q')
                printf("%d\n",Ask(u,v));
            else Change(u,v);
        }
    }
}
Query on a tree II
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define MAXN 10010
#define INF 0x3fffffff
#define lch rt<<1
#define rch rt<<1|1
struct Tnode{
    int v,w,id;
    Tnode *next;
}edge[MAXN*2];
Tnode *Adj[MAXN],*ecnt;
struct Snode{
    int sum;
}Seg[MAXN*4];
int n,k,dcnt;
int val[MAXN],siz[MAXN],top[MAXN],son[MAXN],dep[MAXN],tid[MAXN],rnk[MAXN],fa[MAXN];
int belong[MAXN];

void Init()
{
    memset(Adj,0,sizeof(Adj));
    memset(son,-1,sizeof(son));
    dcnt=0;
    ecnt=&edge[0]; 
}
void AddEdge(int u,int v,int w,int id)
{
    Tnode *p=++ecnt;
    p->v=v,p->w=w,p->id=id,p->next=Adj[u],Adj[u]=p;
    p=++ecnt;
    p->v=u,p->w=w,p->id=id,p->next=Adj[v],Adj[v]=p;
}
void dfs1(int u,int f,int d)
{
    dep[u]=d;
    fa[u]=f;
    siz[u]=1;
    for(Tnode *p=Adj[u];p!=NULL;p=p->next)
    {
        int v=p->v,id=p->id;
        if(v==f) continue;
        belong[id]=v;
        val[v]=p->w;
        dfs1(v,u,d+1);
        siz[u]+=siz[v];
        if(son[u]==-1||siz[v]>siz[son[u]])
            son[u]=v;
    }
}
void dfs2(int u,int tp)
{
    top[u]=tp;
    tid[u]=++dcnt;
    rnk[dcnt]=u;
    if(son[u]==-1)
        return ;
    dfs2(son[u],tp);
    for(Tnode *p=Adj[u];p!=NULL;p=p->next)
    {
        int v=p->v;
        if(v==fa[u]||v==son[u])
            continue;
        dfs2(v,v);
    }
}
void Pushup(int rt)
{
    Seg[rt].sum=Seg[lch].sum+Seg[rch].sum;
}
void Build(int rt,int l,int r)
{
    if(l==r)
    {
        Seg[rt].sum=val[rnk[l]];
        return ;
    }
    int mid=(l+r)>>1;
    Build(lch,l,mid);
    Build(rch,mid+1,r);
    Pushup(rt);
}
int Query(int rt,int l,int r,int st,int ed)
{
    if(st<=l&&r<=ed)
        return Seg[rt].sum;
    int mid=(l+r)>>1;
    int ret=0;
    if(st<=mid)
        ret+=Query(lch,l,mid,st,ed);
    if(ed>mid)
        ret+=Query(rch,mid+1,r,st,ed);
    return ret;
}
int Ask(int u,int v,bool typ)
{
    int ans=0,tu=u,tv=v;
    while(top[u]!=top[v])
    {
        if(dep[top[u]]<dep[top[v]])
            swap(u,v);
        ans+=Query(1,2,n,tid[top[u]],tid[u]);
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])
        swap(u,v);
    int lca=u;
    if(u!=v)
        ans+=Query(1,2,n,tid[u]+1,tid[v]);
    if(!typ)
        return ans;
    if(dep[tu]-dep[lca]+1<k)
    {
        k-=dep[tu]-dep[lca]+1;
        k=dep[tv]-dep[lca]-k+1;
        tu=tv;
    }
    while(dep[top[tu]]>dep[lca])
    {
        int dis=dep[tu]-dep[top[tu]]+1;
        if(dis>=k) break;
        k-=dis;
        tu=fa[top[tu]];
    }
    return rnk[tid[tu]-k+1];
}
int main()
{
    int cs;
    char s[10];
    scanf("%d",&cs);
    while(cs--)
    {
        Init();
        scanf("%d",&n);
        for(int i=1;i<n;i++)
        {
            int u,v,w;
            scanf("%d %d %d",&u,&v,&w);
            AddEdge(u,v,w,i);
        }
        dfs1(1,1,1);
        dfs2(1,1);
        Build(1,2,n);
        while(~scanf("%s",s))
        {
            if(s[1]=='O')
                break;
            int u,v;
            scanf("%d %d",&u,&v);
            if(s[1]=='I')
                printf("%d\n",Ask(u,v,0));
            else 
            {
                scanf("%d",&k);
                printf("%d\n",Ask(u,v,1));
            }
        }
    }
    return 0;
}
Query on a tree again!
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
#define MAXN 100005
#define INF 0x3fffffff
#define lch rt<<1
#define rch rt<<1|1
int n,m,dcnt;
int val[MAXN],siz[MAXN],top[MAXN],son[MAXN];
int dep[MAXN],tid[MAXN],rnk[MAXN],fa[MAXN];
struct Tnode{
    int v;
    Tnode *next;
}Edge[MAXN*2];
Tnode *Adj[MAXN],*ecnt;
struct Snode{
    int sum;
}Seg[MAXN*4];
void Init()
{
    memset(Adj,0,sizeof(Adj));
    memset(son,-1,sizeof(son));
    dcnt=0;
    ecnt=&Edge[0]; 
}
void AddEdge(int u,int v)
{
    Tnode *p=++ecnt;
    p->v=v,p->next=Adj[u],Adj[u]=p;
    p=++ecnt;
    p->v=u,p->next=Adj[v],Adj[v]=p;
}
//--------------------------------------------------------
void dfs1(int u,int f,int d)
{
    dep[u]=d;
    fa[u]=f;
    siz[u]=1;
    for(Tnode *p=Adj[u];p!=NULL;p=p->next)
    {
        int v=p->v;
        if(v==f) continue;
        dfs1(v,u,d+1);
        siz[u]+=siz[v];
        if(son[u]==-1||siz[v]>siz[son[u]])
            son[u]=v;
    }
}
void dfs2(int u,int tp)
{
    top[u]=tp;
    tid[u]=++dcnt;
    rnk[dcnt]=u;
    if(son[u]==-1)
        return ;
    dfs2(son[u],tp);
    for(Tnode *p=Adj[u];p!=NULL;p=p->next)
    {
        int v=p->v;
        if(v==son[u]||v==fa[u])
            continue;
        dfs2(v,v);
    }
}

//---------------------------------------------------------------
void pushup(int rt)
{
    Seg[rt].sum=Seg[lch].sum+Seg[rch].sum;
}
void build(int rt,int l,int r)
{
    if(l==r)
    {
        Seg[rt].sum=0;
        return ;
    }
    int mid=(l+r)>>1;
    build(lch,l,mid);
    build(rch,mid+1,r);
    pushup(rt);
}
void update(int rt,int l,int r,int pos)
{
    if(l==r)
    {
        Seg[rt].sum^=1;
        return ;
    }
    int mid=(l+r)>>1;
    if(pos<=mid)
        update(lch,l,mid,pos);
    else update(rch,mid+1,r,pos);
    pushup(rt);
}
int ask(int rt,int l,int r)
{
    if(!Seg[rt].sum) return -1;
    if(l==r) return l;
    int mid=(l+r)>>1;
    if(!Seg[lch].sum) return ask(rch,mid+1,r);//优先左区间,左区间没有返回右区间
    else return ask(lch,l,mid);
}
int query(int rt,int l,int r,int st,int ed)//查找路径上的第一个黑点的tid
{
    if(!Seg[rt].sum) return -1;//没有黑点
    if(l==st&&r==ed)
        return ask(rt,l,r);
    int mid=(l+r)>>1;
    if(ed<=mid) return query(lch,l,mid,st,ed);
    else if(st>mid) return query(rch,mid+1,r,st,ed);
    else//横跨左右儿子
    {
        int res=query(lch,l,mid,st,mid);//tid在线段树中有序,优先左区间,左区间没有才找右区间
        if(res==-1) return query(rch,mid+1,r,mid+1,ed);
        else return res;
    }
}
int sol(int u)
{
    int res=-1,t;
    while(top[u]!=1)//目标:走到节点1所在重链,就能顺着该重链到节点1
    {
        t=query(1,1,n,tid[top[u]],tid[u]);//沿着重链向上爬,查找这一条重链上的第一个黑点
        if(t!=-1) res=rnk[t];//越往上,答案越优,所以不断更新
        u=fa[top[u]];//走一条轻边,爬另一条重链
    }
    if(u==1) t=query(1,1,n,1,1);
    else t=query(1,1,n,1,tid[u]);
    if(t!=-1) res=rnk[t];
    return res;
}
int main()
{
    Init();
    scanf("%d %d",&n,&m);
    int u,v;
    for(int i=1;i<=n-1;i++)
    {
        scanf("%d %d",&u,&v);
        AddEdge(u,v);
    }
    dfs1(1,1,1);
    dfs2(1,1);
    build(1,1,n);
    int opt;
    for(int i=1;i<=m;i++)
    {
        scanf("%d %d",&opt,&u);
        if(opt) printf("%d\n",sol(u));
        else update(1,1,n,tid[u]);
    }
}
Query on a tree IV
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<set>
using namespace std;
#define MAXN 100010
#define INF 0x3f3f3f3f
int n,m,dcnt;
int tid[MAXN],top[MAXN],rnk[MAXN],sz[MAXN],fa[MAXN],son[MAXN],fav[MAXN];
int presum[MAXN*4],cnt,col[MAXN],root[MAXN];
multiset<int> ch[MAXN],ansx;
bool flagx[MAXN*4];
struct node{
    int pl,pr,v,maxl,maxr;
}tree[MAXN*4];
void read(int& x)
{
    char c;
    bool flag=0;
    while(c=getchar(),c!=EOF&&(c<'0'||c>'9')&&c!='-');
    if(c=='-'){flag=1;c=getchar();}
    x=c-'0';
    while(c=getchar(),c!=EOF&&c>='0'&&c<='9')   x=x*10+c-'0';
    if(flag)    x=-x;
}
vector<int> a[MAXN],va[MAXN];
//-----------------------------------------------------
int Query(int x,int y)
{
    if(x>y) swap(x,y);
    return presum[y]-presum[x];
}
void dfs1(int x,int fax)
{
    fa[x]=fax;
    sz[x]=1;
    son[x]=-1;
    for(int i=0;i<a[x].size();i++)
        if(a[x][i]!=fax)
        {
            dfs1(a[x][i],x);
            sz[x]+=sz[a[x][i]];
            if(son[x]==-1||sz[son[x]]<sz[a[x][i]])
                son[x]=a[x][i];
        }
        else fav[x]=va[x][i];//顶点x与它父亲的边权
}
void dfs2(int x,int tp)
{
    top[x]=tp;
    tid[x]=++dcnt;
    sz[tp]++;
    if(x!=tp)
        presum[dcnt]=presum[tid[fa[x]]]+fav[x];
    rnk[dcnt]=x;
    if(son[x]==-1) return ;
    dfs2(son[x],tp);
    for(int i=0;i<a[x].size();i++)
        if(a[x][i]!=son[x]&&a[x][i]!=fa[x])
            dfs2(a[x][i],a[x][i]);
}
void Merge(int u,int x,int y,int l,int r)
{
    int mid=(l+r)>>1;
    flagx[u]=1;
    tree[u].maxl=max(tree[y].maxl+Query(l,mid+1),tree[x].maxl);//以左端点出发(含),向右的最大
    tree[u].maxr=max(tree[x].maxr+Query(mid,r),tree[y].maxr);//以右端点出发(含),向左的最大
    tree[u].v=max(tree[x].maxr+tree[y].maxl+Query(mid,mid+1)/*跨越左右儿子*/,max(tree[x].v/*在左儿子*/,tree[y].v/*在右儿子*/));
}
void update(int x,int id)
{
    int d1=-INF,d2=-INF;
    flagx[id]=1;
    if(ch[x].size()!=0)
        d1=*(ch[x].rbegin());//按从小到大排序 rbegin()是最大 由于是迭代器 *表示取值
    if(ch[x].size()>1)
        d2=*(++ch[x].rbegin());//d1,d2取出最长的链和次长的链
    if(col[x]==0)
    {
        tree[id].maxl=max(d1,0);
        tree[id].maxr=max(d1,0);
        tree[id].v=max(0,max(d1,d1+d2));
    }
    else
    {
        tree[id].maxl=d1;
        tree[id].maxr=d1;
        tree[id].v=max(-INF,d1+d2);
    }
}
void build(int l,int r,int id,bool flag)
{
    if(l==r)
    {
        int x=rnk[l];
        for(int i=0;i<a[x].size();i++)
        {
            int y=a[x][i];
            if(y!=fa[x]&&top[y]!=top[x])
            {
                root[y]=++cnt;// 每一条链都建了一颗线段树
                build(tid[y],tid[y]+sz[y]-1,root[y],1);
                ch[x].insert(tree[root[y]].maxl+va[x][i]);//上传信息
            }
        }
        update(x,id);
        if(flag==1)
            ansx.insert(tree[id].v);
    }
    else
    {
        int mid=(l+r)>>1;
        tree[id].pl=++cnt;
        tree[id].pr=++cnt;
        build(l,mid,tree[id].pl,0);
        build(mid+1,r,tree[id].pr,0);
        Merge(id,tree[id].pl,tree[id].pr,l,r);
        if(flag==1)
            ansx.insert(tree[id].v);
    }
}
vector<int> path,pathv;
void find_path(int x)
{
    while(x!=0)
    {
        path.push_back(x);
        pathv.push_back(fav[top[x]]);
        x=fa[top[x]];
    }
}
void modify(int l,int r,int id,int i,bool flag)
{
    if(l==r)
    {
        int x=rnk[l];
        if(i!=0)
        {
            int ne=top[path[i-1]];
            ch[x].erase(ch[x].find(tree[root[ne]].maxl+pathv[i-1]));
            modify(tid[ne],tid[ne]+sz[ne]-1,root[ne],i-1,1);
            ch[x].insert(tree[root[ne]].maxl+pathv[i-1]);
        }
        if(flag)
            ansx.erase(ansx.find(tree[id].v));
        update(x,id);
        if(flag)
            ansx.insert(tree[id].v);
    }
    else
    {
        int mid=(l+r)>>1;
        if(tid[path[i]]<=mid)
            modify(l,mid,tree[id].pl,i,0);
        else    modify(mid+1,r,tree[id].pr,i,0);
        if(flag)
            ansx.erase(ansx.find(tree[id].v));
        Merge(id,tree[id].pl,tree[id].pr,l,r);
        if(flag)
            ansx.insert(tree[id].v);
    }
}
//-----------------------------------------------------
int u,v,q,sum,val;
char t[10];
int main()
{
    scanf("%d",&n);
    sum=0;
    for(int i=1;i<n;i++)
    {
        read(u),read(v),read(val);
        a[u].push_back(v);
        a[v].push_back(u);
        va[u].push_back(val);
        va[v].push_back(val);
    }
    dfs1(1,0);
    memset(sz,0,sizeof(sz));
    dfs2(1,1);
    root[1]=1;
    cnt=1;
    build(tid[1],tid[1]+sz[1]-1,1,1);
    scanf("%d",&q);
    for(int i=1;i<=q;i++)
    {
        scanf("%s",t);
        if(t[0]=='C')
        {
            read(u);
            if(!col[u]) 
                sum++;
            else sum--;
            col[u]^=1;
            path.clear();
            pathv.clear();
            find_path(u);
            modify(tid[1],tid[1]+sz[1]-1,root[1],path.size()-1,1);
        }
        else
        {
            if(sum==n)
                printf("They have disappeared.\n");
            else if(sum==n-1)
                printf("0\n");
            else 
                printf("%d\n",*ansx.rbegin());
        }
    }
}
一句话总结:

树链剖分就是将树形结构转化为线性结构,通过将树划分成许多条链然后再利用线段树等数据结构进行维护,后续操作与树本身并没有什么关系,即对原树不再进行修改。

posted @ 2018-08-09 12:24  Starlight_Glimmer  阅读(6)  评论(0编辑  收藏  举报  来源
浏览器标题切换
浏览器标题切换end