P5331 [SNOI2019] 通信 题解

考虑使用费用流解决此题。

先设计一个简单的建图方案:

  • 从源点向每个点 \(i\ (1\le i\le n)\) 连一条边 \((1,0)\),向汇点连一条边 \((1,W)\)

  • 从每个点 \(i+n\ (1\le i\le n)\) 向汇点连一条边 \((1,0)\)

  • 从点 \(i\ (1\le i\le n)\) 向点 \(j+n\ (1\le j\le n)\) 连一条边 \((1,|a_i-a_j|)\) 当且仅当 \(j<i\)

直接在这个图上跑最小费用最大流可以得到 \(80\) 分,因为边的个数是 \(n^2\) 级别的。

考虑用分治优化建图。设当前分治的区间为 \([l,r]\),中点为 \(\text{mid}\),正在处理 \([l,\text{mid}]\)\([\text{mid}+1,r]\) 之间的连边。对 \([l,r]\) 中的每一种数新开一个点,向终点连一条边 \((\infty,0)\),相邻两数对应的点再连对应代价的双向边,最后再对 \([l,\text{mid}]\) 中的每个点向对应权值的点连边,对 \([\text{mid}+1,r]\) 中的每个点从对应权值的点向这些点连边即可。这样建图边的个数是 \(O(n\log n)\) 级别的,可以通过此题。

参考代码:

#include<bits/stdc++.h>
#define ll long long
#define mxn 20005
#define rep(i,a,b) for(int i=a;i<=b;++i)
#define rept(i,a,b) for(int i=a;i<b;++i)
using namespace std;
int n,w,s,t,t1,tot,cnt,a[mxn],b[mxn],p[mxn],ps[mxn],vr[5000003],ed[5000003],c[5000003],nx[5000003],hd[mxn];
ll ans,d[mxn],flow[mxn];
queue<int>q;
bool v[mxn];
inline void add(int x,int y,int z,int cs){
	vr[++tot]=y,ed[tot]=z,c[tot]=cs,nx[tot]=hd[x],hd[x]=tot;
	vr[++tot]=x,ed[tot]=0,c[tot]=-cs,nx[tot]=hd[y],hd[y]=tot;
}
bool spfa(){
	memset(d,0x3f,sizeof(d));
	memset(b,0,sizeof(b));
	flow[s]=1e16,d[s]=0,v[s]=1,q.push(s);
	p[t]=-1;
	while(q.size()){
		int x=q.front();q.pop();v[x]=0;
		for(int i=hd[x],y;i;i=nx[i])if(ed[i]&&d[y=vr[i]]>d[x]+c[i]){
			d[y]=d[x]+c[i];
			flow[y]=min(flow[x],(ll)ed[i]);
			p[y]=i;
			if(!v[y])v[y]=1,q.push(y);
		}
	}
	return p[t]!=-1;
}
void update(){
	int x=t,i;
	while(x!=s){
		i=p[x];
		ed[i]-=flow[t];
		ed[i^1]+=flow[t];
		x=vr[i^1];
	}
	ans+=flow[t]*d[t];
}
void solve(int l,int r){
	if(l==r)return;
	int mid=(l+r)>>1;
	solve(l,mid);solve(mid+1,r);
	t1=0;
	rep(i,l,r)b[++t1]=a[i];
	sort(b+1,b+t1+1);
	t1=unique(b+1,b+t1+1)-b-1;
	ps[1]=++cnt;
	rep(i,2,t1){
		ps[i]=++cnt;
		add(cnt,cnt-1,1e9,b[i]-b[i-1]);
		add(cnt-1,cnt,1e9,b[i]-b[i-1]);
	}
	rep(i,l,mid){
		int p=lower_bound(b+1,b+t1+1,a[i])-b;
		add(i,ps[p],1,0);
	}
	rep(i,mid+1,r){
		int p=lower_bound(b+1,b+t1+1,a[i])-b;
		add(ps[p],i+n,1,0);
	}
}
signed main(){
	scanf("%d%d",&n,&w);
	tot=1;
	t=n+n+1;
	rep(i,1,n){
		scanf("%d",&a[i]);
		add(0,i,1,0);
		add(i,t,1,w);
		add(i+n,t,1,0);
	}
	cnt=t;
	solve(1,n);
	while(spfa())update();
	cout<<ans;
    return 0;
}
posted @ 2023-11-08 14:24  zifanwang  阅读(9)  评论(0编辑  收藏  举报  来源