[TJOI2013]最长上升子序列 (splay)
题目描述
给定一个序列,初始为空。现在我们将1到N的数字插入到序列中,每次将一个数字插入到一个特定的位置。每插入一个数字,我们都想知道此时最长上升子序列长度是多少?
输入格式:
第一行一个整数N,表示我们要将1到N插入序列中。
接下是N个数字,第k个数字Xk,表示我们将k插入到位置Xk(0<=Xk<=k-1,1<=k<=N)
输出格式:
N行,第i行表示i插入Xi位置后序列的最长上升子序列的长度是多少。
输入样例:
6
0 0 2 3 4 1
输出样例:
1
1
2
3
4
4
题解:
- 在第i次操作将i加入,那么在当前数列中i是最大的。
- 如果答案被更新,那么新的LIS以i结尾。
- 以i为结尾的LIS不会再被更新,因为后面插入的数都比i大。
所以我们对最终的序列求一遍LIS,就可以得到以i结尾的LIS长度,那么ans[i]=max(dp[i],dp[j]),j<i;
那么现在的问题就是如何求序列,可以用splay维护中序遍历,先放一个最大值和最小值进去为了避免玄学数组越界。
插入时,将x+1旋到根,x+2旋到根的右儿子,然后再把要插入的数接到x+2的左儿子即可。
注意插入后要更新x+1和x+2的信息。
#include<bits/stdc++.h> using namespace std; const int maxn=100005; const int oo=0x3f3f3f; int n,num,root,mx; int a[maxn],f[maxn]; struct Splay{ int v,size,fa,s[2]; }tr[maxn]; struct answer{ int id,cx; }ans[maxn]; template<class T>inline void read(T &x){ x=0;char ch=getchar(); while(!isdigit(ch)) ch=getchar(); while(isdigit(ch)) {x=(x<<1)+(x<<3)+(ch^48);ch=getchar();} } bool cmp(answer a,answer b){ return a.id<b.id; } int identify(int x){ return x==tr[tr[x].fa].s[1]; } void connect(int x,int y,int d){ tr[x].fa=y;tr[y].s[d]=x; } void update(int x){ tr[x].size=tr[tr[x].s[0]].size+tr[tr[x].s[1]].size+1; } void rotate(int x){ int f=tr[x].fa,ff=tr[f].fa; int d1=identify(x),d2=identify(f); int cs=tr[x].s[d1^1]; connect(cs,f,d1); connect(f,x,d1^1); connect(x,ff,d2); update(x);update(f); } void splay(int x,int go){ if(go==root) root=x; go=tr[go].fa; while(tr[x].fa!=go){ int f=tr[x].fa; if(tr[f].fa==go) rotate(x); else if(identify(x)==identify(f)){rotate(f);rotate(x);} else {rotate(x);rotate(x);} } } int find(int x){ int now=root; while(1){ if(tr[tr[now].s[0]].size>=x) now=tr[now].s[0]; else{ x-=tr[tr[now].s[0]].size; if(x==1) return now; x-=1; now=tr[now].s[1]; } } } void insert(int go,int val){ int x=find(go),y=find(go+1); splay(x,root);splay(y,tr[root].s[1]); tr[++num]=(Splay){val,1,y,{0,0}}; tr[y].s[0]=num; update(y);update(x); splay(num,root); } void nice(int x){ if(tr[x].s[0]) nice(tr[x].s[0]); if(tr[x].v!=oo&&tr[x].v!=-oo) a[++num]=tr[x].v; if(tr[x].s[1]) nice(tr[x].s[1]); } int divi(int x){ int l=0,r=n; while(l<=r){ int mid=(l+r)>>1; if(f[mid]>=x) r=mid-1; else l=mid+1; } return r; } int main(){ read(n); tr[++num]=(Splay){-oo,2,0,{0,2}}; tr[++num]=(Splay){oo,1,1,{0,0}}; root=1; for(int i=1;i<=n;i++){ int x;read(x); insert(x+1,i); } num=0; nice(root); //for(int i=1;i<=n;i++) printf("%d ",a[i]); //putchar(10); for(int i=1;i<=n;i++) f[i]=oo; for(int i=1;i<=n;i++){ int k=divi(a[i])+1; f[k]=a[i]; ans[i].cx=k; ans[i].id=a[i]; } //for(int i=1;i<=n;i++) printf("%d ",ans[i]); sort(ans+1,ans+n+1,cmp); for(int i=1;i<=n;i++) printf("%d\n",mx=max(mx,ans[i].cx)); }