LOJ 2339 「WC2018」通道——边分治+虚树

题目:https://loj.ac/problem/2339

两棵树的话,可以用 CTSC2018 暴力写挂的方法,边分治+虚树。O(nlogn)。

考虑怎么在这个方法上再加一棵树。发现很难弄。

看了看题解,发现两棵树还有别的做法。

  就是要最大化 d1[ x ] + d2[ x ] + d1[ y ] + d2[ y ] - 2*d1[ lca1(x,y) ] - 2*d2[ lca2(x,y) ] ,考虑在第一棵树 T1 上 dfs 地枚举 lca1 ,那么考虑的答案就是 T1 上在当前点 cr 的不同子树里的 x 和 y 。

  考虑 cr 的之前子树 v1 和当前子树 v2 怎么合并。 v1 和 v2 都记录着自己子树里的答案的两个点 x 和 y 。

  似乎根据树的直径证明的类似方法可以得知 cr 的 x 就是 v1 和 v2 的 x 中的一个, cr 的 y 就是 v1 和 v2 的 y 中的一个。

  所以把两个 x 和两个 y 组合一下,看看谁的 d1[ x ] + d2[ x ] + d1[ y ] + d2[ y ] - 2*d2[ lca2(x,y) ] 最小,谁就是 cr 的 x 和 y 。

  做完 cr 之后,因为要换 lca1 了,所以先贡献一下答案,就是把 cr 记录的 x 和 y 按上面要最大化的那个式子贡献给答案。

  只要 RMQ 求 lca 就可以 O(n) 。

所以三棵树就是在这个两棵树的做法上给第三棵树套一个边分治。

就是在当前边分治的情况下,枚举 lca1 ,式子变成最大化 d1[ x ] + d2[ x ] + d3[ x ] + d1[ y ] + d2[ y ] + d2[ y ] - 2*d1[ lca1(x,y) ] - 2*d2[ lca2(x,y) ] + tw,其中 d1[ ] , d2[ ] 是在树 T1 和 T2 上的带权深度,d3[ ] 是在 T3 上到分治中心边的距离, tw 是分治中心边的权值。

枚举 lca1 之后,不仅 x 和 y 不能在 T1 上 lca1 的同一棵子树中,且 x 和 y 还得分别是 T3 的分治中心两边的点,所以 T2 上 DP 的时候,每个点要记两对 ( x , y ) ,表示 T3 两边的直径。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ll long long
#define pil pair<int,ll>
#define pb push_back
#define mkp make_pair
using namespace std;
ll rdn()
{
  ll ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
ll Mx(ll a,ll b){return a>b?a:b;}
ll Mn(ll a,ll b){return a<b?a:b;}
const int N=2e5+5,K=17; const ll INF=1e18;
int n,tn,hd[N],xnt=1,to[N<<1],nxt[N<<1];ll w[N<<1];
int siz[N],mn,Rt,lx[N]; vector<pil> vt[N];
ll d1[N],d2[N],d3[N],ans;
namespace T3{
  int hd[N],xnt,to[N<<1],nxt[N<<1];ll w[N<<1];
  int dep[N],bg[N],en[N],tim,st[N][K],lg[N],bin[K+5];
  void add(int x,int y,ll z)
  {
    to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;w[xnt]=z;
    to[++xnt]=x;nxt[xnt]=hd[y];hd[y]=xnt;w[xnt]=z;
  }
  void ini_dfs(int cr,int fa)
  {
    bg[cr]=++tim; st[tim][0]=cr;
    for(int i=hd[cr],v;i;i=nxt[i])
      if((v=to[i])!=fa)
    {
      dep[v]=dep[cr]+1; d3[v]=d3[cr]+w[i];
      ini_dfs(v,cr); st[++tim][0]=cr;/////
    }
    en[cr]=tim;
  }
  void init()
  {
    ll z;
    for(int i=1,u,v;i<n;i++)
      u=rdn(),v=rdn(),z=rdn(),add(u,v,z);
    ini_dfs(1,0); int tn=n<<1;
    for(int i=2;i<=tn;i++)lg[i]=lg[i>>1]+1;
    bin[0]=1;for(int i=1;i<=lg[tn];i++)bin[i]=bin[i-1]<<1;
    for(int t=1;t<=lg[tn];t++)
      for(int i=1;i+bin[t]-1<=tn;i++)
    {
      int u=st[i][t-1], v=st[i+bin[t-1]][t-1];
      st[i][t]=(dep[u]<dep[v]?u:v);
    }
  }
  int get_lca(int x,int y)
  {
    if(bg[x]>bg[y])swap(x,y); int d=lg[en[y]-bg[x]+1];
    int c1=st[bg[x]][d], c2=st[en[y]-bin[d]+1][d];
    int ret=dep[c1]<dep[c2]?c1:c2;
    return ret;
  }
}
namespace T2{
  int hd[N],xnt,to[N<<1],nxt[N<<1];ll w[N<<1];
  int tim,a[N],ta[N],lca[N],tlca[N],dep[N],sta[N],top;
  struct Node{
    int x,y;ll w;
    Node(int x=0,int y=0,ll w=0):x(x),y(y),w(w) {}
  }dp[N][2];
  void add(int x,int y,ll z)
  {
    to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;w[xnt]=z;
    to[++xnt]=x;nxt[xnt]=hd[y];hd[y]=xnt;w[xnt]=z;
  }
  void ini_dfs(int cr,int fa)
  {
    while(top&&dep[sta[top]]>=dep[cr])top--;
    a[++tim]=cr; lca[tim]=sta[top]; sta[++top]=cr;
    for(int i=hd[cr],v;i;i=nxt[i])
      if((v=to[i])!=fa)
    {
      dep[v]=dep[cr]+1; d2[v]=d2[cr]+w[i];
      ini_dfs(v,cr);
    }
  }
  void init()
  {
    ll z;
    for(int i=1,u,v;i<n;i++)
      u=rdn(),v=rdn(),z=rdn(),add(u,v,z);
    ini_dfs(1,0);
  }
  ll calc(int x,int y)
  {
    ll ret=d1[x]+d1[y]+d2[x]+d2[y]+d3[x]+d3[y];
    int tmp=T3::get_lca(x,y);
    ret-=(d3[tmp]<<1ll);
    return ret;
  }
  Node operator+ (const Node &a,const Node &b)
  {
    int x1=a.x,y1=a.y,x2=b.x,y2=b.y;
    Node ret=Node(0,0,-1); ll tmp;
    if(x1&&x2)
      { tmp=calc(x1,x2); if(tmp>ret.w)ret=Node(x1,x2,tmp);}
    if(x1&&y2)
      { tmp=calc(x1,y2); if(tmp>ret.w)ret=Node(x1,y2,tmp);}
    if(y1&&x2)
      { tmp=calc(y1,x2); if(tmp>ret.w)ret=Node(y1,x2,tmp);}
    if(y1&&y2)
      { tmp=calc(y1,y2); if(tmp>ret.w)ret=Node(y1,y2,tmp);}
    return ret;
  }
  Node mx(Node a,Node b){ return a.w>b.w?a:b;}
  void link(int cr,int v,ll tw)
  {
    ll tmp=d2[cr]<<1ll;
    ans=Mx(ans,(dp[cr][0]+dp[v][1]).w+tw-tmp);
    ans=Mx(ans,(dp[cr][1]+dp[v][0]).w+tw-tmp);
    dp[cr][0]=mx(dp[cr][0],mx(dp[v][0],dp[cr][0]+dp[v][0]));
    dp[cr][1]=mx(dp[cr][1],mx(dp[v][1],dp[cr][1]+dp[v][1]));
    dp[v][0]=dp[v][1]=Node(0,0,-1);
  }
  int solve(int l,int r,ll tw)
  {
    sta[top=1]=a[l];
    for(int i=l+1;i<=r;i++)
      {
    int lm=dep[lca[i]];
    while(top&&dep[sta[top]]>lm)
      {
        if(dep[sta[top-1]]>lm)link(sta[top-1],sta[top],tw);
        else link(lca[i],sta[top],tw);
        top--;
      }
    if(sta[top]!=lca[i])sta[++top]=lca[i];
    sta[++top]=a[i];
      }
    for(int i=top-1;i;i--)link(sta[i],sta[i+1],tw);
    dp[sta[1]][0]=dp[sta[1]][1]=Node(0,0,-1);
    int mid=l-1;
    for(int i=l,tl=0;i<=r;i++)
      {
    if(!tl||dep[lca[i]]<dep[tl])tl=lca[i];
    if(!lx[a[i]]) ta[++mid]=a[i], tlca[mid]=tl, tl=0;
      }
    int ret=mid;
    for(int i=l,tl=0;i<=r;i++)
      {
    if(!tl||dep[lca[i]]<dep[tl])tl=lca[i];
    if(lx[a[i]]) ta[++mid]=a[i], tlca[mid]=tl, tl=0;
      }
    for(int i=l;i<=r;i++)a[i]=ta[i];
    for(int i=l;i<=r;i++)lca[i]=tlca[i];
    return ret;
  }
}
void add(int x,int y,ll z)
{
  to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;w[xnt]=z;
  to[++xnt]=x;nxt[xnt]=hd[y];hd[y]=xnt;w[xnt]=z;
}
void del_ed(int x,int y)
{
  if(to[hd[x]]==y)hd[x]=nxt[hd[x]];
  else
    {
      for(int i=hd[x],pr;i;pr=i,i=nxt[i])
    if(to[i]==y){nxt[pr]=nxt[i];break;}
    }
  if(to[hd[y]]==x)hd[y]=nxt[hd[y]];
  else
    {
      for(int i=hd[y],pr;i;pr=i,i=nxt[i])
    if(to[i]==x){nxt[pr]=nxt[i];break;}
    }
}
void Rbuild(int cr,int fa)
{
  for(int i=0,lst=0,lm=vt[cr].size();i<lm;i++)
    {
      int v=vt[cr][i].first;ll z=vt[cr][i].second;
      if(v==fa)continue;
      if(!lst)add(cr,v,z), lst=cr;
      else{ tn++; add(lst,tn,0); add(tn,v,z); lst=tn;}
    }
  for(int i=0,v,lm=vt[cr].size();i<lm;i++)
    if((v=vt[cr][i].first)!=fa) Rbuild(v,cr);
}
void get_rt(int cr,int fa,int s)
{
  siz[cr]=1;
  for(int i=hd[cr],v;i;i=nxt[i])
    if((v=to[i])!=fa)
      {
    get_rt(v,cr,s); siz[cr]+=siz[v];
    int mx=Mx(siz[v],s-siz[v]);
    if(mx<mn)mn=mx,Rt=i;
      }
}
void dfs(int cr,int fa,ll lj,bool fx)
{
  d1[cr]=lj; lx[cr]=fx;
  T2::dp[cr][fx]=T2::Node(cr,cr,0);
  T2::dp[cr][!fx]=T2::Node(0,0,-1);
  for(int i=hd[cr],v;i;i=nxt[i])
    if((v=to[i])!=fa) dfs(v,cr,lj+w[i],fx);
}
void solve(int cr,int s,int l,int r)
{
  int u=to[cr^1], v=to[cr]; del_ed(u,v);
  dfs(u,0,0,0); dfs(v,0,0,1); ll tw=w[cr];
  int mid=T2::solve(l,r,tw);
  int ts=siz[v];
  if(ts>1){mn=N;get_rt(v,0,ts);solve(Rt,ts,mid+1,r);}
  ts=s-ts;
  if(ts>1){mn=N;get_rt(u,0,ts);solve(Rt,ts,l,mid);}
}
int main()
{
  n=rdn();ll z;
  for(int i=1,u,v;i<n;i++)
    {
      u=rdn();v=rdn();z=rdn();
      vt[u].pb(mkp(v,z)); vt[v].pb(mkp(u,z));
    }
  T2::init(); T3::init(); tn=n; Rbuild(1,0);
  mn=N;get_rt(1,0,tn);solve(Rt,tn,1,n);
  printf("%lld\n",ans);
  return 0;
}

 

posted on 2019-03-07 08:14  Narh  阅读(341)  评论(0编辑  收藏  举报

导航