P3970 [TJOI2014]上升子序列
题意:
给定一个长度为\(n\)的序列,它的上升子序列的定义为:
-
是原序列的一个子序列。
-
长度至少为2。
-
序列内元素严格单调递增。
如果有两个子序列元素相同,那么只会计算一次,求整个序列的所有上升子序列个数,并对\(1e9+7\)取模。
解题思路:
首先离散化。
通过题目中的定义,我们可以得到一个比较显然的\(dp\)转移方程,设\(f_i\)为\(1\)至\(i\)中以\(a_i\)为结尾的上升子序列个数,那么有:
\[f_i=\sum_{j=1}^{j<i}(f_j +1)\ \ \ \ (a_j<a_i)
\]
每个以\(a_i\)为结尾的子序列都可以由前面已有的且结尾元素小于\(a_i\)的上升子序列转移过来,同时还要加上\(a_i\)与\(a_j\)所产生的长度为\(2\)的子序列,所以我们得到了上式。
可悲的是,暴力\(O(n^2)\)只可过\(30pts\),考虑用数据结构优化为\(O(nlogn)\)。
为了好写方便去重,咱选择了线段树√。
稍微思考一下,容易发现,若\(a_i=a_j,i>j\),那么这时在\(i\)位置统计可能会统计到比\(j\)更多也更完整的结果,所以我们可以直接抛弃掉前面已经计算过的结果,选择位置靠后的这个并在值域上进行单点修改,这样做是不用另开数组去重的,而且思维量也会相应地减少很多。
最后的答案即为:\(\sum_{i=1}^{n} f_i\)。
代码:
#include <bits/stdc++.h>
#define Reg register
#define lson (rt<<1)
#define rson (rt<<1|1)
using namespace std;
const int maxn=100010,mod=1e9+7;
int n,ans,a[maxn],ark[maxn],tot;
struct segmenttree{
int l,r,val,cnt;
//维护f[i]和已经存在的元素个数
}tr[maxn<<2];
inline int read(){
int s=0,w=1;
char ch=getchar();
while(ch<'0'||ch>'9'){
if(ch=='-') w=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
s=(s<<3)+(s<<1)+(ch^48),
ch=getchar();
return s*w;
}
inline void build(int rt,int l,int r){
tr[rt].l=l;
tr[rt].r=r;
if(l==r) return;
int mid=(l+r)>>1;
build(lson,l,mid);
build(rson,mid+1,r);
}
inline void pushup(int rt){
tr[rt].val=1ll*(tr[lson].val+tr[rson].val)%mod;
tr[rt].cnt=tr[lson].cnt+tr[rson].cnt;
}
inline void updateseg(int rt,int pos,int val){
if(tr[rt].l==tr[rt].r){
tr[rt].val=val,tr[rt].cnt=1;
return;
}
int mid=(tr[rt].l+tr[rt].r)>>1;
if(pos<=mid) updateseg(lson,pos,val);
else updateseg(rson,pos,val);
pushup(rt);
}
inline int queryseg(int rt,int l,int r){
if(l<=tr[rt].l&&tr[rt].r<=r) return (tr[rt].val+tr[rt].cnt)%mod;//加上新产生的子序列
int mid=(tr[rt].l+tr[rt].r)>>1;
int sum=0;
if(l<=mid) sum=1ll*(sum+queryseg(lson,l,r))%mod;
if(r>mid) sum=1ll*(sum+queryseg(rson,l,r))%mod;
return sum;
}
int main(){
n=read();
for(int i=1;i<=n;++i) ark[i]=a[i]=read();
sort(ark+1,ark+1+n);
tot=unique(ark+1,ark+1+n)-ark-1;
build(1,1,tot);
for(int i=1;i<=n;++i){
a[i]=lower_bound(ark+1,ark+1+tot,a[i])-ark;
int val=queryseg(1,1,a[i]-1);
updateseg(1,a[i],val);
}
printf("%d\n",((queryseg(1,1,tot)-tot)%mod+mod)%mod);
//注意这里:最后的结果不会产生新序列,所以需要减去总个数
return 0;
}