斜率优化 学习笔记


Update:文章已更新

板子题

题目传送门
题目描述
\(n\) 个任务排成一个序列在一台机器上等待完成(顺序不得改变),这 \(n\) 个任务被分成若干批,每批包含相邻的若干任务。
从零时刻开始,这些任务被分批加工,第 \(i\) 个任务单独完成所需的时间为 \(t_i\) 。在每批任务开始前,机器需要启动时间 \(s\),而完成这批任务所需的时间是各个任务需要时间的总和(同一批任务将在同一时刻完成)。
每个任务的费用是它的完成时刻乘以一个费用系数 \(f_i\) 。请确定一个分组方案,使得总费用最小。
输入格式
第一行一个正整数 \(n\)
第二行是一个整数 \(s\)
下面 \(n\) 行每行有一对数,分别为 \(t_i\)\(f_i\) 表示第 \(i\) 个任务单独完成所需的时间是 \(t_i\) 及其费用系数 \(f_i\)
输出格式
一个数,最小的总费用。
输入输出样例
输入 #1

5
1
1 3
3 2
4 3
2 3
1 4

输出 #1

153

说明/提示
【数据范围】
对于 \(100\%\) 的数据,\(1\le n \le 5000\)\(0 \le s \le 50\)\(t_i,f_i\le 100\)
【样例解释】
如果分组方案是 \(\{1,2\},\{3\},\{4,5\}\) ,则完成时间分别为 \(\{5,5,10,14,14\}\) ,费用 \(C=15+10+30+42+56\) ,总费用就是 \(153\)

算法理解

其实斜率优化是一个数学建模。
前置知识:平面直角坐标系上的直线。

算法解析

因为讲的是斜率优化,所以推DP式子的过程不再赘述了,不会的可以点这里,包教包会。
为了方便叙述,我们令第 \(i\) 个产品的时间为 \(T_i\) , 系数为 \(C_i\) ,启动时间 \(s\) 。为了方便进行两个数组区间和的运算,我们将 \(T,C\) 两个数组做一遍前缀和,得到 \(SumT,SumC\) 两个数组。
令前 \(i\) 个的答案为 \(f_i\) ,初值为 \(f_0=0\)
不难得到DP式:

\[f_i=Min_{0\le j<i}\{f_j+\left(SumC_n+SumC_j\right)\times s+SumT_i\times \left(SumC_i-SumC_j\right)\} \]

我们发现,这样我们要 \(\Theta\left(n^2\right)\) 来计算,那么是不是可以优化一下呢?
为了化简式子,我们把 Min去掉

\[f_i=f_j+\left(SumC_n+SumC_j\right)\times s+SumT_i\times \left(SumC_i-SumC_j\right) \]

我们发现,在 \(i\) 这一维是要不断枚举的,所以至少要 \(\Theta(n)\) ,在每次枚举 \(i\) 的时候,我们发现 \(j\) 是一个变量, \(i\) 是一个定值。这样我们继续化简,大家可以拿出笔算一下。
最后的结果:

\[f_j=\left(SumT_i+s\right) \times SumC_j + f_i - SumT_i\times SumC_i-s\times SumC_n \]

我们发现这就是一个直线的斜截式 \(y=kx+b\) ,而且 \(k\) 是定值,我们要求的 \(f_i\)\(b\) 的差值也是一个定值。
我们试着在平面直角坐标系上来解释它的几何含义。
几何含义:已知一条直线过点 \((x,y)\) 并且斜率为 \(k\)\(k\) 是定值并且是正数,\(k\)\(x\) 的增大而增大 ),求这组直线截距的最小值 (截距就是和 \(x\) 轴交点的纵坐标)
我们把这些点放在平面直角坐标系上面。
不难证明:
如果维护的是最小值,那么所有可能是最优解的点都形成一个下凸包,我们只要维护一个下凸包就可以了(单调队列队尾出队判断)。
当然,我们发现,如果连接距离 \(y\) 轴最近的两个点连接起来的斜率小于过最后一个点的直线的斜率的话,最近的那个点就不是最优值(单调队列队首出队判断)。
显然我们可以用单调队列来维护,这样就是 \(\Theta\left(N\right)\) 的了。
反之亦然(指求最大值)。
我们可以用单调队列来优化。
代码

#include<cstdio>
#include<cstring>
#define min(a,b) ((a)<(b)?(a):(b))
#define maxn 300039
using namespace std;
//#define debug
typedef long long Type;
typedef long long ll;
inline Type read(){
	Type sum=0;
	int flag=0;
	char c=getchar();
	while((c<'0'||c>'9')&&c!='-') c=getchar();
	if(c=='-') c=getchar(),flag=1;
	while('0'<=c&&c<='9'){
		sum=(sum<<1)+(sum<<3)+(c^48);
		c=getchar();
	}
	if(flag) return -sum;
	return sum;
}
int n;
ll t[maxn],c[maxn];
ll sumT[maxn],sumC[maxn];
ll f[maxn],s;
int que[maxn],head,tail; 
double js(int x,int y){
	return (f[y]-f[x])*1.0/(sumC[y]-sumC[x]);
}
int main(){
	//freopen("1.in","r",stdin);
	//freopen("cpp.out","w",stdout);
	memset(f,0x7f,sizeof(f));
    n=read(); s=read();
    for(int i=1;i<=n;i++){
    	t[i]=read(); c[i]=read();
    	sumT[i]=sumT[i-1]+t[i];
    	sumC[i]=sumC[i-1]+c[i];
	}
	f[0]=0;
	head=0; tail=0;
	for(int i=1;i<=n;i++){
		while(head<tail&&js(que[head],que[head+1])<=sumT[i]+s) head++;
		#define k que[head]
		f[i]=min( f[i],f[k]+(sumC[n]-sumC[k])*s+sumT[i]*(sumC[i]-sumC[k]) );//DP式 
		while(head<tail&&js(que[tail],que[tail-1])>=js(que[tail],i)) tail--;
		que[++tail]=i;
	}
	printf("%lld",f[n]);
	return 0;
}

posted @ 2021-02-15 22:54  jiangtaizhe001  阅读(66)  评论(0编辑  收藏  举报