「NOI2010」超级钢琴

知识点:RMQ,堆,小技巧

原题面

有趣的套路。

简述

给定一长度为 \(n\) 的数列 \(a\),给定参数 \(k,L,R\),求最大的 \(k\) 个长度在 \([L,R]\) 间的子段之和。
\(1\le n,k\le 5\times 10^5\)\(0\le |A_i|\le 10^3\)\(1\le L\le R\le n\),保证有解。
2S,512MB。

分析

一种显然的暴力是枚举所有长度在 \([L,R]\) 内的区间,大力求和并放入大根堆中,取最优的前 \(k\) 个即可。复杂度 \(O\left(\left(n^2 + k\right)\log n^2\right)\) 级别。期望得分 20pts。
\(n,k\) 同阶,上述做法的问题主要在于堆中有许多无贡献的元素。考虑如何缩小堆的大小,但又能保证每次访问堆顶时取出的元素是最优的。本题提供了一种解决此类前 \(k\) 优问题的思路。

首先求得 \(a\) 的前缀和 \(\operatorname{sum}\),任一区间都可以表示成两个前缀相减的形式,则题目所求即为最大的 \(k\)\(\operatorname{sum}_r - \operatorname{sum}_{l-1}\) 的值。若左端点 \(l-1\) 固定,最优的右端点 \(r\) 可以通过对 \(\operatorname{sum}\) 建 ST 表查询区间 \([l-1 + L, l-1 + R]\) 的最大值 \(O(1)\) 求得。
对于上述过程,定义状态 \((v, l-1, pos, x, y)\) 表示左端点为 \(l-1\),查询区间为 \([x, y]\),在此区间内找到的最优的 \(r\)\(pos\)\(\operatorname{sum}_r - \operatorname{sum}_{l-1}\) 的值是 \(v\)

初始时枚举所有左端点 \(l-1\),按照上述 ST 表做法构造初始状态 \((v, l - 1, pos, l - 1 +L, l - 1 + R)\) 并放入以状态中 \(v\) 为关键字的大根堆中。每次取出堆顶的状态,统计其贡献 \(v\),然后构造新状态 \((v', l - 1, pos', x, pos - 1)\)\((v'', l - 1, pos'', pos +1, y)\) 放入堆中。可以发现这样能够保证枚举到所有有贡献的状态。

堆中初始有 \(n\) 个状态,每次查询分裂状态都会使堆的大小 \(+1\),总复杂度 \(O((n+k)\log (n +k))\) 级别。
实现时注意一些边界细节。

代码

//知识点:RMQ,堆,小技巧
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <queue>
#define pr std::pair
#define mp std::make_pair
#define LL long long
const int kN = 5e5 + 10;
//=============================================================
struct Data {
  LL val;
  int i, pos, l, r;
  bool operator < (const Data &sec_) const {
    return val < sec_.val;
  }
};
int n, k, l, r;
LL ans, sum[kN];
std::priority_queue <Data> q;
//=============================================================
inline int read() {
  int f = 1, w = 0;
  char ch = getchar();
  for (; !isdigit(ch); ch = getchar())
    if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
void Chkmax(int &fir, int sec) {
  if (sec > fir) fir = sec;
}
void Chkmin(int &fir, int sec) {
  if (sec < fir) fir = sec;
}
namespace ST {
  const int kMaxn = kN;
  const int kMaxLog = 20 + 1;
  LL mx[kMaxn][kMaxLog];
  int Log2[kMaxn], pos[kMaxn][kMaxLog];
  void Build() {
    for (int i = 0; i <= n; ++ i) {
      if (i > 1) Log2[i] = Log2[i >> 1] + 1;
      mx[i][0] = sum[i];
      pos[i][0] = i;
    }
    for (int i = 1; i <= kMaxLog; ++ i) {
      for (int j = 0; j + (1 << i) - 1 <= n; ++ j) {
        if (mx[j][i - 1] >= mx[j + (1 << (i - 1))][i - 1]) {
          mx[j][i] = mx[j][i - 1];
          pos[j][i] = pos[j][i - 1];
        } else {
          mx[j][i] = mx[j + (1 << (i - 1))][i - 1];
          pos[j][i] = pos[j + (1 << (i - 1))][i - 1];
        }
      }
    }
  }
  int Query(int L_, int R_) {
    int lth = Log2[R_ - L_ + 1];
    return mx[L_][lth] > mx[R_ - (1 << lth) + 1][lth] ? 
           pos[L_][lth]:
           pos[R_ - (1 << lth) + 1][lth];
  }
}
//=============================================================
int main() { 
  n = read(), k = read(), l = read(), r = read();
  for (int i = 1; i <= n; ++ i) sum[i] = sum[i - 1] + read();
  ST::Build();
  for (int i = 0; i <= n; ++ i) {
    if (i + l > n) break; //无贡献
    int pos = ST::Query(i + l, std::min(n, i + r));
    q.push((Data) {sum[pos] - sum[i], i, pos, i + l, std::min(n, i + r)}); //注意取 min
  }
  while (k --) {
    Data t = q.top(); q.pop();
    ans += t.val;
    int i = t.i, pos = t.pos, l_ = t.l, r_ = t.r;
    
    int new_pos = ST::Query(l_, pos - 1);
    if (l_ <= pos - 1) q.push((Data) {sum[new_pos] - sum[i], i, new_pos, l_, pos - 1});
    new_pos = ST::Query(pos + 1, r_);
    if (pos + 1 <= r_) q.push((Data) {sum[new_pos] - sum[i], i, new_pos, pos + 1, r_});
  }
  printf("%lld\n", ans);
  return 0;
}
posted @ 2021-01-27 08:27  Luckyblock  阅读(56)  评论(0编辑  收藏  举报