UOJ #822 街头庆典
题目描述
给定一棵 \(n\) 个节点的无根树,树上每条边都有相同的长度 \(D\)。
你可以割掉树上的若干条边,割掉第 \(i\) 条边需要花费 \(w_i\) 的代价。
把一些边割掉后,树变成了若干个连通块。你想使得每个连通块的直径长度之和加上割边付出的代价之和最小,输出这个最小值。
\(2\le n\le 2\times 10^5,1\le w_i,D\le 10^9\)。时空限制:3s 1024MB
我们考虑最优解中的一条边,发现只有在这条边的两个端点分别为各自连通块内直径的端点时,才有可能成为最优解;否则,连上这条边后,总的直径长度之和不会变大,我们凭空省下了一条边的花费。
考虑从根开始往下的长链,在最优解中,这条长链会被划分为若干部分。对于每一部分,都要求长链的两个端点都是直径的端点(除了第一部分和最后一部分)。这要求,要么其直径就恰好是这条长链,要么这段长链上有偶数条边,且直径从最中间的那个点向外延伸。第一条和最后一条长链除外:他们只要求靠下或靠上的那个点的确是直径的端点。
对于第一种情况(直径恰好就是这段长链),设 \(f_u\) 表示 \(u\) 子树内的答案,设 \(s_{u,i}\) 表示 \(u\) 子树内,割掉所有和 \(u\) 距离为 \(i\) 的边之后,下面那些子树的 \(f\) 之和,与这些边的边权之和。这里 \(s\) 中没有计算 \(u\) 所在连通块的贡献。那么把长链拿出来后,设 \(s'_{k,p}\) 表示长链上的第 \(k\) 个点,割掉其轻子树内和 \(k\) 距离为 \(p\) 的边,的贡献之和。那么此时如果截出一段长链,端点分别为 \(i,j\),则转移的贡献可以写成 \(val(i,j)=\sum s'_{k,\min(k-i,j-k)}\) 的形式,含义是中间分叉出去的边不能超过这个直径。
注意这里 \(s'\) 的状态数总量是 \(O(n)\) 的,这个就是普通长剖的复杂度分析。
从大到小扫描 \(i\),考虑 \(i\to j\) 之间一个 \(s'_k\) 的贡献在 \(i\) 变为 \(i-1\) 时有什么变化。那么就是第二维 \(\min(k-i,j-k)\) 变成了 \(\min(k-i+1,j-k)\)。当 \(k-i\ge j-k\) 时,这个没有变化;当 \(k-i<j-k\) 时,只有 \(k-i<len_k\) 的才会发生变化,这里 \(len_k\) 表示长链上第 \(k\) 个点的轻子树长链 length 的最大值。
那么每次我们暴力枚举所有 \(k-i<len_k\) 的 \(k\),用它来更新所有 \(val(i,j)+f_{j+1}\) 的值,是一个后缀加的形式。使用线段树维护,可以做到 \(O(n\log n)\)。
对于第二种情况(直径的中心在外面),我们设 \(g_{u,i}\),表示 \(u\) 子树内,钦定 \(u\) 所在的连通块的最浅点为 \(u\) 的 \(i\) 级祖先的最小代价。那么转移 \(g\) 的时候,只需要讨论直径的中心在哪里:\(u\),或者某条 \(u\to v\) 的边上,或者某个儿子 \(v\) 的子树内。对于第一种,有 \(g_{u,i}\leftarrow s_{u,i}\)。对于第二种,如果直径在 \(u\to v\) 这条边的中点上,则 \(g_{u,i}\leftarrow s_{v,i}+\sum_{p\in \text{son}(u),p\neq v}s_{p,i-1}\)。对于第三种,有 \(g_{u,i}\leftarrow g_{v,i+1}+\sum_{p\in \text{son}(u)}s_{p,i-1}\)。
然后这里还有 \(D\times\) 直径的贡献,这里我们在中心处计算即可。也就是把前两种转移的权值分别加上 \(D\times 2i\) 和 \(D\times (2i+1)\)。
那么同理设 \(g'_{k,p}\) 表示长链上第 \(k\) 个点,只考虑轻儿子时的 DP 值,那么有 \(p\le len_k\)。
转移的时候,相当于 \(f_i\leftarrow g_{k,p}+f_{2k-i+1}+val'_{i,2k-i}\)。这里 \(val'\) 指的是,\([i,2k-i]\) 这段区间内的点 \(x\),除了 \(k\) 之外,也不能延伸出去超过 \(\min(x-i,2k-i-x)\) 的长度。这个 \(val'\) 的维护和上面是类似的。
对于第一段和最后一段的特殊情况,我们在这条长链的开头和结尾都添加等同于长链长度这么多个点,在转移最后一段的时候允许超出原本的长链末端,然后对于 \(f_{\text{root}}\) 我们把他对长链前面新增的点的 \(f\) 取 min 即可。
最后我们考虑 \(g\) 怎么算,不难发现只需要算出根节点的 \(g\),发现根节点的 \(g\) 几乎就是我们新增的那些点的 \(f\),只不过由于限定了中心在根节点往下的位置,所以转移的区间有一些变化。
综上,本题在 \(O(n\log n)\) 时间内解决。
注意这里做的实际上是后缀加,查询全局 min,因此,我们使用并查集维护,可以做到 \(O(n\alpha(n))\) 或者 \(O(n)\)。
#include<bits/stdc++.h>
#define ll long long
#define mk make_pair
#define fi first
#define se second
using namespace std;
inline int read(){
int x=0,f=1;char c=getchar();
for(;(c<'0'||c>'9');c=getchar()){if(c=='-')f=-1;}
for(;(c>='0'&&c<='9');c=getchar())x=x*10+(c&15);
return x*f;
}
template<typename T>void cmax(T &x,T v){x=max(x,v);}
template<typename T>void cmin(T &x,T v){x=min(x,v);}
const int N=2e5+5;
int hson[N],len[N],D,top[N],fa[N],wf[N],n;
vector<pair<int,int> >G[N];
void dfs1(int u){
for(auto [v,w]:G[u])if(v!=fa[u]){
fa[v]=u,wf[v]=w,dfs1(v);
if(len[v]>len[hson[u]]||hson[u]==0)hson[u]=v;
}
if(hson[u])len[u]=len[hson[u]]+1;
}
void dfs2(int u,int tp){
top[u]=tp;if(hson[u])dfs2(hson[u],tp);
for(auto [v,w]:G[u])if(v!=fa[u]&&v!=hson[u])dfs2(v,v);
}
void solve(vector<int>nodes);
vector<ll>dp_g[N],dp_s[N];
ll dp_f[N];
const ll INF=1e18;
void DP(int u){
vector<int>nodes;
int tu=u;
while(u)nodes.emplace_back(u),u=hson[u];
for(int p:nodes)for(auto [v,w]:G[p])if(v!=hson[p]&&v!=fa[p])DP(v);
solve(nodes);
}
struct sgt{
#define ls(p) (p<<1)
#define rs(p) (p<<1|1)
ll lz[N<<4],d[N<<4],M,k;
void pushup(int p){d[p]=min(d[ls(p)],d[rs(p)]);}
void build(int n){
n++;
M=1,k=0;while(M<n)M<<=1,k++;
for(int i=1;i<=2*M-1;i++)d[i]=INF;
}
void app(ll f,int p){d[p]+=f,lz[p]+=f;}
void pushdown(int p){app(lz[p],ls(p)),app(lz[p],rs(p)),lz[p]=0;}
void add(int l,int r,ll f){
l++,r++;
if(l>r)return ;
l+=M-1,r+=M;
for(int i=k;i>=1;i--)if(((l>>i)<<i)!=l)pushdown(l>>i);
for(int i=k;i>=1;i--)if(((r>>i)<<i)!=r)pushdown(r>>i);
int nl=l,nr=r;
while(l<r){
if(l&1)app(f,l++);
if(r&1)app(f,--r);
l>>=1,r>>=1;
}
l=nl,r=nr;
for(int i=1;i<=k;i++)if(((l>>i)<<i)!=l)pushup(l>>i);
for(int i=1;i<=k;i++)if(((r>>i)<<i)!=r)pushup(r>>i);
}
void setc(int p,ll v){
p++;p+=M-1;
for(int i=k;i>=1;i--)pushdown(p>>i);
d[p]=v;
for(int i=1;i<=k;i++)pushup(p>>i);
}
ll qval(int p){
p++;p+=M-1;
for(int i=k;i>=1;i--)pushdown(p>>i);
return d[p];
}
ll qmin(int l,int r){
l++,r++;
l+=M-1,r+=M;
for(int i=k;i>=1;i--)if(((l>>i)<<i)!=l)pushdown(l>>i);
for(int i=k;i>=1;i--)if(((r>>i)<<i)!=r)pushdown(r>>i);
ll mn=1e18;
while(l<r){
if(l&1)cmin(mn,d[l++]);
if(r&1)cmin(mn,d[--r]);
l>>=1,r>>=1;
}
return mn;
}
}T;
void solve(vector<int>nodes){
if(nodes.size()==1){
int u=nodes[0];
dp_f[u]=0,dp_g[u].resize(1,0),dp_s[u].resize(1,0);
return ;
}
int k=nodes.size()*2,rt=nodes[0];
dp_g[rt].resize(len[rt]+1,INF);
vector<ll>f(k,INF);
T.build(k+k);
vector<int>mxl(k);
vector<vector<ll> >s(k),g(k);
for(int i=0;i<k;i++){
if(i<k/2){s[i].resize(1,0),g[i].resize(1,0),mxl[i]=0;continue;}
int u=nodes[i-k/2],d=0;
for(auto [v,w]:G[u])if(v!=fa[u]&&v!=hson[u])cmax(d,len[v]+1);
mxl[i]=d,s[i].resize(d+1,0);
for(auto [v,w]:G[u])if(v!=fa[u]&&v!=hson[u]){
for(int j=1;j<=len[v]+1;j++)s[i][j]+=dp_s[v][j-1];
s[i][0]+=w+dp_f[v];
}
g[i]=s[i];
for(int j=0;j<=d;j++)g[i][j]+=2ll*j*D;
for(auto [v,w]:G[u])if(v!=fa[u]&&v!=hson[u]){
for(int j=0;j<=len[v];j++){
cmin(g[i][j],s[i][j]+dp_s[v][j]-(j>0?dp_s[v][j-1]:w+dp_f[v])+1ll*(j+j+1)*D);
if(j<len[v])cmin(g[i][j],dp_g[v][j+1]+s[i][j]-(j>0?dp_s[v][j-1]:w+dp_f[v]));
}
}
}
set<pair<int,int> >ids;
for(int i=k;i<=k+k;i++)T.setc(i,1ll*i*D+s[k-1][0]);
for(int i=k-1;i>=0;i--){
ids.insert(mk(i-mxl[i],i));
for(auto [w,j]:ids){
if(j-mxl[j]<=i){
if(j+j-i+1>k/2)cmin(f[i],g[j][j-i]+T.qval(j+j-i+1)-s[j][j-i]-1ll*(j+j-i+1)*D);
}
else break;
}
cmin(f[i],-1ll*(i+1)*D+T.qmin(i+1,k));
if(1<=i&&i<=k/2){
int r=k/2-i;
cmin(dp_g[rt][r],-1ll*(i+1)*D+T.qmin(k/2+r,k));
}
if(i>=1){
if(i>k/2)T.setc(i,f[i]+wf[nodes[i-k/2]]+1ll*i*D);
for(auto [w,j]:ids){
if(j-mxl[j]<i){
int l=j+j-i+2;
T.add(l,k+k,s[j][j-i+1]-s[j][j-i]);
}
else break;
}
T.add(i,k+k,s[i-1][0]);
}
}
for(int i=0;i<k/2;i++)cmin(f[k/2],f[i]);
dp_f[rt]=f[k/2],dp_s[rt].resize(len[rt]+1);
vector<ll>sum(k);
for(int i=k/2;i<k;i++){
for(int j=0;j<=mxl[i];j++)sum[j+i-k/2]+=s[i][j];
}
assert(len[rt]==k/2-1);
for(int i=0;i<len[rt];i++)dp_s[rt][i]=f[k/2+i+1]+wf[nodes[i+1]]+sum[i];
}
signed main(void){
n=read(),D=read();
for(int i=2;i<=n;i++){
int u=read(),v=read(),w=read();
G[u].emplace_back(mk(v,w)),G[v].emplace_back(mk(u,w));
}
dfs1(1),dfs2(1,1),DP(1);
cout<<dp_f[1]<<endl;
return 0;
}