P6647 [CCC 2019] Tourism
\(\mathcal Solution\)
遍历一个序列,每次可以向右跳 \(1\sim k\) 个,并且获得跳过区间的 \(\max\) 的价值,求跳动次数最少的前提下的最大价值。
神仙 DP 好题!
首先不考虑跳动次数最少的限制,列出最朴素的方程,设 \(f(i)\) 表示到达点 \(i\) 时的最大价值:
\[f(i)=\max\limits_{i - k\leq j<i}\{f(j)+g(j+1,i)\}
\]
其中 \(g(l,r)\) 表示区间 \([l,r]\) 内 \(a_i\) 的最大值。
考虑加上跳动次数最小的限制,一个简单有效的操作是:
\[f(i)=\max\limits_{i - k\leq j<i}\{f(j)+g(j+1,i)\}-\inf
\]
每次跳动都要减去 \(\inf\) 的代价,这样跳动最少的点显然最先被考虑,这是个比较实用的套路。
但是这样的转移还是 \(O(nk)\) 的,需要继续优化。
看到转移式子的区间取 \(\max\) 操作,不难联想到线段树,可以考虑每次将 \(f\) 单点插入到线段树中去,每次区间查询最大值即可。
但是还是有问题的,因为每次的 \(g(j+1,i)\) 都有可能被 \(i\) 更新,所以线段树还要考虑区间加上一个值来得到真正的最大值。
但是有可能每个点加上的 \(g\) 都不同,那么还需要扫一遍去逐个单点修改,那还不如暴力优了。
但是可以发现,每个 \(g\) 都可以“管一段”,即 \(g\) 的取值是成段分布的。这个经典模型可以用单调栈来维护。
具体的,每次在单调栈中取出比当前 \(a_i\) 小的段进行区间加即可,相当于将那些段的决策点变为 \(i\)。
因为 \(g\) 对于 \(i-1\) 的区间是 \([i,i]\),比较烦,可以考虑将维护的 \(a\) 区间整体平移一个,这样决策和状态的下标就一一对应了。
因为每个 \(i\) 都进栈出栈一次,所以总共最多 \(2n\) 个区间存在过,所以总时间复杂度 \(O(n\log n)\),很优秀。
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long LL;
const int N = 1e6 + 10;
const LL INF = 1e12;
int n, k, tot, a[N], stk[N];
LL f[N], dat[N << 2], add[N << 2];
int read(){
int x = 0, f = 1; char c = getchar();
while(c < '0' || c > '9') f = (c == '-') ? -1 : 1, c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
return x * f;
}
void Push_Up(int p){
dat[p] = max(dat[p << 1], dat[p << 1 | 1]);
}
void Push_Down(int p){
if(!add[p]) return;
int l = p << 1, r = p << 1 | 1;
dat[l] += add[p], add[l] += add[p];
dat[r] += add[p], add[r] += add[p];
add[p] = 0;
}
void Modify(int p, int l, int r, int L, int R, LL v){
if(L <= l && r <= R){
dat[p] += v;
add[p] += v;
return;
}
Push_Down(p);
int mid = (l + r) >> 1;
if(L <= mid)
Modify(p << 1, l, mid, L, R, v);
if(R > mid)
Modify(p << 1 | 1, mid + 1, r, L, R, v);
Push_Up(p);
}
LL Query(int p, int l, int r, int L, int R){
if(L <= l && r <= R) return dat[p];
Push_Down(p);
LL num = - 9e18;
int mid = (l + r) >> 1;
if(L <= mid)
num = max(num, Query(p << 1, l, mid, L, R));
if(R > mid)
num = max(num, Query(p << 1 | 1, mid + 1, r, L, R));
return num;
}
int main(){
n = read(), k = read();
for(int i = 1; i <= n; i ++) a[i] = read();
for(int i = 1; i <= n; i ++){
while(tot && a[stk[tot]] <= a[i]){
if(a[stk[tot]] == a[i]) {tot --; continue;}
Modify(1, 0, n, stk[tot - 1], stk[tot] - 1, a[i] - a[stk[tot]]);
tot --;
}
stk[++ tot] = i;
Modify(1, 0, n, i - 1, i - 1, f[i - 1] + a[i]);
f[i] = Query(1, 0, n, max(i - k, 0), i - 1) - INF;
}
printf("%lld\n", f[n] + 1LL * ((n / k) + (n % k != 0)) * INF);
return 0;
}