zoj 3649 lca与倍增dp
参考:http://www.xuebuyuan.com/609502.html
先说题意:
给出一幅图,求最大生成树,并在这棵树上进行查询操作:给出两个结点编号x和y,求从x到y的路径上,由每个结点的权值构成的序列中的极差大小——要求,被减数要在减数的后面,即形成序列{a1,a2…aj …ak…an},求ak-aj (k>=j)的最大值。
求路径,显然用到lca。
太孤陋寡闻,才知道原来倍增dp能用来求LCA。
用p[u][i]表示结点u的第1<< i 个祖先结点,则有递推如下:
for(int i=0;i<POW;i++) p[u][i]=p[p[u][i-1]][i-1]。
在对图dfs的时候即完成递推。
要想求两个结点的lca,首先使得两结点高度相同,若二者的父亲结点不同,则一直向上查找。dep数组表示结点的深度。
int LCA(int a,int b){
if(dep[a]>dep[b]) swap(a,b);
if(dep[a]<dep[b]){
//这一部分使得dep[a]==dep[b]
int tmp=dep[b]-dep[a];
for(int i=0;i<POW;i++) if(tmp&(1<<i))
//这里从POW-1到0来遍历也是一样的
b=p[b][i];
}
if(a!=b){
for(int i=POW-1;i>=0;i--) if(p[a][i]!=p[b][i])
a=p[a][i],b=p[b][i];
a=p[a][0],b=p[b][0];
}
return a;
}
如此即返回结点的lca。
用倍增遍历的思路:
因为一段路被二进制分成了一截一截,或者说路径长度被用二进制表示了出来。而两个结点的深度差即为“路径长度”,所以只要tmp&(1<<i),则表示这是“路径”的其中一个结点,以此类推,从而得到两个深度相同的结点。
有了这个基础之后,用相同的方式构建——
mx数组,mx[u][i]表示从u到其第1<<i个祖先结点路径上的最大值
mn数组,mn[u][i]表示从u到其第1<<i个祖先结点路径上的最小值
dp数组,dp[u][i],表示从u到其第1<<i个祖先结点路径上的最大差值
dp2数组,dp2[u][i],表示从其第1<<i个祖先结点到u路径上的最大差值
构建好后是查询部分。给出结点x和y,获得lca。
则路径被分成两段—— x->lca->y。则有三种可能性:
x到lca上的最大差值;lca到y上的最大差值;x到y上的最大差值(即lca到y的最大值减去x到lca的最小值)。比较一下即可。
这题真心涨姿势。代码如下:
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> using namespace std; const int N=3e4+10,M=N<<1,POW=16,inf=21e8; int mx[N][POW],mn[N][POW],p[N][POW],dp[N][POW],dp2[N][POW]; int head[N],nxt[M],to[M],cnt,val[N],vis[N],dep[N]; int n,m,q,fa[N]; struct Edge{ int u,v,w; bool operator < (const Edge e) const{ return w>e.w; } }E[M]; void ini(int n){ memset(head,-1,sizeof(head)); cnt=0; memset(vis,0,sizeof(vis)); fill(p[0],p[n+1],0); fill(mx[0],mx[n+1],-inf); fill(mn[0],mn[n+1],inf); fill(dp[0],dp[n+1],-inf); fill(dp2[0],dp2[n+1],-inf); dep[0]=0; } int find_(int x){ return x==fa[x]?x:fa[x]=find_(fa[x]); } void addedge(int u,int v){ to[cnt]=v; nxt[cnt]=head[u]; head[u]=cnt++; } int Kruskal(){ for(int i=0;i<=n;i++) fa[i]=i; sort(E,E+m); int sum=0; for(int i=0;i<m;i++){ int a=find_(E[i].u),b=find_(E[i].v); if(a!=b){ fa[a]=b; addedge(E[i].u,E[i].v); addedge(E[i].v,E[i].u); sum+=E[i].w; } } return sum; } void dfs(int u,int f){ dep[u]=dep[f]+1; vis[u]=1; for(int i=head[u];~i;i=nxt[i]) if(!vis[to[i]]){ int v=to[i]; p[v][0]=u; mx[v][0]=max(val[u],val[v]); mn[v][0]=min(val[u],val[v]); dp[v][0]=val[u]-val[v]; dp2[v][0]=val[v]-val[u]; for(int j=1;j<POW;j++){ p[v][j]=p[p[v][j-1]][j-1]; mx[v][j]=max(mx[v][j-1],mx[p[v][j-1]][j-1]); mn[v][j]=min(mn[v][j-1],mn[p[v][j-1]][j-1]); dp[v][j]=max(dp[v][j-1],dp[p[v][j-1]][j-1]); dp[v][j]=max(dp[v][j],mx[p[v][j-1]][j-1]-mn[v][j-1]); dp2[v][j]=max(dp2[v][j-1],dp2[p[v][j-1]][j-1]); dp2[v][j]=max(dp2[v][j],mx[v][j-1]-mn[p[v][j-1]][j-1]); } dfs(v,u); } } int LCA(int a,int b){ //第一次看到这样的LCA,holy high //有点不明觉厉 if(dep[a]>dep[b]) swap(a,b); if(dep[a]<dep[b]){ //这一部分使得dep[a]==dep[b] int tmp=dep[b]-dep[a]; for(int i=POW-1;i>=0;i--) if(tmp&(1<<i)) b=p[b][i]; } if(a!=b){ //如果高度相等,而a!=b for(int i=POW-1;i>=0;i--) if(p[a][i]!=p[b][i]) a=p[a][i],b=p[b][i]; a=p[a][0],b=p[b][0]; } return a; } int getmax(int x,int lca){ int ans=0,tmp=dep[x]-dep[lca]; for(int i=POW-1;i>=0;i--) if(tmp&(1<<i)){ ans=max(ans,mx[x][i]); x=p[x][i]; } return ans; } int getmin(int x,int lca){ int ans=inf,tmp=dep[x]-dep[lca]; for(int i=POW-1;i>=0;i--) if(tmp&(1<<i)){ ans=min(ans,mn[x][i]); x=p[x][i]; } return ans; } int getleft(int x,int lca){ int ans=0,minn=inf; int tmp=dep[x]-dep[lca]; for(int i=POW-1;i>=0;i--) if(tmp&(1<<i)){ ans=max(ans,dp[x][i]); ans=max(ans,mx[x][i]-minn); minn=min(minn,mn[x][i]); x=p[x][i]; } return ans; } int getright(int x,int lca){ int ans=0,maxx=0; int tmp=dep[x]-dep[lca]; for(int i=POW-1;i>=0;i--) if(tmp&(1<<i)){ ans=max(ans,dp2[x][i]); ans=max(ans,maxx-mn[x][i]); maxx=max(maxx,mx[x][i]); x=p[x][i]; } return ans; } int main(){ freopen("in.txt","r",stdin); while(~scanf("%d",&n)){ for(int i=1;i<=n;i++) scanf("%d",&val[i]); ini(n); scanf("%d",&m); for(int i=0;i<m;i++) scanf("%d%d%d",&E[i].u,&E[i].v,&E[i].w); printf("%d\n",Kruskal()); dfs(1,0); scanf("%d",&q); int x,y; while(q--){ scanf("%d%d",&x,&y); int lca=LCA(x,y); int ans=getmax(y,lca)-getmin(x,lca); ans=max(ans,getleft(x,lca)); ans=max(ans,getright(y,lca)); printf("%d\n",ans); } } return 0; }