AT2112 Non-redundant Drive
题目:https://www.luogu.org/problemnew/show/AT2112
对于这种找路径的就直接上点分治就好。
分治时,算出每一个点到分治重心的后能剩多少油,从分治重心走到每个点最少需要多少起始油量。
对这两个数组排序后合并即可。
注意,合并的时候要保证不属于同一棵子树,这个可以利用boruvka时用到的那个技巧来实现。
#include<bits/stdc++.h>
#define N 220000
#define db double
#define ll long long
#define ldb long double
using namespace std;
inline ll read()
{
char ch=0;
ll x=0,flag=1;
while(!isdigit(ch)){ch=getchar();if(ch=='-')flag=-1;}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return x*flag;
}
const ll inf=1e9;
struct edge{ll to,nxt,w;}e[N*2];
ll num,head[N];
inline void add(ll x,ll y,ll z){e[++num]=(edge){y,head[x],z};head[x]=num;}
bool vis[N];
ll rt,ans,size,cnt1,cnt2,w[N],sz[N],bal[N];
struct node{ll x,w,k;}f[N],g[N];
bool cmp(node a,node b){return a.x<b.x;}
bool operator<(node a,node b){return a.w<b.w;}
void get_rt(ll x,ll fa)
{
sz[x]=1;bal[x]=0;
for(ll i=head[x];i!=-1;i=e[i].nxt)
{
ll to=e[i].to;
if(vis[to]||to==fa)continue;
get_rt(to,x);sz[x]+=sz[to];
bal[x]=max(bal[x],sz[to]);
}
bal[x]=max(bal[x],size-sz[x]);
if(bal[rt]>bal[x])rt=x;
}
void get_sz(ll x,ll fa)
{
sz[x]=1;
for(ll i=head[x];i!=-1;i=e[i].nxt)
{
ll to=e[i].to;
if(vis[to]||to==fa)continue;
get_sz(to,x);sz[x]+=sz[to];
}
}
void cal1(ll x,ll k,ll t,ll fa,ll dep,ll flag)
{
if(fa==rt)flag=x;
if(w[x]>=k)f[++cnt1]=(node){t+w[x],dep,flag};
for(ll i=head[x];i!=-1;i=e[i].nxt)
{
ll to=e[i].to,v=e[i].w;
if(vis[to]||to==fa)continue;
cal1(to,max(v,k+v-w[x]),t-v+w[x],x,dep+1,flag);
}
}
void cal2(ll x,ll k,ll t,ll fa,ll dep,ll flag)
{
if(fa==rt)flag=x;
g[++cnt2]=(node){k,dep,flag};
for(ll i=head[x];i!=-1;i=e[i].nxt)
{
ll to=e[i].to,v=e[i].w;
if(vis[to]||to==fa)continue;
cal2(to,max(k,v-t),t-v+w[to],x,dep+1,flag);
}
}
void solve(ll x)
{
bal[rt=0]=inf;get_rt(x,x);get_sz(rt,rt);vis[rt]=true;
cnt1=0;cal1(rt,0,0,rt,1,rt);sort(f+1,f+cnt1+1,cmp);
cnt2=0;cal2(rt,0,0,rt,0,rt);sort(g+1,g+cnt2+1,cmp);
node a={0,-inf,0},b={0,-inf,0};
for(ll i=1,j=0;i<=cnt1;i++)
{
while(j!=cnt2&&g[j+1].x<=f[i].x)
{
j++;
if(g[j].k==a.k)a=max(a,g[j]);
else
{
b=max(b,g[j]);
if(a<b)swap(a,b);
}
}
if(f[i].k!=a.k)ans=max(ans,f[i].w+a.w);
else ans=max(ans,f[i].w+b.w);
}
for(ll i=head[rt];i!=-1;i=e[i].nxt)
{
ll to=e[i].to;
if(vis[to])continue;
size=sz[to];solve(to);
}
}
int main()
{
ll n=read();
for(ll i=1;i<=n;i++)w[i]=read();
num=-1;memset(head,-1,sizeof(head));
for(ll i=1;i<n;i++)
{
ll x=read(),y=read(),z=read();
add(x,y,z);add(y,x,z);
}
ans=1;size=n;solve(1);printf("%lld\n",ans);
return 0;
}