哈尔滨理工大学第12届程序设计竞赛(同步赛)A 割韭菜 线段树
思路如下:
因为每颗韭菜最开始高度都为0,韭菜的生长速度有大有小,生长速度的大的韭菜无论什么时候都是不会比生长速度小的韭菜高度更低的。所以我们考虑对生长速度进行排序,那么需要切割的韭菜其实是一个连续的区间,也就是说如果进行切割操作的话,如果有韭菜被切割,那么被切割的韭菜一定是总排完序之后的最右边到数组某一个地方为止都需要切割的,然后小于这个地方的为位置都不需要切割。
所以我们考虑线段树上每个结点维护这么几个信息,上次切割完结点韭菜时结点最低韭菜高度的mnh,上次切割完结点韭菜时结点最高韭菜高度mxh,结点生长速度总和sumv,上次切割完结点韭菜时结点韭菜的高度总和sum,韭菜最低生长速度mnv,韭菜最快生长速度mxv,上次割完结点韭菜时的高度懒标记下传用的lazyh,还有上次割韭菜的时间lastd。(其实mnv和mxv不用在线段树上维护,每个结点最左边的就是最慢的,最右边的就是最快的。)
我们知道了这些信息之后如何判断线段树上一个结点u内的韭菜能否能被全部减掉呢。我们知道上次割完这个区间韭菜的时间,知道上次割完时候的最低韭菜高度,知道这次的时间d和割韭菜的高度limit,那么我们用这么一个计算公式:height = mnh + 1ll * (d - lastd) * mnv。这样我们算出了结点u内的最低韭菜在第d天的高度是height,那么右边的韭菜肯定会比他更高。所以这一个区间的韭菜都可以割掉。那么如何计算贡献?
我们知道上次割韭菜的时候的区间总韭菜高度sum,和区间总生长速度sumv,那么算出在第d天的韭菜总高度all = sum + 1ll * (lastd - d) * sumv,区间需要保留的总韭菜高度tot = 1ll * (r - l + 1) * limit。所以贡献为all - tot。然后给区间打上一个lazyh = limit的标记用来下次遍历时pushdown,再标记一下区间上次被割完的时候lastd = d,区间最低最高韭菜高度改变为为mnh = mxh = limit。
pushdown的时候要注意切割韭菜高度可能为0,所以用来下传的lazyh不能用0来表示懒标记下传成功了,我们换成lazyh = -1来表示lazyh下传成功。如果有懒标记进行下传,需要修改当前节点u的左儿子ls和右儿子rs的韭菜高度总和sum(ls(u)) = 1ll * (tr[ls(u)].r - tr[ls(u)].l + 1) * lazyh(u);sum(rs(u)) = 1ll * (tr[rs(u)].r - tr[rs(u)].l + 1) * lazyh(u);,以及时间lastd(ls(u)) = lastd(rs(u)) = lastd(u)。还有最低韭菜和最高韭菜高度mnh(ls(u)) = mxh(ls(u)) = mnh(rs(u)) = mxh(rs(u)) = lazyh(u)。最后懒标记下传到左右儿子lazyh(ls(u)) = lazyh(rs(u)) = lazyh(u)。别忘了清空u的懒标记lazyh(u) = -1。
pushup的时候要在当前结点u推平两边的lastd。因为当前区间维护的lastd是基于左右儿子都为lastd时的信息维护。考虑如何推平贡献。首先因为右儿子的生长速度快,所以右儿子的韭菜不可能低于左儿子的韭菜,所以右儿子的lastd肯定是大于等于左儿子的lastd。左儿子为ls,右儿子为rs,当前结点维护最大的lastd,也就是lastd(u) = lastd(rs(u))在维护结点u的mnh的时候mnh(u) = mnh(ls(u)) + 1ll * (lastd(rs(u)) - lastd(ls(u))) * mnv(u)。 这样一来在u基于lastd(rs)时的最低韭菜高度就是正确的了(只考虑结点u上的信息推平,不是直直接将左儿子更改,更改左儿子更花费时间。)。然后维护sum(u) = sum(ls(u)) + sum(rs(u)) + 1ll * (lastd(u) - lastd(ls(u))) * sumv(ls(u))。维护u基于lastd时间的的最高韭菜高度mxh (u) = mxh(rs)。
代码实现如下:
#include <iostream>
#include <cstring>
#include <iomanip>
#include <algorithm>
#include <stack>
#include <queue>
#include <numeric>
#include <cassert>
#include <bitset>
#include <cstdio>
#include <vector>
#include <unordered_set>
#include <cmath>
#include <map>
#include <unordered_map>
#include <set>
#include <deque>
#include <tuple>
#include <array>
#define all(a) a.begin(), a.end()
#define cnt0(x) __builtin_ctz(x)
#define endl '\n'
#define ll long long
#define ull unsigned long long
#define cntone(x) __builtin_popcount(x)
#define db double
#define fs first
#define se second
#define AC main(void)
#define HYS std::ios::sync_with_stdio(false);std::cin.tie(0);std::cout.tie(0);
typedef std::pair<int, int > PII;
typedef std::pair<int, std::pair<int, int>> PIII;
typedef std::pair<ll, ll> Pll;
typedef std::pair<double, double> PDD;
using ld = double long;
#define rs(u) u << 1 | 1
#define ls(u) u << 1
#define lasth(u) tr[u].lasth
#define lastd(u) tr[u].lastd
#define mxv(u) tr[u].mxv
#define mnv(u) tr[u].mnv
#define mnh(u) tr[u].mnh
#define mxh(u) tr[u].mxh
#define sum(u) tr[u].sum
#define lazyh(u) tr[u].lazyh
#define sumv(u) tr[u].sumv
#define lazyd(u) tr[u].lazyd
const long double eps = 1e-9;
const int INF = 0x3f3f3f3f;
const int N = 1e5 + 10, M = 6e2 + 10;
int n, m, p;
#define int long long
int d1[] = {0, 0, 1, -1};
int d2[] = {1, -1, 0, 0};
int a[N];
struct node{
int l, r;
int mxv, mnv;
ll sumv, sum;
ll mnh, mxh;
int lastd;
int lazyh;
}tr[N << 2];
inline void pushupstart(int u){
mxv(u) = mxv(rs(u));
mnv(u) = mnv(ls(u));
sumv(u) = sumv(ls(u)) + sumv(rs(u));
}
inline void pushup(int u){
lastd(u) = lastd(rs(u));
sum(u) = sum(ls(u)) + sum(rs(u)) + 1ll * (lastd(u) - lastd(ls(u))) * sumv(ls(u));
mnh(u) = mnh(ls(u)) + 1ll * (lastd(u) - lastd(ls(u))) * mnv(u);
mxh(u) = mxh(rs(u));
}
inline void pushdown(int u){
if(lazyh(u) != -1){
lazyh(ls(u)) = lazyh(rs(u)) = lazyh(u);
mnh(ls(u)) = mxh(ls(u)) = mnh(rs(u)) = mxh(rs(u)) = lazyh(u);
sum(ls(u)) = 1ll * (tr[ls(u)].r - tr[ls(u)].l + 1) * lazyh(u);
sum(rs(u)) = 1ll * (tr[rs(u)].r - tr[rs(u)].l + 1) * lazyh(u);
lastd(ls(u)) = lastd(rs(u)) = lastd(u);
lazyh(u) = -1;
}
}
inline void build(int u, int l, int r){
tr[u] = {l, r};
if(l == r){
sumv(u) = mxv(u) = mnv(u) = a[l];
return ;
}
lazyh(u) = -1;
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushupstart(u);
}
inline ll query(int u, int d, int limit){
ll res = 0;
if(mnh(u) + 1ll * (d - lastd(u)) * mnv(u) > limit){
ll tot = 1ll * limit * (tr[u].r - tr[u].l + 1);
res = sumv(u) * (d - lastd(u)) + sum(u) - tot;
sum(u) = tot;
mnh(u) = mxh(u) = limit;
lazyh(u) = limit;
lastd(u) = d;
return res;
}
if(mxh(u) + 1ll * (d - lastd(u)) * mxv(u) <= limit) return 0;
pushdown(u);
res += query(ls(u), d, limit);
res += query(rs(u), d, limit);
pushup(u);
return res;
}
inline void solve(){
std::cin >> n >> m;
for(int i = 1; i <= n; i ++) std::cin >> a[i];
std::sort(a + 1, a + n + 1);
build(1, 1, n);
while(m --){
int d, h;
std::cin >> d >> h;
std::cout << query(1, d, h) << '\n';
}
}
signed AC{
HYS
int _ = 1;
std::cin >> _;
while (_ --)
solve();
return 0;
}