bzoj2119 股市的预测
题意:给你一个字符串,将其差分之后,问有多少个子串满足长度为m+2i且前i个字符组成的子串和后i个字符组成的子串相同.(m在题中给出,i可以为任意值)
分析:考虑子串[l,r]什么情况下满足条件.首先r-l+1==m+2i,并且lcp(l,r-i+1)>=i.那么我们考虑按height合并所有后缀,在这个过程中考虑每一对后缀及其lcp,那么我们考虑每一对后缀能否作为合法的l和r-i+1即可.显然,一对后缀最多对答案贡献1.假设有一对后缀a,b(a<b)那么这对后缀能够作为合法的(l,r-i+1)的条件是lcp(a,b)+m>=j-i且j-i-1>=m,那么我们使用平衡树启发式合并,统计答案的时候也使用类似启发式合并的思想枚举较小一棵树的后缀在另一棵树中查询排名即可,复杂度为O(nlog^2n).
#include<cstdio> #include<cstdlib> #include<algorithm> using namespace std; const int maxn=50005; int tmp[2][maxn],sum[maxn],key[maxn],sa[maxn],rank[maxn],height[maxn]; int str[maxn]; bool cmp(const int &a,const int &b){ return str[a]<str[b]; } void getsa(int n,int m){ int i,j,k,p,*rk=tmp[0],*res=tmp[1]; for(i=0;i<n;++i)sa[i]=i; sort(sa,sa+n,cmp); for(rk[sa[0]]=0,m=1,i=1;i<n;++i) rk[sa[i]]=(str[sa[i]]==str[sa[i-1]])?m-1:m++; for(p=0,j=1;p<n;j<<=1,m=p){ for(i=0;i<m;++i)sum[i]=0; for(p=0,i=n-j;i<n;++i)res[p++]=i; for(i=0;i<n;++i)if(sa[i]>=j)res[p++]=sa[i]-j; for(i=0;i<n;++i)sum[key[i]=rk[res[i]]]++; for(i=1;i<m;++i)sum[i]+=sum[i-1]; for(i=n-1;i>=0;--i)sa[--sum[key[i]]]=res[i]; for(res[sa[0]]=0,p=1,i=1;i<n;++i) res[sa[i]]=(rk[sa[i]]==rk[sa[i-1]]&&rk[sa[i]+j]==rk[sa[i-1]+j])?p-1:p++; swap(res,rk); } for(i=1;i<n;++i)rank[sa[i]]=i; for(i=0,k=0;i<n-1;height[rank[i++]]=k) for(k?k--:0,j=sa[rank[i]-1];str[i+k]==str[j+k];++k); } struct node{ node* ch[2]; int ord,key,sz; node(int x){ key=x;ord=rand();sz=1;ch[0]=ch[1]=0; } void update(){ sz=1; if(ch[0])sz+=ch[0]->sz; if(ch[1])sz+=ch[1]->sz; } }; void rot(node* &rt,int t){ node* c=rt->ch[t];rt->ch[t]=c->ch[t^1];c->ch[t^1]=rt;rt=c; rt->ch[t^1]->update();rt->update(); } int Rank(node* rt,int x){//cnt how many y<x if(!rt)return 0; int lsz=(rt->ch[0])?rt->ch[0]->sz:0; if(x<rt->key)return Rank(rt->ch[0],x); return lsz+1+Rank(rt->ch[1],x); } void Insert(node* &rt,int x){ if(!rt)rt=new node(x); else{ int t=x>rt->key; Insert(rt->ch[t],x); rt->update(); if(rt->ch[t]->ord>rt->ord)rot(rt,t); } } int n,m; long long ans=0; node* root[maxn]; int seq[maxn],ufs[maxn]; int find(int x){ return ufs[x]==x?x:ufs[x]=find(ufs[x]); } bool cmp2(const int &a,const int &b){ return height[a]>height[b]; } void traverse(node* rt1,node* &rt2,int lcp){ if(!rt1)return; ans+=Rank(rt2,rt1->key+lcp+m)-Rank(rt2,rt1->key+m); ans+=Rank(rt2,rt1->key-m-1)-Rank(rt2,rt1->key-m-lcp-1); traverse(rt1->ch[0],rt2,lcp);traverse(rt1->ch[1],rt2,lcp); } void Merge(node* &rt1,node* &rt2){ if(!rt1)return; Merge(rt1->ch[0],rt2);Merge(rt1->ch[1],rt2); Insert(rt2,rt1->key); delete rt1;rt1=0; } void merge(int a,int b,int lcp){ a=find(a);b=find(b); if(root[a]->sz>root[b]->sz)swap(a,b); traverse(root[a],root[b],lcp); Merge(root[a],root[b]); ufs[a]=b; } int main(){ scanf("%d%d",&n,&m); for(int i=0;i<n;++i)scanf("%d",&str[i]); for(int i=0;i<n;++i)str[i]-=str[i+1]; str[n-1]=0x80808080; getsa(n,0); n--; for(int i=1;i<=n;++i)root[i]=new node(sa[i]); for(int i=2;i<=n;++i)seq[i]=i; for(int i=1;i<=n;++i)ufs[i]=i; sort(seq+2,seq+n+1,cmp2); for(int i=2;i<=n;++i){ merge(seq[i],seq[i]-1,height[seq[i]]); } printf("%lld\n",ans); return 0; }