[CTSC2018]暴力写挂
给定两棵带边权的树\(T,T'\),大小为\(n\),要求\(max{dep[x]+dep[y]-dep[LCA(x,y)]-dep'[LCA'(x,y)]}\)。\(n\leq 366666\)。
首先式子可以化成\((dep[x]+dep[y]+dist(x,y))/2-dep'[LCA'(x,y)]\),出现\(dist\)之后就可以考虑统计路径信息了。这一般就三种做法。
考虑长剖和点分治,发现在合并两棵子树的信息时似乎不好确定在另一棵树上的\(LCA\),所以考虑一下边分治。
假设现在用边分治已经把树分成两半了,一半染黑一半染白,考虑枚举\(LCA'\),分别维护\(T'\)中每棵子树黑点和白点最大的\(dep[u]+dist(u,分治边一端)\)即可。合并子树的时候就用黑点和白点的最大值计算以下答案。枚举\(LCA'\)时只需要枚举当前分治的子树的点即可,所以对这棵子树在\(T'\)中的对应点建出虚树。时间复杂度\(\mathcal{O}(n\log^2 n)\)。
考虑一下优化,瓶颈似乎是建虚树的\(n\log n\)。如果一开始就把所有点按照\(T'\)的\(dfn\)排序,在分治过程中就很容易按照\(dfn\)序从小到大把点分配给下一层分治的子树。求\(LCA\)的时候用\(RMQ\)算法\(\mathcal{O}(1)\)求即可,总复杂度就是\(\mathcal{O}(n\log n)\)。
#include<bits/stdc++.h>
#define rg register
#define il inline
#define cn const
#define gc getchar()
#define fp(i, a, b) for(int i = (a), ed = (b); i <= ed; ++i)
#define fb(i, a, b) for(int i = (a), ed = (b); i >= ed; --i)
#define go(u) for(int i = head[u]; ~i; i = e[i].nxt)
#define mp make_pair
#define pb push_back
#define fi first
#define se second
using namespace std;
typedef cn int cint;
typedef long long LL;
typedef pair<int, LL> pr;
typedef pair<int, int> pr2;
il void rd(int &x){
x = 0;
rg int f(1); rg char c(gc);
while(c < '0' || '9' < c){if(c == '-')f = -1; c = gc;}
while('0' <= c && c <= '9')x = (x<<1)+(x<<3)+(c^48), c = gc;
x *= f;
}
cint maxn = 366676, maxm = maxn<<1;
cn LL inf = 0x3f3f3f3f3f3f3f3f;
int n;
vector<pr> t1[maxn], t2[maxn];
LL ans = -inf;
int m;
struct edge{
int to, nxt;
LL dis;
}e[maxm<<1];
int head[maxm], k;
il void add(cint &u, cint &v, cn LL &w){e[k] = (edge){v, head[u], w}, head[u] = k++;}
il void add2(cint &u, cint &v, cn LL &w){add(u, v, w), add(v, u, w);}
int siz[maxm], vrt[maxn], tmp[2][maxn], dfn[maxm], lst[maxm], id[maxm];
LL Dep[maxn], dis[maxm];
bool mark[maxm<<1];
int dep[maxn], dfn2[maxn], tot, fir[maxn], fnl[maxn], elr[maxn<<1], lg[maxn<<1], len;
pr2 mn[2][maxn<<1][20];
int stk[maxn], tp, rt;
int typ[maxn];
LL dis2[maxn], f[maxn][2], val[maxn];
vector<int> g[maxn];
void rebuild(int u, int pre){
int lst = 0;
for(auto &x : t1[u])if(x.fi^pre){
if(!lst)add2(u, x.fi, x.se), lst = u;
else ++m, add2(lst, m, 0), add2(m, x.fi, x.se), lst = m;
rebuild(x.fi, u);
}
}
void getdep(int u, int pre){go(u)if(e[i].to^pre)Dep[e[i].to] = Dep[u]+e[i].dis, getdep(e[i].to, u);}
void dfs(int u, int pre){
elr[++len] = u, fir[u] = len, dep[u] = dep[pre]+1, dfn2[u] = ++tot;
for(auto &x : t2[u])if(x.fi^pre)dis2[x.fi] = dis2[u]+x.se, dfs(x.fi, u), elr[++len] = u;
fnl[u] = len;
}
il bool cmp(cint &x, cint &y){return dfn2[x] < dfn2[y];}
il int getlca(int u, int v){
if(dfn2[u] > dfn2[v])swap(u, v);
rg int len = fnl[v]-fir[u]+1;
return min(mn[0][fir[u]][lg[len]], mn[1][fnl[v]][lg[len]]).se;
}
il void link(cint &u, cint &v){
if(dep[u] < dep[rt] || !rt)rt = u;
g[u].pb(v), f[u][0] = f[u][1] = f[v][0] = f[v][1] = val[u] = val[v] = -inf;
}
il void ins(cint &u){
if(!tp)return stk[++tp] = u, void();
rg int lca = getlca(u, stk[tp]);
if(lca == stk[tp])return stk[++tp] = u, void();
while(tp > 1 && dfn2[stk[tp-1]] >= dfn2[lca])link(stk[tp-1], stk[tp]), --tp;
if(lca != stk[tp])link(lca, stk[tp]), stk[tp] = lca;
stk[++tp] = u;
}
il void build(int *nd, int n){
fp(i, 1, n)ins(nd[i]);
while(tp > 1)link(stk[tp-1], stk[tp]), --tp;
}
void getrt(int u, int pre, cint &n, int &edg, int &mn){
siz[u] = 1;
go(u)if(i != pre && !mark[i])getrt(e[i].to, i^1, n, edg, mn), siz[u] += siz[e[i].to];
if(max(siz[u], n-siz[u]) < mn)edg = pre, mn = max(siz[u], n-siz[u]);
}
void dfs2(int u, int pre, int &tot){
dfn[u] = ++tot, id[tot] = u, siz[u] = 1;
go(u)if(e[i].to != pre && !mark[i]){
dis[e[i].to] = dis[u]+e[i].dis;
dfs2(e[i].to, u, tot);
siz[u] += siz[e[i].to];
}
lst[u] = tot;
}
void dfs3(int u, cn LL &len){
f[u][typ[u]] = val[u];
for(auto &x : g[u]){
dfs3(x, len);
ans = max(ans, (f[u][0]+f[x][1]+len)/2-dis2[u]);
ans = max(ans, (f[u][1]+f[x][0]+len)/2-dis2[u]);
f[u][0] = max(f[u][0], f[x][0]);
f[u][1] = max(f[u][1], f[x][1]);
}
g[u].clear();
}
void slv(int nw, int m, int n, int *nd){
if(m == 1)return;
int edg, rt1, rt2, mn = 0x3f3f3f3f, tot = 0, d1 = 0, d2 = 0;
getrt(nw, -1, m, edg, mn), rt1 = e[edg].to, rt2 = e[edg^1].to, mark[edg] = mark[edg^1] = 1;
dis[rt1] = dis[rt2] = 0, dfs2(rt1, 0, tot), dfs2(rt2, 0, tot), build(nd, n);
fp(i, 1, n){
typ[nd[i]] = dfn[rt1] <= dfn[nd[i]] && dfn[nd[i]] <= lst[rt1];
val[nd[i]] = Dep[nd[i]]+dis[nd[i]];
}
dfs3(rt, e[edg].dis), rt = tp = 0;
fp(i, 1, n){
if(dfn[rt1] <= dfn[nd[i]] && dfn[nd[i]] <= lst[rt1])tmp[0][++d1] = nd[i];
else tmp[1][++d2] = nd[i];
}
fp(i, 1, d1)nd[i] = tmp[0][i];
fp(i, 1, d2)nd[i+d1] = tmp[1][i];
slv(rt1, siz[rt1], d1, nd), slv(rt2, siz[rt2], d2, nd+d1);
}
int main(){
rd(n), m = n;
fp(i, 2, n){
rg int u, v;
rg LL w;
rd(u), rd(v), scanf("%lld", &w);
t1[u].pb(mp(v, w)), t1[v].pb(mp(u, w));
}
fp(i, 2, n){
rg int u, v;
rg LL w;
rd(u), rd(v), scanf("%lld", &w);
t2[u].pb(mp(v, w)), t2[v].pb(mp(u, w));
}
memset(head, -1, sizeof head), rebuild(1, 0), getdep(1, 0), dfs(1, 0);
fp(i, 2, len)lg[i] = lg[i>>1]+1;
fp(i, 1, len)mn[0][i][0] = mn[1][i][0] = mp(dep[elr[i]], elr[i]);
fp(j, 1, 19){
rg int s = 1<<j;
fp(i, 1, len){
if(i+s > len+1)break;
mn[0][i][j] = min(mn[0][i][j-1], mn[0][i+s/2][j-1]);
}
fb(i, len, 1){
if(i-s < 0)break;
mn[1][i][j] = min(mn[1][i][j-1], mn[1][i-s/2][j-1]);
}
}
fp(i, 1, n)vrt[i] = i;
sort(vrt+1, vrt+1+n, cmp), slv(1, m, n, vrt);
fp(i, 1, n)ans = max(ans, Dep[i]-dis2[i]);
printf("%lld\n", ans);
return 0;
}