[边分治+线段树合并]「CTSC2018」暴力写挂
题目梗概
给出两棵1为根的树,求\(d[x]+d[y]-d[lca(x,y)]-d'[lca(x,y)]\)的最大值
解题思路
套路化简之后\((d[x]+d[y]+dis(x,y)-2*d'[lca(x,y)])/2\)
第二棵树上的lca化不掉,所以考虑在第二棵上枚举lca
先说说这题的解法,边分树的合并.
边分和点分有什么区别,边分在合并类似\(d[x]+d[y]+dis(x,y)\)这种贡献是很方便,只要记录一条边两端的点集中\(d[x]+dis(x,u)\)最大值即可,而点分合并这种贡献时复杂度与度数有关.
所以我们边分治第一棵树,建出边分树之后,遍历第二棵树,每次加入一个点,在边分树上维护答案
考虑左右儿子的答案如何合并,因为边分树是二叉树,像线段树一样合并即可,复杂度\(O(n\) \(log n)\).
边分治:将原图转成二叉树(保证复杂度),每次找左右两端节点数最大值最小的边分开即可.
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
#define LL long long
using namespace std;
const int maxn=370005;
const LL INF=(LL)1e18;
inline int _read(){
int num=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9') num=num*10+ch-48,ch=getchar();
return f*num;
}
struct jz{
int x,y,w;
jz(int x=0,int y=0,int w=0):x(x),y(y),w(w){}
};
int lnk[maxn<<2],son[maxn<<3],nxt[maxn<<3],w[maxn<<3],tot=1;
int n,cnt,zo,rx,ry,mi,rh,s[maxn<<2],rs[maxn<<3],ls[maxn<<3],d[maxn<<2],val[maxn<<3],fa[maxn<<3];
int rt[maxn],id,lc[maxn*46],rc[maxn*46],h[maxn*46];
LL rw[maxn*46],lw[maxn*46],dis[23][maxn<<2],ans=-INF;
vector<jz> Q;
bool vis[maxn<<3];
void add(int x,int y,int z){nxt[++tot]=lnk[x];lnk[x]=tot;son[tot]=y;w[tot]=z;}
void DFS(int x,int fa){
int pre=x;
for (int j=lnk[x];j;j=nxt[j]) if (son[j]!=fa){
Q.push_back(jz(pre,son[j],w[j]));
Q.push_back(jz(pre,++cnt,0));pre=cnt;
DFS(son[j],x);
}
}
void get_ro(int x,int fa,int dep,int lst){
s[x]=1;for (int j=lnk[x];j;j=nxt[j]) if (!vis[j]&&son[j]!=fa){
dis[dep][son[j]]=dis[dep][x]+w[j];
get_ro(son[j],x,dep,j),s[x]+=s[son[j]];
}
int w=max(zo-s[x],s[x]);
if (w<=mi) mi=w,rx=x,ry=fa,rh=lst;
}
int work(int x,int dep,int sz){
if (sz<=1) return d[x]=dep,x;
mi=zo=sz;int now=++cnt;
get_ro(x,0,dep,0);vis[rh]=vis[rh^1]=1;val[now]=w[rh];
int X=rx,Y=ry;
rs[now]=work(Y,dep+1,sz-s[X]);
ls[now]=work(X,dep+1,s[X]);
fa[ls[now]]=fa[rs[now]]=now;
return now;
}
int add(int x){
int lst=0,now=x;
for (int i=d[x];i;i--){
h[++id]=fa[now];lw[id]=rw[id]=-INF;
if (now==ls[fa[now]]) lw[id]=dis[i][x]+dis[0][x],lc[id]=lst;
if (now==rs[fa[now]]) rw[id]=dis[i][x]+dis[0][x],rc[id]=lst;
lst=id;now=fa[now];
}
return id;
}
void merge(int &x,int y,LL dep){
if (!x||!y){x=x+y;return;}
ans=max(ans,lw[x]+rw[y]+val[h[x]]-dep);
ans=max(ans,lw[y]+rw[x]+val[h[x]]-dep);
lw[x]=max(lw[x],lw[y]);rw[x]=max(rw[x],rw[y]);
merge(lc[x],lc[y],dep);merge(rc[x],rc[y],dep);
}
void solve(int x,int fa,LL dep){
rt[x]=add(x);ans=max(ans,dis[0][x]*2-dep*2);
for (int j=lnk[x];j;j=nxt[j]) if (son[j]!=fa){
solve(son[j],x,dep+w[j]);
merge(rt[x],rt[son[j]],dep*2);
}
}
int main(){
freopen("exam.in","r",stdin);
freopen("exam.out","w",stdout);
n=_read();cnt=n;
for (int i=1;i<n;i++){
int x=_read(),y=_read(),z=_read();
add(x,y,z);add(y,x,z);
}
DFS(1,0);memset(lnk,0,sizeof(lnk));tot=1;
for (int i=0;i<Q.size();i++) add(Q[i].x,Q[i].y,Q[i].w),add(Q[i].y,Q[i].x,Q[i].w);
work(1,0,cnt);memset(lnk,0,sizeof(lnk));tot=1;
for (int i=1;i<n;i++){
int x=_read(),y=_read(),z=_read();
add(x,y,z);add(y,x,z);
}
solve(1,0,0);printf("%lld\n",ans/2);
return 0;
}