Codeforces #816 1715 C
题面
假设我们有一个函数 $ g(1, n) $ 表示 $ i = 1 \sim n - 1 $ 中满足 $ a_i \neq a_{i + 1} $ 的 $ i $ 的数量。
现在有 $ m $ 个询问,每个询问将会让 $ x \rightarrow a_i $ 。
你需要在每次询问后求出 $ \sum_{l = 1}^{n} \sum_{r = l}^{n} g(l, r) $ 。
思路
着手于 $ a_i \neq a_{i + 1} $ 这个条件。
我们把满足 $ a_i \neq a_{i + 1} $ 的 $ i $ 右边,$ i + 1 $ 左边的空隙称为“断点”。显而易见,“断点”的数量 $ + 1 $ 就是 $ g(1, n) $ 。
对于一个断点,它被包含在子串里,则子串的 $ g + 1 $ 。
那么我们对于一个断点,统计它左边有多少个元素,右边有多少个元素,一乘,就能得到这个断点对答案的“贡献”。
这样就能求开始时的 $ \sum_{l = 1}^{n} \sum_{r = l}^{n} g(l, r) $ 了。
现在我们需要改变元素。
显然,一个元素的更改,只会影响它左边和右边的断点(如果有)。
(也就是 $ i - 1 $ 和 $ i $)
那就先减掉,更改完后,再加回来。
记得要加 $ \frac{n(n + 1)}{2} $ ,即 $ 1 + 2 + \cdots + n $ 。
这样做的原因是……
是……
是 $ + 1 $ !
那么我们统计所有的 $ [l, r] $:
- $ l = 1 $ ,有 $ n $ 条;
- $ l = 2 $ ,有 $ n - 1 $ 条;
- ……
- $ l = n $ ,有 $ 1 $ 条。
所以要 $ 1 + 2 + \cdots + n $ 。
代码
#include <bits/stdc++.h>
using namespace std;
int a[100005];
bool joint[100005];
int main() {
int n, m;
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
for (int i = 1; i <= n; i++) {
if (a[i] != a[i + 1]) {
joint[i] = 1;
}
}
long long ans = 0;
for (int i = 1; i <= n; i++) {
ans += 1ll * joint[i] * (i) * (n - i);
}
while (m--) {
int i, x;
scanf("%d %d", &i, &x);
ans -= 1ll * joint[i - 1] * (i - 1) * (n - i + 1);
ans -= 1ll * joint[i] * (i) * (n - i);
a[i] = x;
joint[i - 1] = a[i - 1] != a[i];
joint[i] = a[i] != a[i + 1];
ans += 1ll * joint[i - 1] * (i - 1) * (n - i + 1);
ans += 1ll * joint[i] * (i) * (n - i);
printf("%lld\n", ans + 1ll * n * (n + 1) / 2);
}
return 0;
}