[COGS859] 数列
该题要求满足 i<j<k 且 ai<aj>ak 的三元组(i,j,k)的个数。
对于经典的逆序对的一种求解方法是对于元素 ai 求出满足 aj>ai 且 i<j 的元素的个数,线段树,树状数组以及平衡树都可以支持这个操作,用平衡树简单清晰,只需要依次插入每个元素并求一下当前平衡树中大于 ai 的元素的个数累加进答案即可。
对于本题只不过多求一遍。
1:顺序插入每个元素,插入前求出当前元素在当前平衡树中的rank记录为 b[i] 。
2:逆序插入每个元素,插入前求出当前元素在当前平衡树中的rank记录为 c[i] 。
3:根据乘法原理,ans=Σ b[i]*c[i]。
我写这道题只是为了练一下splay双旋的板子,没想到把splay操作写挂了,而且挂的离谱,我自己看到后都震惊了,写程序的时候得想着啥才能把splay写成这样......
第一次交的时候的splay代码。
void splay(int x,int &p) { int y=f[x],z=f[y],q=f[p]; while(y!=q) { if(z==q) x==l[y]?r_rot(x):l_rot(x); else if(x==l[y]&&y==l[z]) r_rot(y),r_rot(x); else if(x==l[y]&&y==r[x]) r_rot(x),l_rot(y); else if(x==r[x]&&y==l[x]) l_rot(x),r_rot(x); else l_rot(y),l_rot(x); y=f[x],z=f[y]; } p=x; }
那是一坨什么东西......
// q.c #include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<cmath> using namespace std; const int M=50000+10; struct SplayTree { int root,cnt,l[M],r[M],f[M],v[M],s[M]; void clear() { root=cnt=0; memset(l,0,sizeof(l)); memset(r,0,sizeof(r)); memset(f,0,sizeof(f)); memset(v,0,sizeof(v)); memset(s,0,sizeof(s)); } void update(int x) { s[x]=s[l[x]]+s[r[x]]+1; } void l_rot(int x) { int y=f[x],z=f[y]; f[x]=z; if(z) y==l[z]?l[z]=x:r[z]=x; if(l[x]) f[l[x]]=y; r[y]=l[x],f[y]=x,l[x]=y; update(y),update(x); } void r_rot(int x) { int y=f[x],z=f[y]; f[x]=z; if(z) y==l[z]?l[z]=x:r[z]=x; if(r[x]) f[r[x]]=y; l[y]=r[x],f[y]=x,r[x]=y; update(y),update(x); } void splay(int x,int &p) { int y=f[x],z=f[y],q=f[p]; while(y!=q) { if(z==q) x==l[y]?r_rot(x):l_rot(x); else if(x==l[y]&&y==l[z]) r_rot(y),r_rot(x); else if(x==l[y]&&y==r[z]) r_rot(x),l_rot(x); else if(x==r[y]&&y==l[z]) l_rot(x),r_rot(x); else l_rot(y),l_rot(x); y=f[x],z=f[y]; } p=x; } void insert(int &x,int fa,int k) { if(!x) x=++cnt,f[x]=fa,v[x]=k,s[x]=1; else { if(k<=v[x]) insert(l[x],x,k); else insert(r[x],x,k); update(x); } } int query(int k) { int ans=0,x=root,px=root; while(x) { if(v[x]<k) ans+=s[l[x]]+1,px=x,x=r[x]; else x=l[x]; } splay(px,root); return ans; } }t; int n,a[M],b[M],c[M]; long long ans; int main() { freopen("queueb.in","r",stdin); freopen("queueb.out","w",stdout); scanf("%d",&n); for(int i=1;i<=n;i++) scanf("%d",&a[i]); t.clear(); for(int i=1;i<=n;i++) { b[i]=t.query(a[i]); t.insert(t.root,0,a[i]); } t.clear(); for(int i=n;i>=1;i--) { c[i]=t.query(a[i]); t.insert(t.root,0,a[i]); } for(int i=1;i<=n;i++) ans+=b[i]*c[i]; printf("%lld\n",ans); return 0; }