树状数组求LIS方案数
树状数组求LIS方案数
题意
给一个序列,求它的LIS的方案数。
最长上升子序列计数(Bonus) - 题目 - Daimayuan Online Judge
思路
$ n^2$ 解只需在朴素LIS的dp上再做一个方案数的dp。
void solve()
{
int n; cin >> n;
vector<int> a(n + 1);
rep(i,1,n) cin >> a[i];
vector<int> f(n + 1), g(n + 1);
f[1] = g[1] = 1;
rep(i,2,n) {
f[i] = 1;
rep(j,1,i - 1) {
if(a[i] > a[j])
f[i] = max(f[j] + 1, f[i]);
}
rep(j,1,i - 1) {
if(f[i] == 1) g[i] = 1;
else if(a[i] > a[j] and f[j] + 1 == f[i])
g[i] = (g[j] + g[i]) % mod;
}
}
int mx = *max_element(f.begin(),f.end());
int ans = 0;
rep(i,1,n) if(mx == f[i])
ans = (ans + g[i]) % mod;
cout << ans;
}
考虑LIS的优化,可以用树状数组将其优化到 \(O(nlogn)\) 。在这个过程中如何维护方案数呢。
离线树状数组插值,多用于处理下标和值的偏序问题。
树状数组维护的前缀中,我们想要的是LIS长度最大的状态的方案数的和。
所以遇到较大的LIS长度,覆盖,遇到相等的LIS长度合并。
PII operator +(const PII &A, const PII &B) {
PII ans = {0,0};
if(A[0] != B[0]) {
ans = A[0] > B[0] ? A : B;
}else {
ans[0] = A[0];
ans[1] = (A[1] + B[1]) % mod;
}
return ans;
}
完整代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<set>
#include<queue>
#include<map>
#include<stack>
#include<string>
#include<functional>
#include<cassert>
#include<random>
#include<iomanip>
#define yes puts("yes");
#define inf 0x3f3f3f3f
#define ll long long
#define linf 0x3f3f3f3f3f3f3f3fll
#define ull unsigned long long
#define endl '\n'
#define int long long
#define SZ(x) (int)x.size()
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define dec(i,n,a) for(int i = n;i >= a;i--)
using namespace std;
mt19937 mrand(random_device{}());
int rnd(int x) { return mrand() % x;}
using PII = array<int,2>;
const int MAXN =10 + 4e5 ,mod=1e9 + 7;
int n;
int a[MAXN],all[MAXN];
PII f[MAXN];
PII tr[MAXN];
#define lowbit(x) x&(-x)
PII operator +(const PII &A, const PII &B) {
PII ans = {0,0};
if(A[0] != B[0]) {
ans = A[0] > B[0] ? A : B;
}else {
ans[0] = A[0];
ans[1] = (A[1] + B[1]) % mod;
}
return ans;
}
void add(int p,PII val) {
for(;p <= n;p += lowbit(p))
tr[p] = tr[p] + val;
}
PII query(int p) {
PII ans = {0,1};// 方案数至少为1
for(;p;p -= lowbit(p))
ans = ans + tr[p];
return ans;
}
void solve()
{
cin >> n;
rep(i,1,n) cin >> a[i], all[i] = a[i];
sort(all + 1, all + 1 + n);
int len = unique(all + 1,all + 1 + n) - all - 1;
int mx = 0;
rep(i,1,n) {
a[i] = lower_bound(all + 1,all + 1 + len,a[i]) - all;
f[i] = query(a[i] - 1);
f[i][0] += 1;
add(a[i],f[i]);
// cerr<<f[i][1]<<" \n"[i==n];
mx = max(f[i][0], mx);
}
int ans = 0;
rep(i,1,n) if(mx == f[i][0]) ans = (ans + f[i][1]) % mod;
cout << ans;
}
signed main()
{
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//int T;cin>>T;
//while(T--)
solve();
return 0;
}