P1637 三元上升子序列
三元上升子序列的题解,利用线段树和离散化,或者用dp求解答案
P1637 三元上升子序列
简要题意,在一个序列中寻找长度为三的上升子序列
思路
有两种思路
直接法
一种是对于一个树,算一个数左边比他小的数,算右边比他大的数,然后相乘即是该该点处值
算比他大的数,和比他小的数,用树状数组或线段树即皆可
根据题目数据范围需要离散化
CODE
#include<bits/stdc++.h>
using namespace std;
#define x first
#define y second
#define ll long long
int n;
const int maxn=1e5+10;
pair<int,int>m[maxn];
int t[maxn<<2];
int num[maxn];
int j[maxn<<2];
int sml[maxn];//比它小的数
int smx[maxn];//比它大的数
void push_up(int p){
t[p]=t[p<<1]+t[p<<1|1];
}
void update(int p,int l,int r,int nl,int nr){
if(l==r && nl==l){
++t[p];
return ;
}
int mid=(l+r)>>1;
if(nl<=mid) update(p<<1,l,mid,nl,nr);
if(nr>mid) update(p<<1|1,mid+1,r,nl,nr);
push_up(p);
}
ll query(int p,int l,int r,int nl,int nr){
ll res=0;
if(nl<=l && r<=nr){
return t[p];
}
int mid=(l+r)>>1;
if(nl<=mid) res+=query(p<<1,l,mid,nl,nr);
if(nr>mid) res+=query(p<<1|1,mid+1,r,nl,nr);
push_up(p);
return res;
}
int main(){
cin>>n;
for(int i=1;i<=n;++i){
int x;cin>>m[i].x;
m[i].y=i;
}
sort(m+1,m+1+n);
int cnt=0;//cnt是离散后的大小
for(int i=1;i<=n;++i){
if(m[i].x>m[i-1].x) ++cnt;
num[m[i].y]=cnt;//离散化
}
for(int i=1;i<=n;++i){
if(num[i]>1) sml[i]=query(1,1,n,1,num[i]-1);
update(1,1,n,num[i],num[i]);
}
memset(t,0,sizeof(t));
for(int i=n;i>=1;--i){
if(num[i]<n) smx[i]=query(1,1,n,num[i]+1,n);
update(1,1,n,num[i],num[i]);
}
ll ans=0;
for(int i=1;i<=n;++i) ans+=(smx[i]*sml[i]);
cout<<ans<<endl;
return 0;
}
DP
上升子序列,其实可以让我们很容易相到,最长上升子序列的求法,只需稍加修改即可
令f[i][j]是以a[j]为结尾长度为i的上升子序列
(该思路可以求i元上升子序列)
\(f[i][j]=\sum_{k<j,a[k]<a[j]}f[i-1][k]\)
利用桶排序的思想,储存f[i][j],在第a[j]个点,这样在转移只需要求
小于a[j]的和即可
如何高效的统计这个和,则会用到树状数组或线段树
遍历i时
树状数组存的应是f[i-1][k]的和(因此遍历i之后需要清空树状数组)
遍历第j+1个序列时,前j个序列 第a[j]点的位置加上f[i-1][j]
寻找满足状态转移方程则 f[i][j]+=sum(a[j]-1)
#include<bits/stdc++.h>
using namespace std;
#define x first
#define y second
#define ll long long
int n;
const int maxn=1e5+10;
pair<int,int> m[maxn];
int a[maxn];
int f[5][maxn];
int lowbit(int x){
return x&(-x);
}
int t[maxn];
void add(int x,int k){
while(x<=n){
t[x]+=k;
x+=lowbit(x);
}
}
ll sum(int x){
ll sum=0;
while(x){
sum+=t[x];
x-=lowbit(x);
}
return sum;
}
int main(){
cin>>n;
for(int i=1;i<=n;++i) {
cin>>m[i].x;
m[i].y=i;
}
int num=0;
sort(m+1,m+1+n);
for(int i=1;i<=n;++i){
if(m[i].x>m[i-1].x) ++num;
a[m[i].y]=num;
}
for(int i=1;i<=n;++i) f[1][i]=1;
for(int i=2;i<=3;++i){
memset(t,0,sizeof(t));
for(int j=1;j<=n;++j){
f[i][j]=sum(a[j]-1);
add(a[j],f[i-1][j]);
}
}
ll ans=0;
for(int i=1;i<=n;++i) ans+=f[3][i];
cout<<ans<<endl;
return 0;
}