洛谷 P5069 [Ynoi2015] 纵使日薄西山
珂朵莉想让你维护一个长度为 \(n\) 的正整数序列 \(a_1,a_2,\ldots,a_n\),支持修改序列中某个位置的值。
每次修改后问对序列重复进行以下操作,需要进行几次操作才能使序列变为全 \(0\)(询问后序列和询问前相同,不会变为全 \(0\)):
选出序列中最大值的出现位置,若有多个最大值则选位置标号最小的一个,设位置为 \(x\),则将 \(a_{x-1},a_x,a_{x+1}\) 的值减 \(1\),如果序列中存在小于 \(0\) 的数,则把对应的数改为 \(0\)。
\(1\leq n,q\leq 10^5\),\(1\leq x_i\leq n\),\(1\leq a_i,y_i\leq 10^9\)。
濒临退役了啊……
可以发现进行一次操作后这个数再若干次操作后还会被选成最大值,而且这之中的操作不会影响这个数,因为这个数左右的数都被减了,左右的数一定不会进行操作。
那么我们可以直接统计每个会被操作的数的和,就是答案。
再进一步,对于一个单调上升的区间 \([l,r]\) ,那么一定是选 \(r,r-2,r-4...\) ,也就是和 \(r\) 奇偶性相同的数。
于是我们用set维护出相邻单调性不同的区间的交点,再用区间和就可以统计每个区间的和。
修改的时候只会影响最多 \([l-2,r+2]\) 个交点,暴力修改这几个点的答案。
Code
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <set>
const int N = 1e5;
using namespace std;
int n,a[N + 5],m;
long long ans;
set <int> s;
set <int>::iterator it,itp,itn;
int lowbit(int x)
{
return x & (-x);
}
struct Bit
{
long long c[N + 5];
void add(int x,int v)
{
for (;x <= n;x += lowbit(x))
c[x] += v;
}
long long query(int x)
{
long long ans = 0;
for (;x;x -= lowbit(x))
ans += c[x];
return ans;
}
}c[2];
long long calc(int x)
{
if (*s.lower_bound(x) != x || x <= 0 || x > n)
return 0;
//cout<<x<<endl;
long long ans = 0;
if (x == 1)
{
int nxt = *s.upper_bound(x);
if (a[x] < a[nxt])
ans = c[nxt % 2].query(nxt - 1);
}
else
if (x == n)
{
int pre = *(--s.lower_bound(x));
if (a[x] <= a[pre])
ans = c[pre % 2].query(x) - c[pre % 2].query(pre - 1);
else
ans = a[x];
}
else
{
int nxt = *s.upper_bound(x),pre = *(--s.lower_bound(x));
if (a[x] <= a[pre] && a[x] < a[nxt])
{
//cout<<x<<" "<<pre<<" "<<nxt<<endl;
ans = c[pre % 2].query(x - 1) - c[pre % 2].query(pre - 1) + c[nxt % 2].query(nxt - 1) - c[nxt % 2].query(x);
if (x % 2 == pre % 2 && x % 2 == nxt % 2)
ans += a[x];
}
}
//cout<<x<<" "<<ans<<endl;
return ans;
}
int main()
{
scanf("%d",&n);
for (int i = 1;i <= n;i++)
scanf("%d",&a[i]);
for (int i = 1;i <= n;i++)
c[i % 2].add(i,a[i]);
s.insert(0);s.insert(-1);s.insert(-2);s.insert(1);
for (int i = 2;i < n;i++)
if ((a[i] > a[i - 1] && a[i] >= a[i + 1]) || (a[i] <= a[i - 1] && a[i] < a[i + 1]))
s.insert(i);
s.insert(n);s.insert(n + 1);s.insert(n + 2);s.insert(n + 3);
for (it = s.begin();it != s.end();it++)
ans += calc(*it);
scanf("%d",&m);
int x,y;
while (m--)
{
scanf("%d%d",&x,&y);
itp = s.lower_bound(x);itn = s.upper_bound(x);
ans -= calc(x);
itp--;
ans -= calc(*itp) + calc(*itn);
itp--;
itn++;
ans -= calc(*itp) + calc(*itn);
s.erase(x);s.erase(x - 1);s.erase(x + 1);
s.insert(n);s.insert(1);
c[x % 2].add(x,-a[x]);
a[x] = y;
c[x % 2].add(x,a[x]);
for (int i = max(1,x - 1);i <= min(n,x + 1);i++)
{
if ((a[i] > a[i - 1] && a[i] >= a[i + 1]) || (a[i] <= a[i - 1] && a[i] < a[i + 1]))
s.insert(i);
}
while (itp != itn)
ans += calc(*itp),itp++;
ans += calc(*itp);
printf("%lld\n",ans);
}
return 0;
}