【题解】p6918 Branch Assignment

Branch Assignment

大家应该都可以想到,一个点对另一个点传递信息的价值为它到总部的最短路加上总部到另一个点的最短路,在下文,我们设一个点到总部的最短路加上总部到它的最短路为\(d_i\)

对于分到一个子集的点,若子集的大小为m,那么这个子集的总代价是:

\((m-1)\sum_{i=1}^md_i\)

于是,我们得到一个明显的贪心思想,那就是,d越大的值,它所在的子集就应该越小。然后,我们便可以把那些小的放一个大一点的子集,大的放一个小一点的子集。

我们把这些点按照\(d\)的大小排序,然后可以得到一个明显的转移方程式:

\(f_{i,j}\)为当前是第\(i\)个点,放了\(j\)个子集。

\((r-l-1)\sum_{i=l}^rd_i\)这个为\(cost_{l,r}\)

\(f_{i,j}=min_{k=0}^{i-1}(f_{k,j-1}+cost_{k,i})\)

上面这个方程式明显是\(O(n^3)\)的,通不过此题。

能不能快一点?

可以,我们发现后面的\(cost\)满足四边形不等式,于是可以用四边形不等式优化dp

按理来说,这就已经可以通过了,但是,我在https://www.luogu.com.cn/problem/UVA1737看到了这么一句话:

\(nlog^2\)级的算法

能不能再快一点?

观察题目,注意到正好建立s个集合

想到带权二分,我们通打表证明发现若我们给选的集合数带上代价,是一个凸函数。

然后我们干掉dp的一维,不管集合数那一维,列出方程式:

\(f_{i}=min_{k=0}^{i-1}(f_{k}+cost_{k,i})\)

代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <queue>
#define maxn 50005
#define maxr 200005
#define eps 1e-8

using namespace std;

struct node {
	long long val;
	int c;
	node operator + (node x) {
		return (node){val + x.val,c + x.c};
	}
}f[maxn];

struct dot{
	int c,l,r;
}qu[maxn];

node mins(node q,node p) {
	if (q.val < p.val) return q;
	else return p;
}

int n,b,s,r,nex[maxr],to[maxr],head[maxr],tot = 0,front,tail;
long long edge[maxr],dis[maxn],vis[maxn],sum[maxn],from[maxn],go[maxn],len[maxn],sum1[maxn],d[maxn];

std::priority_queue <std::pair<long long ,long long > > q;

void add(long long x,long long y,long long z) {
	to[++tot] = y; nex[tot] = head[x]; edge[tot] = z; head[x] = tot;
}

void dijkstra(long long st) {
	for (int i = 1;i <= n;i++) dis[i] = 9999999999999999;
	memset(vis,0,sizeof(vis));
	dis[st] = 0;
	q.push(std::make_pair(0,st));
	while (!q.empty()) {
		int x = q.top().second;
		q.pop();
		if (vis[x]) continue;
		vis[x] = 1;
		for (int i = head[x];i;i = nex[i]) {
			int y = to[i],z = edge[i];
			if (dis[y] > dis[x] + z) {
				dis[y] = dis[x] + z;
				q.push(std::make_pair(-dis[y],y));
			}
		}
	}
}

long long cost(int i,int j,long long num) {
	return f[j].val + (long long)(sum[i] - sum[j]) * (long long)(i - j - 1) + num;
}

void run(long long num) {
	qu[front = tail = 1] = {0,1,b};
	for (int i = 1;i <= b;i++) {
		f[i] = {999999999999,0};
		for (int j = 0;j < i;j++) {
			f[i] = mins(f[i],(node){cost(i,j,num),f[j].c + 1});
		}
	}
}

void half() {
	long long l = 0,r = 9999999999999999,ans = 0;
	while (r >= l) {
		long long  mid = (r + l) / 2;
		run(mid);
		if (f[b].c > s) l = mid + 1;
		else if (f[b].c < s) r = (ans = mid) - 1;
		else {printf("%lld\n",f[b].val - mid * s);return;}
	}
	l--;
	run(l);
	printf("%lld\n",f[b].val - l * s);
}

int main() {
	scanf("%d%d%d%d",&n,&b,&s,&r);
	for (int i = 1;i <= r;i++) {
		int x,y;
		long long z;
		scanf("%lld%lld%lld",&from[i],&go[i],&len[i]);
		add(from[i],go[i],len[i]);
	}
	dijkstra(b + 1);
	for (int i = 1;i <= b;i++) d[i] = dis[i];
	memset(head,0,sizeof(head)); tot = 0;
	for (int i = 1;i <= r;i++) add(go[i],from[i],len[i]);
	dijkstra(b + 1);
	for (int i = 1;i <= b;i++) d[i] += dis[i];
	std::sort(d + 1,d + b + 1);
	for (int i = 1;i <= b;i++) sum[i] = sum[i - 1] + d[i];
	for (int i = 1;i <= b;i++) sum[i] += sum1[i];
	half();
	return 0;
}

复杂度:\(O(n^2logn)\)

复杂度不减反增。。。

能不能再快一点?

好了,最后一个优化。

思考一下,我们的d数组排序后是不断递增的,并且,越大的数所在的集合应该越小。

那么我们在dp过程中,越靠后的的\(d_i\)会选取的决策点离它就越近,同时,这也证明了这个dp满足决策单调性

然后我们就可以果断地使用决策单调性优化dp,来把又一个n化成\(logn\)

代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <queue>
#define maxn 50005
#define maxr 200005
#define eps 1e-8

using namespace std;

struct node {
	long long val;
	int c;
	node operator + (node x) {
		return (node){val + x.val,c + x.c};
	}
}f[maxn];

struct dot{
	int c,l,r;
}qu[maxn];

node mins(node q,node p) {
	if (q.val < p.val) return q;
	else return p;
}

int n,b,s,r,nex[maxr],to[maxr],head[maxr],tot = 0,front,tail;
long long edge[maxr],dis[maxn],vis[maxn],sum[maxn],from[maxn],go[maxn],len[maxn],sum1[maxn],d[maxn];

std::priority_queue <std::pair<long long ,long long > > q;

void add(long long x,long long y,long long z) {
	to[++tot] = y; nex[tot] = head[x]; edge[tot] = z; head[x] = tot;
}

void dijkstra(long long st) {
	for (int i = 1;i <= n;i++) dis[i] = 9999999999999999;
	memset(vis,0,sizeof(vis));
	dis[st] = 0;
	q.push(std::make_pair(0,st));
	while (!q.empty()) {
		int x = q.top().second;
		q.pop();
		if (vis[x]) continue;
		vis[x] = 1;
		for (int i = head[x];i;i = nex[i]) {
			int y = to[i],z = edge[i];
			if (dis[y] > dis[x] + z) {
				dis[y] = dis[x] + z;
				q.push(std::make_pair(-dis[y],y));
			}
		}
	}
}

long long cost(int i,int j,long long num) {
	return f[j].val + (long long)(sum[i] - sum[j]) * (long long)(i - j - 1) + num;
}

void search(int i,long long num){
    int now = qu[tail].c,ll = qu[tail].l,rr = qu[tail].r;
    int ret = qu[tail].r + 1;
    while(ll <= rr){
        int mid = (ll + rr) / 2;
        if(cost(mid,i,num) <= cost(mid,now,num)) rr = mid - 1,ret = mid;
        else ll = mid + 1;
    }
    if(ret != qu[tail].l) qu[tail].r = ret - 1;
    else tail--;
    if(ret <= b) qu[++tail] = (dot){i,ret,b};
}

void run(long long num) {
	qu[front = tail = 1] = {0,1,b};
	memset(f,0,sizeof(f));
	for (int i = 1;i <= b;i++) {
		while (front < tail && qu[front].r < i) front++;
		qu[front].l++;
		int j = qu[front].c;
		f[i] = node{cost(i,j,num),f[j].c + 1};
		while (front < tail && cost(qu[tail].l,qu[tail].c,num) >= cost(qu[tail].l,i,num)) tail--;
		search(i,num); 
	}
}

void half() {
	long long l = 0,r = 9999999999999999,ans = 0;
	while (r >= l) {
		long long  mid = (r + l) / 2;
		run(mid);
		if (f[b].c > s) l = mid + 1;
		else if (f[b].c < s) r = (ans = mid) - 1;
		else {printf("%lld\n",f[b].val - mid * s);return;}
	}
	l--;
	run(l);
	printf("%lld\n",f[b].val - l * s);
}

int main() {
	scanf("%d%d%d%d",&n,&b,&s,&r);
	for (int i = 1;i <= r;i++) {
		int x,y;
		long long z;
		scanf("%lld%lld%lld",&from[i],&go[i],&len[i]);
		add(from[i],go[i],len[i]);
	}
	dijkstra(b + 1);
	for (int i = 1;i <= b;i++) d[i] = dis[i];
	memset(head,0,sizeof(head)); tot = 0;
	for (int i = 1;i <= r;i++) add(go[i],from[i],len[i]);
	dijkstra(b + 1);
	for (int i = 1;i <= b;i++) d[i] += dis[i];
	std::sort(d + 1,d + b + 1);
	for (int i = 1;i <= b;i++) sum[i] = sum[i - 1] + d[i];
	for (int i = 1;i <= b;i++) sum[i] += sum1[i];
	half();
	return 0;
}
posted @ 2021-02-24 21:31  I11usi0ns  阅读(72)  评论(0编辑  收藏  举报