P1484 种树 Sol / wqs 二分

这个知乎专栏挺 nb 的

不会反悔贪心,所以用 wqs 二分解决了问题。

首先可以发现对于函数 \(f(i)\)(表示强制选 \(i\) 个点时的最大点权和)来说,它呈上凸壳状。

原因是每一次都应当加入使得答案变得最大的一个点。

否则可以交换两次操作使得答案更优,不符合题意。

到了某个 \(i\) 之后,你被迫选择不那么优的点(本题中为负点权),所以答案会下降。

所以 \(f(i)\) 呈现上凸壳状。


wqs 二分的思想是二分斜率。对于当前二分到的斜率 \(mid\),可以确定一条直线 \(f(i)=mid \cdot i+b\)

由于我们无法知道 \(f(i)\) 的值(不然为什么还要二分呢对吧),考虑用 \(mid \cdot i+b\) 来反推 \(f(i)\)

由于 \(mid,i\) 确定,使得 \(f(i)\) 最大,就要使得 \(b\) 最大(即直线截距最大)。

考虑 \(b=f(i)-mid\cdot i\),可以想到,它等效于每一次花费 \(mid\) 的代价来进行一次答案更新。

那么如何求出 \(b\) 呢?考虑将每一个原点权 \(-mid\),跑一遍 DP,求出最大点权和以及最大点权和选择的点数。

如果当前选择的点数 \(\le k\),意味着当前的斜率 \(mid\) 是可行的,可以继续减小,记录答案。

否则,当前的 \(mid\) 不可行,斜率需要继续增大。

由于你求出的点数其实就是 \(i\),所以我们可以将求出来的答案 \(+mid\times i\) 以获得真实答案。

当然,在这之前,分情况考虑。

由于题目说“至多 \(k\) 次”,所以对于 \(k>\) \(f(i)_{\max}\)对应的 \(i\) 时,直接输出 \(f(i)_{\max}\) 即可。

也就是说对于在 \(f(i)\) 单调减的区间内的 \(k\),我们不如直接选在这之前的最大值。

接下来只需要考虑 \(f(i)\) 单调递增的情况,像前面那样处理即可。

但是有一个问题,就是存在特殊情况,点权可以重复,即一段斜率可能相等。

这里我们就可以简单粗暴地处理问题,将求出来的答案 \(+mid\times k\)

虽然说对于和 \(f(k)\) 处斜率不一致的 \(f(i)\) 求出来的答案是错误的,但是我们只关心最大值。

在这段区间内,\(i\) 越大,答案越大,\(i=k\) 时自然最优;且点数单调非降,不会对最大值产生影响,答案自然正确。

#include <bits/stdc++.h>
#define int long long
using namespace std;

const int N = 5e5 + 10, inf = 1e18;
int n, k, a[N], dp[N][2], g[N][2];

inline pair <int, int> check(int mid) {
  for (int i = 1; i <= n; ++i) {
    dp[i][0] = dp[i - 1][0], g[i][0] = g[i - 1][0];
    if (dp[i - 1][1] >= dp[i][0]) {
      if (dp[i][0] == dp[i - 1][1]) g[i][0] = min(g[i][0], g[i - 1][1]);
      else g[i][0] = g[i - 1][1], dp[i][0] = dp[i - 1][1];
    }
    dp[i][1] = dp[i - 1][0] + a[i] - mid, g[i][1] = g[i - 1][0] + 1;
  }
  if (dp[n][0] > dp[n][1]) return {dp[n][0], g[n][0]};
  if (dp[n][0] == dp[n][1]) return {dp[n][0], min(g[n][0], g[n][1])};
  return {dp[n][1], g[n][1]};
}

signed main() {
  ios_base::sync_with_stdio(false); cin.tie(0), cout.tie(0);
  cin >> n >> k; for (int i = 1; i <= n; ++i) cin >> a[i];
  auto tmp = check(0); if (tmp.second <= k) return cout << tmp.first << endl, 0;
  int l = 1, r = *max_element(a + 1, a + 1 + n), res;
  while (l <= r) {
    int mid = (l + r) >> 1; auto tmp = check(mid);
    int val = tmp.first, num = tmp.second;
    if (num <= k) res = val + k * mid, r = mid - 1;
    else l = mid + 1;
  }
  cout << res << endl;
  return 0;
}
posted @ 2022-11-02 19:53  MistZero  阅读(44)  评论(0编辑  收藏  举报