AtCoder Beginner Contest 354 F - Useless for LIS
题意
给定 T
组数据,对于每组数据,给定一个长度为 n
的序列 a
,问能组成最长上升子序列(LIS)的数有哪些,输出它们的下标。
解法
我们先考虑找出LIS的长度,这里我们可以用dp去做。
定义 dp[i]
表示以 a[i]
结尾的LIS长度,那么转移就是 \(dp[i] = max_{j \lt i \& a[j] < a[i]}(dp[j] + 1)\)
我们可以用树状数组来存储前缀最大值。注意,如果是求任意区间的最大值,那么是无法用树状数组去求的,但是前缀最大值是没问题的。
那么如果找出所有能组成LIS的数呢?我们考虑逆序遍历,遍历到 i
的时候,判断 i
之后是否存在 a[j] > a[i]
且 dp[j] = dp[i] + 1
,如果满足,就是能够组成LIS的数。我们可以使用一个数组 seq
来存储当前LIS中各个位置上的数,具体如代码所示。
#include <bits/stdc++.h>
#define endl '\n'
#define ls u << 1
#define rs u << 1 | 1
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
const int INF = 0x3f3f3f3f, N = 1e5 + 10;
const int MOD = 1e9 + 7;
const double eps = 1e-6;
const double PI = acos(-1);
inline int lowbit(int x) {return x & (-x);}
inline void solve() {
int n; cin >> n;
vector<int> a(n + 1), v;
for (int i = 1; i <= n; i ++ ) {
cin >> a[i];
v.push_back(a[i]);
}
sort(v.begin(), v.end());
v.erase(unique(v.begin(), v.end()), v.end());
vector<int> tr(n + 1), dp(n + 1);
function<int(int)> getIdx = [&](int x) {
return lower_bound(v.begin(), v.end(), x) - v.begin() + 1;
};
function<void(int, int)> update = [&](int u, int val) {
for (int i = u; i <= n; i += lowbit(i)) tr[i] = max(tr[i], val);
};
function<int(int)> pre = [&](int u) {
int sum = 0;
for (int i = u; i; i -= lowbit(i)) sum = max(sum, tr[i]);
return sum;
};
for (int i = 1; i <= n; i ++ ) {
int x = getIdx(a[i]);
dp[i] = pre(x - 1) + 1;
update(x, dp[i]);
}
int ans = pre(n);
vector<int> seq(ans + 1);
vector<int> res;
for (int i = n; i; i -- ) {
if (dp[i] == ans) {
seq[ans] = max(seq[ans], a[i]);
res.push_back(i);
} else {
if (seq[dp[i] + 1] <= a[i]) continue;
seq[dp[i]] = max(seq[dp[i]], a[i]);
res.push_back(i);
}
}
cout << res.size() << endl;
reverse(res.begin(), res.end());
for (auto ite : res) cout << ite << ' '; cout << endl;
}
signed main() {
#ifdef DEBUG
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
auto now = clock();
#endif
ios::sync_with_stdio(false), cin.tie(nullptr);
cout << fixed << setprecision(2);
int _ = 1;
cin >> _;
while (_ -- )
solve();
#ifdef DEBUG
cout << "============================" << endl;
cout << "Program run for " << (clock() - now) / (double)CLOCKS_PER_SEC * 1000 << " ms." << endl;
#endif
return 0;
}