51nod1812树的双直径(换根树DP)
传送门:http://www.51nod.com/Challenge/Problem.html#!#problemId=1812
题解:头一次写换根树DP。
求两条不相交的直径乘积最大,所以可以这样考虑:把一条边割掉,然后分别求两棵子树内的最长链乘起来就行了。由于负负得正,所以要再求一次最短链,就是把边权全部取负求一下就行了。然后就能通过dfs维护子树i内的答案dn[i]和不含以i为根的子树的答案up[i],dn[i]很好维护,重点是维护up[i],共5种可能:(1)从父亲的up继承过来(2)前后缀中的最大值f+出边+入边(3)父亲的g+兄弟节点中最大的f+出边(4)前驱/后继中的最大和次大(5)前驱/后继中的子树中的直径。然后转移状态就行了。
细节太多……还要__int128。为了方便,计算时答案用long long维护,乘起来再转long long……
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int N=4e5+7; int n,tot,hd[N],v[N<<1],w[N<<1],nxt[N<<1],p[N<<1],len[N<<1]; ll f[N],g[N],pre[N],suf[N],dn[N],up[N]; __int128 ans; void print(__int128 x){if(x>9)print(x/10);putchar('0'+x%10);} void add(int x,int y,int z){v[++tot]=y,nxt[tot]=hd[x],hd[x]=tot,w[tot]=z;} void dfs1(int u, int fa) { f[u]=dn[u]=0; for(int i=hd[u];i;i=nxt[i]) if(v[i]!=fa) { dfs1(v[i],u); dn[u]=max(dn[u],f[u]+f[v[i]]+w[i]); f[u]=max(f[u],f[v[i]]+w[i]); dn[u]=max(dn[u],dn[v[i]]); } } void dfs2(int u,int fa) { int cnt=0; for(int i=hd[u];i;i=nxt[i])if(v[i]!=fa)p[++cnt]=v[i],len[cnt]=w[i]; pre[0]=suf[cnt+1]=0; for(int i=1;i<=cnt;i++)pre[i]=max(pre[i-1],f[p[i]]+len[i]); for(int i=cnt;i;i--)suf[i]=max(suf[i+1],f[p[i]]+len[i]); /*一个点向上的直径: (1)从父亲的up继承过来 (2)前后缀中的最大值f+出边+入边 (3)父亲的g+兄弟节点中最大的f+出边 (4)前驱/后继中的最大和次大 (5)前驱/后继中的子树中的直径*/ for(int i=1;i<=cnt;i++) { g[p[i]]=max(g[p[i]],g[u]+len[i]); g[p[i]]=max(g[p[i]],max(pre[i-1],suf[i+1])+len[i]); up[p[i]]=max(up[p[i]],up[u]); up[p[i]]=max(up[p[i]],pre[i-1]+suf[i+1]); up[p[i]]=max(up[p[i]],g[u]+max(pre[i-1],suf[i+1])); } ll mx1=-1e18,mx2=-1e18,mx=-1e18,tmp; for(int i=1;i<=cnt;i++) { up[p[i]]=max(up[p[i]],max(mx1+mx2,mx)); tmp=f[p[i]]+len[i]; if(tmp>mx1)mx2=mx1,mx1=tmp;else if(tmp>mx2)mx2=tmp; mx=max(mx,dn[p[i]]); } mx1=mx2=mx=-1e18; for(int i=cnt;i;i--) { up[p[i]]=max(up[p[i]],max(mx1+mx2,mx)); tmp=f[p[i]]+len[i]; if(tmp>mx1)mx2=mx1,mx1=tmp;else if(tmp>mx2)mx2=tmp; mx=max(mx,dn[p[i]]); } for(int i=hd[u];i;i=nxt[i])if(v[i]!=fa)dfs2(v[i],u); } int main() { scanf("%d",&n); for(int i=1,x,y,z;i<n;i++)scanf("%d%d%d",&x,&y,&z),add(x,y,z),add(y,x,z); dfs1(1,0),dfs2(1,0); for(int i=2;i<=n;i++)ans=max(ans,(__int128)dn[i]*up[i]); for(int i=1;i<=tot;i++)w[i]=-w[i]; memset(up,0,sizeof up); memset(g,0,sizeof g); dfs1(1,0),dfs2(1,0); for(int i=2;i<=n;i++)ans=max(ans,(__int128)dn[i]*up[i]); print(ans); }