【题解】CF1621G Weighted Increasing Subsequences
常规,但不常规。
思路来自 @gyh.
思路
BIT 优化计数。
本来考虑的是对 LIS 进行计数,得到一个对 \([]\) 形式的值套三层求和的方式,然后再瞪眼找优化方法,但是没有发现什么好的处理方法,于是只能考虑转换计数方法。
考虑通过每个位置对答案的贡献计数。假设某个位置 \(x\) 被一个合法的子序列 \(i_1, \cdots i_k\) 包含,考虑此时需要满足的限制。
其实很简单,考虑最后一个满足题目限制的位置,令 \(y\) 为最靠后的满足 \(a_y > a_x\) 的位置。只需限制 \(i_k < y\),就可以通过位置 \(y\) 满足题目的限制。同时容易发现 \(y\) 以及之后的位置都不可能满足限制,所以这个条件是充要的。
对于每个位置,考虑在统计以其为开头的上升子序列时计算它对答案的贡献。换言之,对于 \(1 \leq x \leq n\),统计 \([x, y - 1]\) 中包含 \(x\) 的上升子序列的个数。
通过树状数组优化朴素的 dp 做法,可以达到 \(O(n^2 \log n)\) 的复杂度。
同时可以观察得到合法的 \(y\) 一定是原序列的后缀最大值,可以进一步减小常数。
优化发现没法再转化统计方式,于是考虑通过容斥一类的方式做一些手脚。先假定所有包含 \(x\) 的上升子序列均符合限制,再容斥减去不符合限制的子序列。
分类讨论。当子序列以 \([x, y - 1]\) 中的位置结尾时,序列一定符合条件;当子序列以 \(y\) 结尾时,不符合题目限制;当子序列以 \([y + 1, n]\) 中的位置结尾时,因为 \(y\) 的定义,一定不可能从中找出比 \(a_y\) 大的值,所以这种情况矛盾,不可能统计。
这意味着我们只需要先预处理出以 \(x\) 为开头的上升子序列数量,再计算以 \(x\) 开头且 \(y\) 为结尾的上升子序列数量。
同时题目有进一步的性质:令 \(z\) 为满足 \(z > y\) 且 \(a_z\) 最大的位置。对于某个 \(y\),其对应的所有 \(x\) 都应当满足 \(a_z \leq a_x < a_y\),也就是对于每个 \(y\),其对应的 \(a_x\) 一定在某个区间内,统计的时候只需要考虑这些位置,直接通过定义二分预处理出来。
又因为 \(x\) 到 \(y\) 可以看成是一种单射关系,也就是每个 \(x\) 所在的被 \(y\) 确定的区间是唯一的,所以均摊的复杂度是 \(O(n \log n)\).
两次转化算是比较常规,就是瞪不出来题目的性质,跟做初中几何一个样。
代码
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
const int maxn = 2e5 + 5;
const int mod = 1e9 + 7;
int n;
int top, stk[maxn];
int a[maxn], p[maxn];
int pre[maxn], suf[maxn], f[maxn];
vector<int> seq[maxn];
namespace BIT
{
int c[maxn];
void init() { for (int i = 1; i <= n; i++) c[i] = 0; }
int lowbit(int x) { return x & (-x); }
void update(int p, int w) { for (int i = p; i <= n; i += lowbit(i)) c[i] = (c[i] + w) % mod; }
int query(int p)
{
int res = 0;
for (int i = p; i; i -= lowbit(i)) res = (res + c[i]) % mod;
return res;
}
}
using namespace BIT;
bool cmp(int x, int y) { return (a[x] == a[y] ? x > y : a[x] < a[y]); }
void solve()
{
scanf("%d", &n);
top = 0;
for (int i = 1; i <= n; i++) seq[i].clear();
for (int i = 1; i <= n; i++) scanf("%d", &a[i]), p[i] = i;
sort(p + 1, p + n + 1, cmp);
for (int i = 1; i <= n; i++) a[p[i]] = i;
init(); for (int i = 1; i <= n; i++) update(a[i], pre[i] = query(a[i] - 1) + 1);
init(); for (int i = n; i >= 1; i--) update(n - a[i] + 1, suf[i] = query(n - a[i]) + 1);
for (int i = n; i >= 1; i--) if (a[i] > a[stk[top]]) stk[++top] = i;
for (int i = n; i >= 1; i--)
{
int l = 1, r = top;
while (l < r)
{
int mid = (l + r) >> 1;
if (a[i] <= a[stk[mid]]) r = mid;
else l = mid + 1;
}
if (i != stk[l]) seq[stk[l]].push_back(i);
}
init();
for (int i = 1; i <= top; i++)
{
update(n - a[stk[i]] + 1, f[stk[i]] = 1);
for (int p : seq[stk[i]]) update(n - a[p] + 1, f[p] = query(n - a[p]));
for (int p : seq[stk[i]]) update(n - a[p] + 1, -f[p]);
update(n - a[stk[i]] + 1, -1);
}
int ans = 0;
for (int i = 1; i <= n; i++) ans = (ans + 1ll * (suf[i] - f[i] + mod) % mod * pre[i] % mod) % mod;
printf("%d\n", ans);
}
int main()
{
int t;
scanf("%d", &t);
while (t--) solve();
return 0;
}