UOJ #822 街头庆典

题目描述

给定一棵 \(n\) 个节点的无根树,树上每条边都有相同的长度 \(D\)

你可以割掉树上的若干条边,割掉第 \(i\) 条边需要花费 \(w_i\) 的代价。

把一些边割掉后,树变成了若干个连通块。你想使得每个连通块的直径长度之和加上割边付出的代价之和最小,输出这个最小值。

\(2\le n\le 2\times 10^5,1\le w_i,D\le 10^9\)。时空限制:3s 1024MB

我们考虑最优解中的一条边,发现只有在这条边的两个端点分别为各自连通块内直径的端点时,才有可能成为最优解;否则,连上这条边后,总的直径长度之和不会变大,我们凭空省下了一条边的花费。

考虑从根开始往下的长链,在最优解中,这条长链会被划分为若干部分。对于每一部分,都要求长链的两个端点都是直径的端点(除了第一部分和最后一部分)。这要求,要么其直径就恰好是这条长链,要么这段长链上有偶数条边,且直径从最中间的那个点向外延伸。第一条和最后一条长链除外:他们只要求靠下或靠上的那个点的确是直径的端点。

对于第一种情况(直径恰好就是这段长链),设 \(f_u\) 表示 \(u\) 子树内的答案,设 \(s_{u,i}\) 表示 \(u\) 子树内,割掉所有和 \(u\) 距离为 \(i\) 的边之后,下面那些子树的 \(f\) 之和,与这些边的边权之和。这里 \(s\) 中没有计算 \(u\) 所在连通块的贡献。那么把长链拿出来后,设 \(s'_{k,p}\) 表示长链上的第 \(k\) 个点,割掉其轻子树内和 \(k\) 距离为 \(p\) 的边,的贡献之和。那么此时如果截出一段长链,端点分别为 \(i,j\),则转移的贡献可以写成 \(val(i,j)=\sum s'_{k,\min(k-i,j-k)}\) 的形式,含义是中间分叉出去的边不能超过这个直径。

注意这里 \(s'\) 的状态数总量是 \(O(n)\) 的,这个就是普通长剖的复杂度分析。

从大到小扫描 \(i\),考虑 \(i\to j\) 之间一个 \(s'_k\) 的贡献在 \(i\) 变为 \(i-1\) 时有什么变化。那么就是第二维 \(\min(k-i,j-k)\) 变成了 \(\min(k-i+1,j-k)\)。当 \(k-i\ge j-k\) 时,这个没有变化;当 \(k-i<j-k\) 时,只有 \(k-i<len_k\) 的才会发生变化,这里 \(len_k\) 表示长链上第 \(k\) 个点的轻子树长链 length 的最大值。

那么每次我们暴力枚举所有 \(k-i<len_k\)\(k\),用它来更新所有 \(val(i,j)+f_{j+1}\) 的值,是一个后缀加的形式。使用线段树维护,可以做到 \(O(n\log n)\)

对于第二种情况(直径的中心在外面),我们设 \(g_{u,i}\),表示 \(u\) 子树内,钦定 \(u\) 所在的连通块的最浅点为 \(u\)\(i\) 级祖先的最小代价。那么转移 \(g\) 的时候,只需要讨论直径的中心在哪里:\(u\),或者某条 \(u\to v\) 的边上,或者某个儿子 \(v\) 的子树内。对于第一种,有 \(g_{u,i}\leftarrow s_{u,i}\)。对于第二种,如果直径在 \(u\to v\) 这条边的中点上,则 \(g_{u,i}\leftarrow s_{v,i}+\sum_{p\in \text{son}(u),p\neq v}s_{p,i-1}\)。对于第三种,有 \(g_{u,i}\leftarrow g_{v,i+1}+\sum_{p\in \text{son}(u)}s_{p,i-1}\)

然后这里还有 \(D\times\) 直径的贡献,这里我们在中心处计算即可。也就是把前两种转移的权值分别加上 \(D\times 2i\)\(D\times (2i+1)\)

那么同理设 \(g'_{k,p}\) 表示长链上第 \(k\) 个点,只考虑轻儿子时的 DP 值,那么有 \(p\le len_k\)

转移的时候,相当于 \(f_i\leftarrow g_{k,p}+f_{2k-i+1}+val'_{i,2k-i}\)。这里 \(val'\) 指的是,\([i,2k-i]\) 这段区间内的点 \(x\),除了 \(k\) 之外,也不能延伸出去超过 \(\min(x-i,2k-i-x)\) 的长度。这个 \(val'\) 的维护和上面是类似的。

对于第一段和最后一段的特殊情况,我们在这条长链的开头和结尾都添加等同于长链长度这么多个点,在转移最后一段的时候允许超出原本的长链末端,然后对于 \(f_{\text{root}}\) 我们把他对长链前面新增的点的 \(f\) 取 min 即可。

最后我们考虑 \(g\) 怎么算,不难发现只需要算出根节点的 \(g\),发现根节点的 \(g\) 几乎就是我们新增的那些点的 \(f\),只不过由于限定了中心在根节点往下的位置,所以转移的区间有一些变化。

综上,本题在 \(O(n\log n)\) 时间内解决。

注意这里做的实际上是后缀加,查询全局 min,因此,我们使用并查集维护,可以做到 \(O(n\alpha(n))\) 或者 \(O(n)\)

#include<bits/stdc++.h>

#define ll long long
#define mk make_pair
#define fi first
#define se second

using namespace std;

inline int read(){
	int x=0,f=1;char c=getchar();
	for(;(c<'0'||c>'9');c=getchar()){if(c=='-')f=-1;}
	for(;(c>='0'&&c<='9');c=getchar())x=x*10+(c&15);
	return x*f;
}

template<typename T>void cmax(T &x,T v){x=max(x,v);}
template<typename T>void cmin(T &x,T v){x=min(x,v);}

const int N=2e5+5;
int hson[N],len[N],D,top[N],fa[N],wf[N],n;
vector<pair<int,int> >G[N];
void dfs1(int u){
	for(auto [v,w]:G[u])if(v!=fa[u]){
		fa[v]=u,wf[v]=w,dfs1(v);
		if(len[v]>len[hson[u]]||hson[u]==0)hson[u]=v;
	}
	if(hson[u])len[u]=len[hson[u]]+1;
}
void dfs2(int u,int tp){
	top[u]=tp;if(hson[u])dfs2(hson[u],tp);
	for(auto [v,w]:G[u])if(v!=fa[u]&&v!=hson[u])dfs2(v,v);
}

void solve(vector<int>nodes);

vector<ll>dp_g[N],dp_s[N];
ll dp_f[N];
const ll INF=1e18;

void DP(int u){
	vector<int>nodes;
	int tu=u;
	while(u)nodes.emplace_back(u),u=hson[u];
	for(int p:nodes)for(auto [v,w]:G[p])if(v!=hson[p]&&v!=fa[p])DP(v);
	solve(nodes);
}

struct sgt{
	#define ls(p) (p<<1)
	#define rs(p) (p<<1|1)
	ll lz[N<<4],d[N<<4],M,k;
	void pushup(int p){d[p]=min(d[ls(p)],d[rs(p)]);}
	void build(int n){
		n++;
		M=1,k=0;while(M<n)M<<=1,k++;
		for(int i=1;i<=2*M-1;i++)d[i]=INF;
	}
	void app(ll f,int p){d[p]+=f,lz[p]+=f;}
	void pushdown(int p){app(lz[p],ls(p)),app(lz[p],rs(p)),lz[p]=0;}
	void add(int l,int r,ll f){
		l++,r++;
		if(l>r)return ;
		l+=M-1,r+=M;
		for(int i=k;i>=1;i--)if(((l>>i)<<i)!=l)pushdown(l>>i);
		for(int i=k;i>=1;i--)if(((r>>i)<<i)!=r)pushdown(r>>i);
		int nl=l,nr=r;
		while(l<r){
			if(l&1)app(f,l++);
			if(r&1)app(f,--r);
			l>>=1,r>>=1;
		}
		l=nl,r=nr;
		for(int i=1;i<=k;i++)if(((l>>i)<<i)!=l)pushup(l>>i);
		for(int i=1;i<=k;i++)if(((r>>i)<<i)!=r)pushup(r>>i);
	}
	void setc(int p,ll v){
		p++;p+=M-1;
		for(int i=k;i>=1;i--)pushdown(p>>i);
		d[p]=v;
		for(int i=1;i<=k;i++)pushup(p>>i);
	}
	ll qval(int p){
		p++;p+=M-1;
		for(int i=k;i>=1;i--)pushdown(p>>i);
		return d[p];
	}
	ll qmin(int l,int r){
		l++,r++;
		l+=M-1,r+=M;
		for(int i=k;i>=1;i--)if(((l>>i)<<i)!=l)pushdown(l>>i);
		for(int i=k;i>=1;i--)if(((r>>i)<<i)!=r)pushdown(r>>i);
		ll mn=1e18;
		while(l<r){
			if(l&1)cmin(mn,d[l++]);
			if(r&1)cmin(mn,d[--r]);
			l>>=1,r>>=1;
		}
		return mn;
	}
}T;

void solve(vector<int>nodes){
	if(nodes.size()==1){
		int u=nodes[0];
		dp_f[u]=0,dp_g[u].resize(1,0),dp_s[u].resize(1,0);
		return ;
	}
	
	int k=nodes.size()*2,rt=nodes[0];
	dp_g[rt].resize(len[rt]+1,INF);
	vector<ll>f(k,INF);
	T.build(k+k);
	vector<int>mxl(k);
	vector<vector<ll> >s(k),g(k);

	for(int i=0;i<k;i++){
		if(i<k/2){s[i].resize(1,0),g[i].resize(1,0),mxl[i]=0;continue;}
		int u=nodes[i-k/2],d=0;
		for(auto [v,w]:G[u])if(v!=fa[u]&&v!=hson[u])cmax(d,len[v]+1);

		mxl[i]=d,s[i].resize(d+1,0);
		for(auto [v,w]:G[u])if(v!=fa[u]&&v!=hson[u]){
			for(int j=1;j<=len[v]+1;j++)s[i][j]+=dp_s[v][j-1];
			s[i][0]+=w+dp_f[v];
		}

		g[i]=s[i];
		for(int j=0;j<=d;j++)g[i][j]+=2ll*j*D;
		for(auto [v,w]:G[u])if(v!=fa[u]&&v!=hson[u]){
			for(int j=0;j<=len[v];j++){
				cmin(g[i][j],s[i][j]+dp_s[v][j]-(j>0?dp_s[v][j-1]:w+dp_f[v])+1ll*(j+j+1)*D);
				if(j<len[v])cmin(g[i][j],dp_g[v][j+1]+s[i][j]-(j>0?dp_s[v][j-1]:w+dp_f[v]));
			}
		}
	}

	set<pair<int,int> >ids;

	for(int i=k;i<=k+k;i++)T.setc(i,1ll*i*D+s[k-1][0]);
	for(int i=k-1;i>=0;i--){
		ids.insert(mk(i-mxl[i],i));
		for(auto [w,j]:ids){
			if(j-mxl[j]<=i){
				if(j+j-i+1>k/2)cmin(f[i],g[j][j-i]+T.qval(j+j-i+1)-s[j][j-i]-1ll*(j+j-i+1)*D);
			}
			else break;
		}
		cmin(f[i],-1ll*(i+1)*D+T.qmin(i+1,k));
		if(1<=i&&i<=k/2){
			int r=k/2-i;
			cmin(dp_g[rt][r],-1ll*(i+1)*D+T.qmin(k/2+r,k));
		}

		if(i>=1){
			if(i>k/2)T.setc(i,f[i]+wf[nodes[i-k/2]]+1ll*i*D);
			for(auto [w,j]:ids){
				if(j-mxl[j]<i){
					int l=j+j-i+2;
					T.add(l,k+k,s[j][j-i+1]-s[j][j-i]);
				}
				else break;
			}
			T.add(i,k+k,s[i-1][0]);
		}
	}

	for(int i=0;i<k/2;i++)cmin(f[k/2],f[i]);
	dp_f[rt]=f[k/2],dp_s[rt].resize(len[rt]+1);

	vector<ll>sum(k);
	for(int i=k/2;i<k;i++){
		for(int j=0;j<=mxl[i];j++)sum[j+i-k/2]+=s[i][j];
	}

	assert(len[rt]==k/2-1);
	for(int i=0;i<len[rt];i++)dp_s[rt][i]=f[k/2+i+1]+wf[nodes[i+1]]+sum[i];
}

signed main(void){

	n=read(),D=read();
	for(int i=2;i<=n;i++){
		int u=read(),v=read(),w=read();
		G[u].emplace_back(mk(v,w)),G[v].emplace_back(mk(u,w));
	}
	dfs1(1),dfs2(1,1),DP(1);
	cout<<dp_f[1]<<endl;

	return 0;
}
posted @ 2024-06-20 19:52  云浅知处  阅读(63)  评论(2编辑  收藏  举报