Loj #2553. 「CTSC2018」暴力写挂
Loj #2553. 「CTSC2018」暴力写挂
题目描述
temporaryDO 是一个很菜的 OIer 。在 4 月,他在省队选拔赛的考场上见到了《林克卡特树》一题,其中 \(k = 0\) 的部分分是求树 \(T\) 上的最长链。可怜的 temporaryDO 并不会做这道题,他在考场上抓猫耳挠猫腮都想不出一点思路。
这时,善良的板板出现在了空中,他的身上发出璀璨却柔和的光芒,荡漾在考场上。‘‘题目并不难。’’ 板板说。那充满磁性的声音,让 temporaryDO 全身充满了力量。
他决定:写一个枚举点对求 LCA 算距离的 \(k = 0\) 的 \(O(n^2\log\ n)\) 的部分分程序!于是, temporaryDO 选择以 \(1\) 为根,建立了求 LCA 的树链剖分结构,然后写了二重 for 循环枚举点对。
然而,菜菜的 temporaryDO 不小心开小了数组,于是数组越界到了一片神秘的内存区域。但恰好的是,那片内存区域存储的区域恰好是另一棵树 \(T′\) 。这样一来,程序并没有 RE ,但他求 \(x\) 和 \(y\) 的距离的时候,计算的是
最后程序会输出每一对点对 \(i, j (i \le j)\) 的如上定义的‘‘距离’’ 的最大值。
temporaryDO 的程序在评测时光荣地爆零了。但他并不服气,他决定花好几天把自己的程序跑出来。请你根据 \(T\) 和 \(T′\) 帮帮可怜的 temporaryDO 求出他程序的输出。
输入格式
第一行包含一个整数 \(n\) ,表示树上的节点个数;
第 \(2\) 到第 \(n\) 行,每行三个整数 \(x , y , v\) ,表示 \(T\) 中存在一条从 \(x\) 到 \(y\) 的边,其长度为 \(v\) ;
第 \(n + 1\) 到第 \(2n - 1 行\) ,每行三个整数 \(x , y , v\) ,表示 \(T′\) 中存在一条从 \(x\) 到 \(y\) 的边,其长度为 \(v\) 。
输出格式
输出一行一个整数,表示 temporaryDO 的程序的输出。
数据范围与提示
对于所有数据, \(n \le 366666 , |v| \le 2017011328\) 。
以前一直觉得边分治和点分治没什么区别,做了这道题才发现我太naive了。
首先题目中给的式子,不好看,所以我们把它变一下形:
于是我们可以枚举第二颗树的\(lca\),然后计算其子树之间的\(dep_x+dep_y+dis_{x,y}\)的最大值。后面部分就可以用边分治来维护。
可以类比点分治来理解边分治,就是在每个分治连通块中找到一条边使得边两端的连通块大小尽量平均。但是我们发现,一个菊花就可以把这个分治卡死。原因是某个点的度数太大了。于是我们考虑转成二叉树。具体就是每个点的上面连一个额外点,然后一个父亲节点连向其中一个儿子的额外点,几个儿子的额外点再连成一条线(代码一看就懂)。
假设分治中心边是\((x,y)\),我们像点分治那样统计分支块内每个点\(p\)到分治中心边的其中一个点的距离(\(dep_p+dis_{x,p}\))。这个距离有两种方向,分别对应\(x,y\)所在的连通块。我们发现,边分治的分治树是棵二叉树。所以对于每个点,我们可以开个二叉树,记录其在每一个分治连通块内的到中心点的距离。
然后就可以算答案了。当枚举了第二棵树上的\(lca\)的时候,合并每个子树的二叉树,同一种节点(到根路径相同)代表同一个分支连通块,我们可以在合并过程中更新答案。
代码:
#include<bits/stdc++.h>
#define ll long long
#define N 800005
using namespace std;
inline ll Get() {ll x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}
int n;
struct graph {
int to[N<<2],nxt[N<<2],dis[N<<2];
int h[N<<1],cnt=1;
void add(int i,int j,int d) {
to[++cnt]=j;
nxt[cnt]=h[i];
dis[cnt]=d;
h[i]=cnt;
}
void Init() {
memset(h,0,sizeof(h));
cnt=1;
}
}s,g;
int vertex;
ll dep[N];
void build_edge(int v,int fr) {
int lst=v;
for(int i=s.h[v];i;i=s.nxt[i]) {
int to=s.to[i];
if(to==fr) continue ;
vertex++;
g.add(lst,vertex,0);
g.add(vertex,lst,0);
g.add(vertex,to,s.dis[i]);
g.add(to,vertex,s.dis[i]);
lst=vertex;
dep[to]=dep[v]+s.dis[i];
build_edge(to,v);
}
}
int size[N<<1];
int sum,E,mx;
bool vis[N<<1];
void Find_edge(int v,int fr) {
size[v]=1;
for(int i=g.h[v];i;i=g.nxt[i]) {
int to=g.to[i];
if(vis[i]||to==fr) continue ;
Find_edge(to,v);
size[v]+=size[to];
int now=max(size[to],sum-size[to]);
if(mx>now) {
mx=now;
E=i;
}
}
}
struct node {
ll dis;
int dir;
node() {}
node(ll _dis,int _dir) {dis=_dis,dir=_dir;}
};
vector<node>st[N<<1];
void statis(int v,int fr,ll dis,int dir) {
size[v]=1;
if(v<=n) {
st[v].push_back(node(dep[v]+dis,dir));
}
for(int i=g.h[v];i;i=g.nxt[i]) {
int to=g.to[i];
if(to==fr||vis[i]) continue ;
statis(to,v,dis+g.dis[i],dir);
size[v]+=size[to];
}
}
void solve(int v) {
sum=size[v];
mx=1e9;
Find_edge(v,0);
int x=g.to[E],y=g.to[E^1];
vis[E]=vis[E^1]=1;
statis(x,0,0,0),statis(y,0,g.dis[E],1);
if(size[x]>1) solve(x);
if(size[y]>1) solve(y);
}
int rt[N<<1];
int ls[N*20],rs[N*20];
ll lmx[N*20],rmx[N*20];
int tot;
void build_tree(int &v,vector<node>&a,int now) {
if(now==a.size()) return ;
v=++tot;
if(a[now].dir==0) {
lmx[v]=a[now].dis;
rmx[v]=-1ll<<60;
build_tree(ls[v],a,now+1);
} else {
rmx[v]=a[now].dis;
lmx[v]=-1ll<<60;
build_tree(rs[v],a,now+1);
}
}
ll ans=-1ll<<60;
ll Dis;
int Merge(int a,int b) {
if(!a||!b) return a+b;
ans=max(ans,max(lmx[a]+rmx[b],lmx[b]+rmx[a])-2*Dis);
lmx[a]=max(lmx[a],lmx[b]);
rmx[a]=max(rmx[a],rmx[b]);
ls[a]=Merge(ls[a],ls[b]);
rs[a]=Merge(rs[a],rs[b]);
return a;
}
void dfs2(int v,int fr,ll dis) {
ans=max(ans,2*dep[v]-2*dis);
for(int i=s.h[v];i;i=s.nxt[i]) {
int to=s.to[i];
if(to==fr) continue ;
dfs2(to,v,dis+s.dis[i]);
Dis=dis;
rt[v]=Merge(rt[v],rt[to]);
}
}
int main() {
n=Get();
for(int i=1;i<n;i++) {
int a=Get(),b=Get(),c=Get();
s.add(a,b,c),s.add(b,a,c);
}
vertex=n;
build_edge(1,0);
size[1]=vertex;
solve(1);
for(int i=1;i<=n;i++) {
build_tree(rt[i],st[i],0);
}
s.Init();
for(int i=1;i<n;i++) {
int a=Get(),b=Get(),c=Get();
s.add(a,b,c),s.add(b,a,c);
}
dfs2(1,0,0);
cout<<ans/2;
return 0;
}