zoj 3649 lca与倍增dp

参考:http://www.xuebuyuan.com/609502.html

先说题意:

       给出一幅图,求最大生成树,并在这棵树上进行查询操作:给出两个结点编号x和y,求从x到y的路径上,由每个结点的权值构成的序列中的极差大小——要求,被减数要在减数的后面,即形成序列{a1,a2…aj …ak…an},求ak-aj (k>=j)的最大值。

 

求路径,显然用到lca。

太孤陋寡闻,才知道原来倍增dp能用来求LCA。

用p[u][i]表示结点u的第1<< i 个祖先结点,则有递推如下:

for(int i=0;i<POW;i++)  p[u][i]=p[p[u][i-1]][i-1]。

在对图dfs的时候即完成递推。

要想求两个结点的lca,首先使得两结点高度相同,若二者的父亲结点不同,则一直向上查找。dep数组表示结点的深度。

int LCA(int a,int b){

       if(dep[a]>dep[b]) swap(a,b);

       if(dep[a]<dep[b]){

              //这一部分使得dep[a]==dep[b]

              int tmp=dep[b]-dep[a];

              for(int i=0;i<POW;i++) if(tmp&(1<<i))

              //这里从POW-1到0来遍历也是一样的

                     b=p[b][i];

       }

       if(a!=b){

              for(int i=POW-1;i>=0;i--) if(p[a][i]!=p[b][i])

                     a=p[a][i],b=p[b][i];

              a=p[a][0],b=p[b][0];

       }

       return a;

}

如此即返回结点的lca。

用倍增遍历的思路:

因为一段路被二进制分成了一截一截,或者说路径长度被用二进制表示了出来。而两个结点的深度差即为“路径长度”,所以只要tmp&(1<<i),则表示这是“路径”的其中一个结点,以此类推,从而得到两个深度相同的结点。

 

有了这个基础之后,用相同的方式构建——

mx数组,mx[u][i]表示从u到其第1<<i个祖先结点路径上的最大值

mn数组,mn[u][i]表示从u到其第1<<i个祖先结点路径上的最小值

dp数组,dp[u][i],表示从u到其第1<<i个祖先结点路径上的最大差值

dp2数组,dp2[u][i],表示从其第1<<i个祖先结点到u路径上的最大差值

 

构建好后是查询部分。给出结点x和y,获得lca。

则路径被分成两段—— x->lca->y。则有三种可能性:

x到lca上的最大差值;lca到y上的最大差值;x到y上的最大差值(即lca到y的最大值减去x到lca的最小值)。比较一下即可。

 

这题真心涨姿势。代码如下:

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int N=3e4+10,M=N<<1,POW=16,inf=21e8;
int mx[N][POW],mn[N][POW],p[N][POW],dp[N][POW],dp2[N][POW];
int head[N],nxt[M],to[M],cnt,val[N],vis[N],dep[N];
int n,m,q,fa[N];
struct Edge{
    int u,v,w;
    bool operator < (const Edge e) const{
        return w>e.w;
    }
}E[M];
void ini(int n){
    memset(head,-1,sizeof(head));
    cnt=0;
    memset(vis,0,sizeof(vis));
    fill(p[0],p[n+1],0);
    fill(mx[0],mx[n+1],-inf);
    fill(mn[0],mn[n+1],inf);
    fill(dp[0],dp[n+1],-inf);
    fill(dp2[0],dp2[n+1],-inf);
    dep[0]=0;
}
int find_(int x){
    return x==fa[x]?x:fa[x]=find_(fa[x]);
}
void addedge(int u,int v){
    to[cnt]=v;
    nxt[cnt]=head[u];
    head[u]=cnt++;
}
int Kruskal(){
    for(int i=0;i<=n;i++) fa[i]=i;
    sort(E,E+m);
    int sum=0;
    for(int i=0;i<m;i++){
        int a=find_(E[i].u),b=find_(E[i].v);
        if(a!=b){
            fa[a]=b;
            addedge(E[i].u,E[i].v);
            addedge(E[i].v,E[i].u);
            sum+=E[i].w;
        }
    }
    return sum;
}
void dfs(int u,int f){
    dep[u]=dep[f]+1;
    vis[u]=1;
    for(int i=head[u];~i;i=nxt[i]) if(!vis[to[i]]){
        int v=to[i];
        p[v][0]=u;
        mx[v][0]=max(val[u],val[v]);
        mn[v][0]=min(val[u],val[v]);
        dp[v][0]=val[u]-val[v];
        dp2[v][0]=val[v]-val[u];
        for(int j=1;j<POW;j++){
            p[v][j]=p[p[v][j-1]][j-1];
            mx[v][j]=max(mx[v][j-1],mx[p[v][j-1]][j-1]);
            mn[v][j]=min(mn[v][j-1],mn[p[v][j-1]][j-1]);
            
            dp[v][j]=max(dp[v][j-1],dp[p[v][j-1]][j-1]);
            dp[v][j]=max(dp[v][j],mx[p[v][j-1]][j-1]-mn[v][j-1]);

            dp2[v][j]=max(dp2[v][j-1],dp2[p[v][j-1]][j-1]);
            dp2[v][j]=max(dp2[v][j],mx[v][j-1]-mn[p[v][j-1]][j-1]);
        }
        dfs(v,u);
    }
}
int LCA(int a,int b){
    //第一次看到这样的LCA,holy high
    //有点不明觉厉
    if(dep[a]>dep[b]) swap(a,b);
    if(dep[a]<dep[b]){
        //这一部分使得dep[a]==dep[b]
        int tmp=dep[b]-dep[a];
        for(int i=POW-1;i>=0;i--) if(tmp&(1<<i))
            b=p[b][i];
    }
    if(a!=b){
        //如果高度相等,而a!=b
        for(int i=POW-1;i>=0;i--) if(p[a][i]!=p[b][i])
            a=p[a][i],b=p[b][i];
        a=p[a][0],b=p[b][0];
    }
    return a;
}
int getmax(int x,int lca){
    int ans=0,tmp=dep[x]-dep[lca];
    for(int i=POW-1;i>=0;i--) if(tmp&(1<<i)){
        ans=max(ans,mx[x][i]);
        x=p[x][i];
    }
    return ans;
}
int getmin(int x,int lca){
    int ans=inf,tmp=dep[x]-dep[lca];
    for(int  i=POW-1;i>=0;i--) if(tmp&(1<<i)){
        ans=min(ans,mn[x][i]);
        x=p[x][i];
    }
    return ans;
}
int getleft(int x,int lca){
    int ans=0,minn=inf;
    int tmp=dep[x]-dep[lca];
    for(int i=POW-1;i>=0;i--) if(tmp&(1<<i)){
        ans=max(ans,dp[x][i]);
        ans=max(ans,mx[x][i]-minn);
        minn=min(minn,mn[x][i]);
        x=p[x][i];
    }
    return ans;
}
int getright(int x,int lca){
    int ans=0,maxx=0;
    int tmp=dep[x]-dep[lca];
    for(int i=POW-1;i>=0;i--) if(tmp&(1<<i)){
        ans=max(ans,dp2[x][i]);
        ans=max(ans,maxx-mn[x][i]);
        maxx=max(maxx,mx[x][i]);
        x=p[x][i];
    }
    return ans;
}
int main(){
    freopen("in.txt","r",stdin);
    while(~scanf("%d",&n)){
        for(int i=1;i<=n;i++)
            scanf("%d",&val[i]);
        ini(n);
        scanf("%d",&m);
        for(int i=0;i<m;i++)
            scanf("%d%d%d",&E[i].u,&E[i].v,&E[i].w);
        printf("%d\n",Kruskal());
        dfs(1,0);
        scanf("%d",&q);
        int x,y;
        while(q--){
            scanf("%d%d",&x,&y);
            int lca=LCA(x,y);
            int ans=getmax(y,lca)-getmin(x,lca);
            ans=max(ans,getleft(x,lca));
            ans=max(ans,getright(y,lca));
            printf("%d\n",ans);
        }
    }
    return 0;
}
View Code

 

posted @ 2015-10-30 11:47  轶辰  阅读(377)  评论(0编辑  收藏  举报