【WC2010】重建计划(分数规划+长链剖分)

长链剖分

因为有很多巨佬只是讲了一下大致的做法,并没有详细地解释如何维护,所以就有了这篇题解。巨佬们都不屑于详细写,我太弱了/kk

首先先对原树进行长链剖分。

先讲一些定义:

  • 一条路径的权值和指的是这条路径上的所有边权之和

  • 一条路径的长度指的是这条路径包含多少条边

  • \(dep_i\) 表示 \(i\) 的深度。

  • \(maxdep_i\) 表示在 \(i\) 子树内的最深的节点的深度。

  • \(dis_i\) 表示从根节点到节点 \(i\) 的路径权值和是什么。

  • \(id_i\) 表示点 \(i\) 在线段树中所代表的位置

  • \(val_i\) 表示的是线段树的点的权值

然后容易想到先套一个二分套在最外层来二分 \(AugValue\) (AV)

不放假设现在的二分值为 \(mid\),然后根据分数规划,我们可以把所有边的边权减去 \(mid\),那么只要找到一条长度符合要求的路径,并且它的边权和大于 \(0\) 即可,同时也更新一下 \(dis_i\)

那我们就不妨找到树中长度符合要求的最大路径,然后再判断是否大于 \(0\)

让我们先用朴素的树形 dp 的做法想一下,设 \(dp_{i,j}\) 为以 \(i\) 为根的子树,往下走 \(j\) 步的最大权值和。

那么我们每次枚举一个 \(u\) 的儿子 \(a\)\(b\),有:

\[dp_{i,j}=\max_{k=0}^{j}(dp_{a,k}+dp_{b,j-k}) \]

然而这个时间复杂度我们是承受不起的。

所以考虑用长链剖分进行优化。

每一次 \(\operatorname{check}\),我们都先按着重儿子的方向进行递归,然后再递归其他的儿子。

先说一下每一次递归到 \(u\),并递归完其重儿子后,我们先循环枚举 \(u\) 的每一个非重儿子 \(son\),每一次循环中,我们先递归这个非重儿子 \(son\)

可能看代码好解释:(伪代码)

void solve(int u)
{
   solve(重儿子);
   for(son=非重儿子)
   {
   	solve(son);
    	......
   }
}

如图,我们现在已经遍历完了重儿子(红色节点)、\(son\) 以及 \(son\) 左边的儿子,现在要更新经过 \(u\)\(son\) 的且在 \(u\) 的子树内的路径的最大权值和,就好像这样的路径:(蓝色路径)

这时,我们循环枚举这条路径中 \(u\rightarrow b\)长度(不是权值和) \(j\)。那么 \(a\rightarrow u\) (不含 \(u\))长度的取值范围是 \(A=[\max(1,L-j),\min(R-j,maxdep_u-dep_u)]\),那么 \(dep_a\) 的取值范围是 \(B=dep_u+A\)

那么我们就要求出 \(u\) 的子树中,所有深度在 \(B\) 的点中,\(dis_i\) 最大的那个点 \(i\) 当做 \(a\),那么这条路径就是 \(u\rightarrow b\) 的长度为 \(j\) 的路径中权值和最大的那一条。

考虑如何维护这些点中 \(dis\) 最大的那一个。

由于他们需要满足的条件只是在一个深度范围内,所以想到了线段树。

关键是如何用较低的时间复杂度查询这么多子树中所有符合条件的点的 \(dis\) 的最大值。

所以在递归 \(son\) 左边的儿子以及 \(u\) 的重儿子时,要预先维护好这个东西。

那么我们现在进入到递归到 \(u\) 的重儿子或 \(son\) 左边的某个儿子 \(a\) 的情景:

如图,蓝色为长链,当我们递归完 \(a\) 准备回溯的时候,我们需要维护好这棵树内每一层的节点的 \(dis\) 的最大值。

对于某一层(比如 \(dep=k\) 这一层)的所有节点的 \(dis\),我们可以把它们全部合并入 \(a\) 所在的重链的满足 \(dep=k\) 的这个点 \(v\) 里面。

也就是说,不妨设在 \(a\) 子树内的 \(dep=k\) 的这一层的所有节点的集合为 \(\{a_1,a_2,\dots,a_m\}\),那么我们只用把点 \(v\) 在线段树中所代表的位置 \(id_v\) 的权值 \(val_{id_v}\) 设为 \(\max_{i=1}^m a_i\)。所以我们要查询这一层只用查询这一个点的权值就好了。

至于如何合并,我们只需在遍历 \(a\) 的儿子的时候更新就好了。

回到一开始的问题,我们现在就可以维护一个深度区间内最大的 \(dis_i\) 了。

同样,对于 \(u\) 来说,也要做上述的处理。

事到如此,整道题就已经做完了。

接下来有一些需要注意的事情:

  • INF 要开够,至少 \(N\times \dfrac{\max v_i}{2}=50000000000\)

  • 有很多东西需要开 double。

代码:

#include<bits/stdc++.h>
 
#define N 100010
#define INF 50000000010
 
using namespace std;
 
int n,L,R;
int cnt,head[N],nxt[N<<1],to[N<<1],w[N<<1];
int tot,fa[N],num[N],dep[N],maxd[N],son[N],sonw[N],id[N];
double ans,last[N],dis[N],ww[N<<1],sonww[N],maxn[N<<2];
 
void dfs1(int u)
{
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(to[i]==fa[u]) continue;
        maxd[v]=dep[v]=dep[u]+1;
        fa[v]=u;
        dfs1(v);
        if(maxd[v]>maxd[son[u]])
        {
            maxd[u]=maxd[v];
            son[u]=v;
            sonw[u]=w[i];
        }
    }
}
 
void dfs2(int u)
{
    id[u]=++tot;
    if(son[u])
        dfs2(son[u]);
    for(int i=head[u];i;i=nxt[i])
        if(to[i]!=fa[u]&&to[i]!=son[u])
            dfs2(to[i]);
}
 
void adde(int u,int v,int wi)
{
    to[++cnt]=v;
    w[cnt]=wi;
    nxt[cnt]=head[u];
    head[u]=cnt;
}
 
void getnum(int k,int l,int r)
{
    if(l==r)
    {
        num[l]=k;
        return;
    }
    int mid=(l+r)>>1;
    getnum(k<<1,l,mid);
    getnum(k<<1|1,mid+1,r);
}
 
void clear(int k,int l,int r)
{
    maxn[k]=-INF;
    if(l==r) return;
    int mid=(l+r)>>1;
    clear(k<<1,l,mid);
    clear(k<<1|1,mid+1,r);
}
 
void update(int k,int l,int r,int x,double y)
{
    maxn[k]=max(maxn[k],y);
    if(l==r) return;
    int mid=(l+r)>>1;
    if(x<=mid) update(k<<1,l,mid,x,y);
    else update(k<<1|1,mid+1,r,x,y);
}
 
double query(int k,int l,int r,int ql,int qr)
{
    if(ql<=l&&r<=qr) return maxn[k];
    int mid=(l+r)>>1;
    double ans=-INF;
    if(ql<=mid) ans=max(ans,query(k<<1,l,mid,ql,qr));
    if(qr>mid) ans=max(ans,query(k<<1|1,mid+1,r,ql,qr));
    return ans;
}
 
void solve(int u)
{
    update(1,1,n,id[u],dis[u]);
    if(son[u])
    {
        int v=son[u];
        dis[v]=dis[u]+sonww[u];
        solve(v);
    }
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==fa[u]||v==son[u]) continue;
        dis[v]=dis[u]+ww[i];
        solve(v);
        for(int j=1;j<=maxd[v]-dep[u];j++)
        {
            last[j]=maxn[num[id[v]+j-1]];
            if(j<=R)
            {
            	int ql=max(1,id[u]+L-j),qr=min(id[u]+R-j,id[u]+maxd[u]-dep[u]);
                double tmp=query(1,1,n,ql,qr);
                ans=max(ans,tmp+last[j]-dis[u]*2);
            }
        }
        for(int j=1;j<=maxd[v]-dep[u];j++)
            update(1,1,n,id[u]+j,last[j]);
    }
    ans=max(ans,query(1,1,n,id[u]+L,min(id[u]+R,id[u]+maxd[u]-dep[u]))-dis[u]);
}
 
bool check(double mid)
{
    for(int i=1;i<=cnt;i++) 
        ww[i]=1.0*w[i]-mid;
    for(int i=1;i<=n;i++)
        sonww[i]=1.0*sonw[i]-mid;
    ans=-INF;
    clear(1,1,n);
    solve(1);
    return ans>0;
}
 
int main()
{
    scanf("%d%d%d",&n,&L,&R);
    for(int i=1;i<n;i++)
    {
        int u,v,w;
        scanf("%d%d%d",&u,&v,&w);
        adde(u,v,w),adde(v,u,w);
    }
    dfs1(1);
    dfs2(1);
    getnum(1,1,n);
    double l=0,r=1000000;
    while(r-l>1e-5)
    {
        double mid=0.5*(l+r);
        if(check(mid)) l=mid;
        else r=mid;
    }
    printf("%.3lf\n",l);
    return 0;
}
posted @ 2022-10-30 10:08  ez_lcw  阅读(25)  评论(0编辑  收藏  举报