树状数组求LIS方案数

树状数组求LIS方案数

题意

给一个序列,求它的LIS的方案数。

最长上升子序列计数(Bonus) - 题目 - Daimayuan Online Judge

思路

$ n^2$ 解只需在朴素LIS的dp上再做一个方案数的dp。

void solve()
{    
    int n; cin >> n;
    vector<int> a(n + 1);
    rep(i,1,n) cin >> a[i];
    vector<int> f(n + 1), g(n + 1);
    f[1] = g[1] = 1;
    rep(i,2,n) {
        f[i] = 1;
        rep(j,1,i - 1) {
            if(a[i] > a[j])
                f[i] = max(f[j] + 1, f[i]);
        }
        rep(j,1,i - 1) {
            if(f[i] == 1) g[i] = 1;
            else if(a[i] > a[j] and f[j] + 1 == f[i])
                g[i] = (g[j] + g[i]) % mod;
        }
    }
    int mx = *max_element(f.begin(),f.end());
    int ans = 0;
    rep(i,1,n) if(mx == f[i])
        ans = (ans + g[i]) % mod;
    cout << ans;
}

考虑LIS的优化,可以用树状数组将其优化到 \(O(nlogn)\) 。在这个过程中如何维护方案数呢。


离线树状数组插值,多用于处理下标和值的偏序问题。


树状数组维护的前缀中,我们想要的是LIS长度最大的状态的方案数的和。

所以遇到较大的LIS长度,覆盖,遇到相等的LIS长度合并。

PII operator +(const PII &A, const PII &B) {
    PII ans = {0,0};
    if(A[0] != B[0]) {
        ans = A[0] > B[0] ? A : B;
    }else {
        ans[0] = A[0];
        ans[1] = (A[1] + B[1]) % mod;
    }
    return ans;
}

完整代码:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<set>
#include<queue>
#include<map>
#include<stack>
#include<string>
#include<functional>
#include<cassert>
#include<random>
#include<iomanip>
#define yes puts("yes");
#define inf 0x3f3f3f3f
#define ll long long
#define linf 0x3f3f3f3f3f3f3f3fll
#define ull unsigned long long
#define endl '\n'
#define int long long
#define SZ(x) (int)x.size()
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define dec(i,n,a) for(int i = n;i >= a;i--)
using namespace std;
mt19937 mrand(random_device{}());
int rnd(int x) { return mrand() % x;}
using PII = array<int,2>;
const int MAXN =10 + 4e5 ,mod=1e9 + 7;
int n;
int a[MAXN],all[MAXN];
PII f[MAXN];

PII tr[MAXN];
#define lowbit(x) x&(-x)
PII operator +(const PII &A, const PII &B) {
    PII ans = {0,0};
    if(A[0] != B[0]) {
        ans = A[0] > B[0] ? A : B;
    }else {
        ans[0] = A[0];
        ans[1] = (A[1] + B[1]) % mod;
    }
    return ans;
}
void add(int p,PII val) {
    for(;p <= n;p += lowbit(p))
        tr[p] = tr[p] + val;
}
PII query(int p) {
    PII ans = {0,1};// 方案数至少为1
    for(;p;p -= lowbit(p))
        ans = ans + tr[p];
    return ans;
}
void solve()
{    
    cin >> n;
    rep(i,1,n) cin >> a[i], all[i] = a[i];
    
    sort(all + 1, all + 1 + n);
    int len = unique(all + 1,all + 1 + n) - all - 1;
    int mx = 0;
    rep(i,1,n) {
        a[i] = lower_bound(all + 1,all + 1 + len,a[i]) - all;
        
        f[i] = query(a[i] - 1);
        f[i][0] += 1;
        add(a[i],f[i]);
        // cerr<<f[i][1]<<" \n"[i==n];
        mx = max(f[i][0], mx);
    }
    
    int ans = 0;

    rep(i,1,n) if(mx == f[i][0]) ans = (ans + f[i][1]) % mod;

    cout << ans;
}
signed main()
{
    ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);

    //int T;cin>>T;
    //while(T--)
        solve();

    return 0;
}
posted @ 2022-09-14 16:37  Mxrurush  阅读(74)  评论(0编辑  收藏  举报