daimayuan#884. 最长上升子序列计数 (线段树优化dp)

http://oj.daimayuan.top/problem/884
image

  • f[i] 表示以a[i]结尾的最长上升子序列,cnt[i]表示以a[i]结尾的最长上升子序列的个数。 可以n方转移: f[i] = max(f[j] + 1, f[i]); cnt[i] += cnt[j] | (f[i] == f[j] + 1)
  • 发现f的转移就是在i之前找小于a[i]的最大f[i],这个可以用权值线段树处理
  • 单点修改,区间查询,只需要考虑如何合并
Tr comp( Tr a, Tr b) {
 if(!a.len) return b; //特判0
 if(!b.len) return a;
 if( a.len == b.len ) return Tr{ a.len, (a.cnt + b.cnt) % mod };
 if( a.len > b.len ) return a;
 return b;
}
#include<bits/stdc++.h>
using namespace std;
#define IOS ios::sync_with_stdio(false) ,cin.tie(0), cout.tie(0);
//#pragma GCC optimize(3,"Ofast","inline")
#define ll long long
//#define int long long
const int N = 4e5 + 6;
const int M = 2e6 + 6;
const ll P = 1e9 + 7;
const int INF = 0x3f3f3f3f;
const ll LNF = 0x3f3f3f3f3f3f3f3f;
const int mod = 1e9 + 7;
const double PI = acos(-1.0);
struct Tr {
 ll len, cnt;
} tree[N << 2], f[N];
Tr comp( Tr a, Tr b) {
 if(!a.len) return b;
 if(!b.len) return a;
 if( a.len == b.len ) return Tr{ a.len, (a.cnt + b.cnt) % mod };
 if( a.len > b.len ) return a;
 return b;
}
void build( int l, int r, int rt ) {
 if( l == r ) {
   tree[rt] = Tr{0, 1};
   return;
 }
 int mid = l + r >> 1;
 build( l, mid , rt << 1 ); build( mid + 1, r, rt << 1 | 1 );
 tree[rt] = comp( tree[rt << 1], tree[rt << 1 | 1] );
}
void modify ( int pos, Tr val, int l, int r, int rt ) {
 if( l == r ) {
   if( tree[rt].len == val.len ) tree[rt].cnt = (tree[rt].cnt + val.cnt) % mod;
   else tree[rt] = val;
   return;
 }
 int mid = l + r >> 1;
 if( pos <= mid ) modify( pos, val, l, mid, rt << 1 );
 else modify( pos, val, mid + 1, r, rt << 1 | 1 );
 tree[rt] = comp( tree[rt << 1], tree[rt << 1 | 1] );
}
Tr query( int a, int b, int l, int r, int rt ) {
 
 if( b < l || a > r) return Tr{0, 1};
 if( l >= a && r <= b ) {
   return tree[rt];
 }
 int mid = l + r >> 1;
 return comp( query( a, b, l, mid, rt << 1 ), query( a, b, mid + 1, r, rt << 1 | 1 ) );
}
int a[N];
int main() {
 IOS
 int n; cin >> n;
 vector<int> ve;
 for ( int i = 1; i <= n; ++ i ) {
   cin >> a[i]; ve.push_back(a[i]);
 }
 sort(ve.begin(), ve.end());
 ve.erase( unique(ve.begin(), ve.end()), ve.end());
 for ( int i = 1; i <= n ;++ i ) {
   a[i] = lower_bound( ve.begin(), ve.end(), a[i]) - ve.begin() + 1;
 }
 int limit = ve.size();
 build( 0, limit, 1 );
 for ( int i = 1; i <= n; ++ i ) {
   f[i] = query( 0, a[i] - 1, 0, limit, 1); f[i].len += 1;
   Tr de1 = f[i];
   modify( a[i], f[i], 0, limit, 1);
 }
 int mx = 1, cnt = 0 ;
 for ( int i = 1; i <= n; ++ i ) {
   if( f[i].len > mx) {
     mx = f[i].len, cnt = f[i].cnt;
   } else if( f[i].len == mx) {
     cnt += f[i].cnt; cnt %= mod;
   }
 }
 cout << cnt << '\n';
 return 0;
}
posted @ 2022-05-19 11:44  qingyanng  阅读(172)  评论(0编辑  收藏  举报