「解题报告」CF1621G Weighted Increasing Subsequences
比较套路的拆贡献题。
考虑直接枚举那个 \(j\),求有多少包含 \(j\) 的上升子序列满足这个子序列最后一个数的后面有大于 \(a_j\) 的数。
首先对于 \(j\) 前面的选择方案是没有影响的,可以直接拿树状数组 DP 一遍得到。后面的过程我们可以找到从后往前第一个大于 \(a_j\) 的数的位置 \(x\),那么后面的方案就是 \(\lbrack j, x)\) 中包含 \(j\) 的上升子序列数。这个东西直接求不好求,考虑任意一个以 \(j\) 开头的上升子序列,由于 \(x\) 是第一个大于 \(a_j\) 的数,说明这个数后面的数都比 \(a_j\) 小,那么以 \(j\) 开头的上升子序列不可能包含这些数。那么实际上只有两种序列:要不然不包含 \(x\),要不然以 \(x\) 结尾。我们容斥一下,求所有以 \(x\) 结尾的上升子序列数。发现,我们要求的就是以 \(j\) 开头,以 \(x\) 结尾的上升子序列数。看起来还是不可做,但是考虑到 \(x\) 表示的数一定是第一个大于 \(a_j\) 的数,也就是说假如我们把后缀最大值写成一个序列 \(b_1 < b_2 < \cdots < b_m\),那么在这个子序列中的数值域一定在 \((b_{x - 1}, b_x\rbrack\),容易发现这样的值域总数是 \(O(n)\) 的。那么我们直接拿树状数组跑 DP,就也是 \(O(n \log n)\) 的了。
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 200005, P = 1000000007;
int T, n, m, a[MAXN], b[MAXN];
struct BinaryIndexTree {
int a[MAXN];
#define lowbit(x) (x & (-x))
void init() { for (int i = 1; i <= n; i++) a[i] = 0; }
void add(int d, int v) {
while (d <= n) {
a[d] = (a[d] + v) % P;
d += lowbit(d);
}
}
int query(int d) {
if (!d) return 0;
int ret = 0;
while (d) {
ret = (ret + a[d]) % P;
d -= lowbit(d);
}
return ret;
}
} bit;
int f[MAXN], g[MAXN], h[MAXN];
pair<int, int> q[MAXN];
int main() {
scanf("%d", &T);
while (T--) {
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]), b[i] = a[i];
sort(b + 1, b + 1 + n);
int m = unique(b + 1, b + 1 + n) - b - 1;
for (int i = 1; i <= n; i++) a[i] = lower_bound(b + 1, b + 1 + n, a[i]) - b;
bit.init();
for (int i = 1; i <= n; i++) {
f[i] = (1 + bit.query(a[i] - 1)) % P;
bit.add(a[i], f[i]);
}
bit.init();
for (int i = n; i >= 1; i--) {
g[i] = (1 + bit.query(n) - bit.query(a[i]) + P) % P;
bit.add(a[i], g[i]);
}
m = 0;
for (int i = n; i >= 1; i--) {
if (!m || q[m].first < a[i]) q[++m] = { a[i], i }, b[i] = -1;
}
bit.init();
for (int i = n; i >= 1; i--) {
auto it = lower_bound(q + 1, q + 1 + m, make_pair(a[i], INT_MAX));
if (it != q + m + 1 && it->second > i) {
h[i] = (bit.query(a[it->second]) - bit.query(a[i]) + P) % P;
auto it2 = lower_bound(q + 1, q + 1 + m, make_pair(a[i], 0));
if (it2->first != a[i]) bit.add(a[i], h[i]);
} else if (b[i] == -1) {
h[i] = 1;
bit.add(a[i], h[i]);
} else {
h[i] = 0;
}
}
int ans = 0;
for (int i = 1; i <= n; i++) {
auto it = lower_bound(q + 1, q + 1 + m, make_pair(a[i], INT_MAX));
if (it != q + m + 1 && it->second > i) {
// printf("%d: %d %d %d\n", i, f[i], g[i], h[i]);
ans = (ans + 1ll * f[i] * (g[i] - h[i] + P)) % P;
}
}
printf("%d\n", ans);
}
return 0;
}