[lnsyoj2210/luoguP5069]纵使日薄西山
来源
原题链接
2024.7.25 校内测验 T3
题意
给定序列 \(a\),\(m\) 次查询,每次查询修改一个数,然后查询:每次操作选定最大且下标最小的数 \(a_i\),使 \(a_{i-1},a_i,a_{i+1}\) 的值都减 \(1\),查询将整个序列变为全非正数序列的操作次数.
赛时 50pts
由于每次都会连带着相邻两个元素一起减 \(1\),当选定一个元素后,相邻两个元素永远不会大于该元素,因此当选定一个元素后,该元素一定会一直被操作直到减为 \(0\)。因此我们开一个堆,每次找出符合要求的元素,然后将该元素及相邻元素置为0,并累加答案。
时间复杂度 \(O(mn\log n)\)。
sol
顺着赛时的思路去想,由于需要进行修改,我们考虑使用线段树来操作。
在 pushup 时,我们需要将两区间合并,此时可能会有四种情况:
.X.
+(.)X.
直接合并.X
+X.
此时如果删掉中间的其中一个,另一个也会随之被删除,因此,只能选择更大且更靠左的值,另外一边只能不选X
,重新计算答案
因此,我们需要维护四个值 \(sum,suml,sumr,sumlr\)(分别表示最终答案、不选左端点的答案、不选右端点的答案、不选两端点的答案)以及四个标记 \(cl, cr, clr, crl\)(分别表示最终答案选不选左端点、最终答案选不选右端点、不选左端点的答案选不选右端点、不选右端点的答案选不选左端点)。
在合并时,若左儿子答案取了右端点,并且右儿子答案取了左端点,则需要处理冲突,否则直接相加。
处理冲突的方法为:
比较左儿子的右端点和右儿子的左端点的大小,若相同则取左儿子。
如果取左儿子,则将左儿子的答案和右儿子不取左端点的答案相加,然后处理标记
否则,就讲右儿子的答案和左儿子不取右端点的答案相加,然后处理标记
另外三个值的处理方式同理。
代码
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
typedef long long LL;
const int N = 100005;
struct Node{
LL sum;
LL suml, sumr, sumlr;
bool cl, cr;
bool clr, crl;
}tr[N * 4];
int n, m;
LL a[N];
void pushup(int u, int l, int r){
int mid = l + r >> 1;
if (tr[u << 1].cr && tr[u << 1 | 1].cl){
if (a[mid] >= a[mid + 1]){
tr[u].cl = tr[u << 1].cl;
tr[u].cr = tr[u << 1 | 1].clr;
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].suml;
}
else {
tr[u].cl = tr[u << 1].crl;
tr[u].cr = tr[u << 1 | 1].cr;
tr[u].sum = tr[u << 1].sumr + tr[u << 1 | 1].sum;
}
}
else {
tr[u].cl = tr[u << 1].cl;
tr[u].cr = tr[u << 1 | 1].cr;
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
if (tr[u << 1].clr && tr[u << 1 | 1].cl){
if (a[mid] >= a[mid + 1]){
tr[u].clr = tr[u << 1 | 1].clr;
tr[u].suml = tr[u << 1].suml + tr[u << 1 | 1].suml;
}
else {
tr[u].clr = tr[u << 1 | 1].cr;
tr[u].suml = tr[u << 1].sumlr + tr[u << 1 | 1].sum;
}
}
else {
tr[u].clr = tr[u << 1 | 1].cr;
tr[u].suml = tr[u << 1].suml + tr[u << 1 | 1].sum;
}
if (tr[u << 1 | 1].crl && tr[u << 1].cr){
if (a[mid] < a[mid + 1]){
tr[u].crl = tr[u << 1].crl;
tr[u].sumr = tr[u << 1 | 1].sumr + tr[u << 1].sumr;
}
else {
tr[u].crl = tr[u << 1].cl;
tr[u].sumr = tr[u << 1].sum + tr[u << 1 | 1].sumlr;
}
}
else {
tr[u].crl = tr[u << 1].cl;
tr[u].sumr = tr[u << 1 | 1].sumr + tr[u << 1].sum;
}
if (tr[u << 1 | 1].crl && tr[u << 1].clr){
if (a[mid] >= a[mid + 1])
tr[u].sumlr = tr[u << 1].suml + tr[u << 1 | 1].sumlr;
else
tr[u].sumlr = tr[u << 1].sumlr + tr[u << 1 | 1].sumr;
}
else
tr[u].sumlr = tr[u << 1 | 1].sumr + tr[u << 1].suml;
}
void build(int u, int l, int r){
if (l == r){
tr[u].sum = a[l];
tr[u].suml = tr[u].sumr = tr[u].sumlr = 0;
tr[u].cl = tr[u].cr = true;
tr[u].clr = tr[u].crl = false;
return ;
}
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u, l, r);
}
void update(int u, int l, int r, int x){
if (l == r){
tr[u].sum = a[l];
return ;
}
int mid = l + r >> 1;
if (x <= mid) update(u << 1, l, mid, x);
else update(u << 1 | 1, mid + 1, r, x);
pushup(u, l, r);
}
int main(){
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i ++ ) scanf("%lld", &a[i]);
build(1, 1, n);
while (m -- ){
int x, y;
scanf("%d%d", &x, &y);
a[x] = y;
update(1, 1, n, x);
printf("%lld\n", tr[1].sum);
}
return 0;
}