CF1647F Madoka and Laziness
CF1647F Madoka and Laziness
题目大意
如果序列 \(a_1, a_2, \dots, a_n\) 中存在一个位置 \(i\) 满足 \(a_1 < a_2 < \dots < a_i > a_{i + 1} > a_{i + 2} > \dots > a_{n}\),称这个序列为一座小山,称 \(i\) 为山顶。
现在给你一个长度为 \(n\) 的,所有元素互不相同的序列 \(a\)。考虑所有把 \(a\) 划分为两个子序列的方案,使得每个元素恰好在 \(1\) 个子序列中,且两个子序列都是小山。问所有划分方案中,共有多少个可能的无序“山顶对”。如:山顶对 \((3, 5)\) 和 \((5, 3)\) 算同一个。
数据范围: \(1\leq n\leq 5\cdot 10^5\),\(1\leq a_i\leq 10^9\)。
本题题解
记整个序列里最大的元素的位置为 \(p\)。显然,无论怎么划分,\(p\) 一定是其中一座小山的山顶。也就是说,任何一个满足条件的山顶对里必有一个元素是 \(p\)。我们只需要考虑另一个元素是谁,设它为 \(q\),答案就是合法的 \(q\) 的数量。不妨先假设 \(q > p\),然后把序列翻转过来再求一次,就能算到所有情况。
设 \(p, q\) 所在的小山分别为“子序列一”和“子序列二”。整个序列被分为了三个部分:
- 在 \([1, p]\) 中,子序列一和子序列二都递增。
- 在 \((p, q]\) 中,子序列一递减,子序列二递增。
- 在 \((q, n]\) 中,子序列一个子序列二都递减。
如果两个子序列增减性相同,我们很容易用一个 DP 来判断是否有解。以同时递增为例:设 \(\mathrm{dp}_1(i)\) 表示只考虑前 \(i\) 个元素,元素 \(i\) 在某个子序列中时,另一个子序列的最后一个元素最小是几(最小是因为子序列要递增,最后一个元素越小,在后面越可能有解)。初始时,\(\mathrm{dp}_1(1) = -\inf\),其他所有 \(\mathrm{dp}_1(i)\) 赋值为正无穷。转移时,\(i\) 前面有以 \(a_{i - 1}\) 结尾的和以 \(\mathrm{dp}_1(i - 1)\) 结尾的两个子序列,考虑 \(a_i\) 加入哪一个:
- 若 \(a_i > a_{i - 1}\),\(a_i\) 可以加入以 \(a_{i - 1}\) 结尾的子序列,\(\mathrm{dp}_1(i)\) 对 \(\mathrm{dp}_1(i - 1)\) 取 \(\min\)。
- 若 \(a_i > \mathrm{dp}_1(i - 1)\),\(a_i\) 可以加入以 \(\mathrm{dp}_1(i - 1)\) 结尾的子序列,\(\mathrm{dp}_1(i)\) 对 \(a_{i - 1}\) 取 \(\min\)。
同理,对于 \((q, n]\),两个子序列都递减的部分,只需要从后往前做一遍一模一样的 DP 即可。记为 \(\mathrm{dp}_3\)。
比较复杂的是中间部分:子序列一递减,子序列二递增。设 \(\mathrm{dpIncMin}(i)\) 表示元素 \(i\) 在递减子序列中时,递增子序列最后一个元素最小是几;设 \(\mathrm{dpDecMax}(i)\) 表示元素 \(i\) 在递增子序列中时,递减子序列的最后一个元素最大是几。因为 \(p\) 只能在递减子序列(子序列一)中,所以初始时,\(\mathrm{dpDecMax}(p) = -\inf\)(表示无解);根据 \(\mathrm{dp}_1\) 的定义可知,\(\mathrm{dpIncMin}(p) = \mathrm{dp}_1(p)\)。对于其他位置,初始时 \(\mathrm{dpDecMax}(i) = -\inf\),\(\mathrm{dpIncMin}(i) = \inf\)。转移时,根据 \(a_i\) 是在递增子序列还是递减子序列中,以及 \(a_{i - 1}\) 是在递增子序列还是递减子序列中,共有四种可能:
- 若 \(a_i\) 在递减子序列里,且 \(a_{i - 1}\) 也在递减子序列里,此时需要满足:\(a_i < a_{i - 1}\)。那么 \(\mathrm{dpIncMin}(i)\) 可以对 \(\mathrm{dpIncMin}(i - 1)\) 取 \(\min\)。
- 若 \(a_i\) 在递减子序列里,\(a_{i - 1}\) 在递增子序列里,此时需要满足:\(a_{i} < \mathrm{dpDecMax}(i - 1)\)。那么 \(\mathrm{dpIncMin}(i)\) 可以对 \(a_{i - 1}\) 取 \(\min\)。
- 若 \(a_i\) 在递增子序列里,且 \(a_{i - 1}\) 也在递增子序列里,此时需要满足:\(a_i > a_{i - 1}\)。那么 \(\mathrm{dpDecMax}(i)\) 可以对 \(\mathrm{dpDecMax}(i - 1)\) 取 \(\max\)。
- 若 \(a_i\) 在递增子序列里,\(a_{i - 1}\) 在递减子序列里,此时需要满足:\(a_i > \mathrm{dpIncMin}(i - 1)\)。那么 \(\mathrm{dpDecMax}(i)\) 可以对 \(a_{i - 1}\) 取 \(\max\)。
最后,对于一个 \(i > p\) 的位置 \(i\),若满足 \(\mathrm{dpDecMax}(i) > \mathrm{dp}_3(i)\),说明中间部分可以顺利过渡到后半部分,那么 \(i\) 就可以是一个山顶。合法的 \(i\) 的总数,就是答案了。
时间复杂度 \(\mathcal{O}(n\log n)\)。
参考代码
// problem: CF1647F
#include <bits/stdc++.h>
using namespace std;
#define mk make_pair
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }
const int MAXN = 5e5;
const int INF = 1e9 + 1;
int n, a[MAXN + 5], mxp;
int dp1[MAXN + 5], dp3[MAXN + 5], dp2_incMin[MAXN + 5], dp2_decMax[MAXN + 5];
int ans;
void solve() {
dp1[1] = 0;
for (int i = 2; i <= mxp; ++i) {
dp1[i] = INF;
if (a[i] > a[i - 1]) {
ckmin(dp1[i], dp1[i - 1]);
}
if (a[i] > dp1[i - 1]) {
ckmin(dp1[i], a[i - 1]);
}
}
dp3[n] = 0;
for (int i = n - 1; i > mxp; --i) {
dp3[i] = INF;
if (a[i] > a[i + 1]) {
ckmin(dp3[i], dp3[i + 1]);
}
if (a[i] > dp3[i + 1]) {
ckmin(dp3[i], a[i + 1]);
}
}
dp2_incMin[mxp] = dp1[mxp]; // 把 mxp 放到递减序列里,递增序列的末尾最小(incMin)会是多少
dp2_decMax[mxp] = -INF; // mxp 不可能被放到递增序列里
for (int i = mxp + 1; i <= n; ++i) {
dp2_incMin[i] = INF;
dp2_decMax[i] = -INF; // 初始化
// 把 i 放到递减序列里:
if (a[i] < a[i - 1]) { // i - 1 也在递减序列里
ckmin(dp2_incMin[i], dp2_incMin[i - 1]);
}
if (a[i] < dp2_decMax[i - 1]) { // i - 1 在递增序列里
ckmin(dp2_incMin[i], a[i - 1]);
}
// 把 i 放到递增序列里:
if (a[i] > a[i - 1]) { // i - 1 也在递增序列里
ckmax(dp2_decMax[i], dp2_decMax[i - 1]);
}
if (a[i] > dp2_incMin[i - 1]) { // i - 1 在递减序列里
ckmax(dp2_decMax[i], a[i - 1]);
}
}
for (int i = mxp + 1; i <= n; ++i) {
ans += (dp2_decMax[i] > dp3[i]);
}
}
int main() {
cin >> n;
mxp = 0;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
if (a[mxp] < a[i]) {
mxp = i;
}
}
ans = 0;
solve();
reverse(a + 1, a + n + 1);
mxp = n - mxp + 1;
solve();
cout << ans << endl;
return 0;
}