【BZOJ2119】股市的预测 后缀数组+分块
【BZOJ2119】股市的预测
Description
墨墨的妈妈热爱炒股,她要求墨墨为她编写一个软件,预测某只股票未来的走势。股票折线图是研究股票的必备工具,它通过一张时间与股票的价位的函数图像清晰地展示了股票的走势情况。经过长时间的观测,墨墨发现很多股票都有如下的规律:之前的走势很可能在短时间内重现!如图可以看到这只股票A部分的股价和C部分的股价的走势如出一辙。通过这个观测,墨墨认为他可能找到了一个预测股票未来走势的方法。进一步的研究可是难住了墨墨,他本想试图统计B部分的长度与发生这种情况的概率关系,不过由于数据量过于庞大,依赖人脑的力量难以完成,于是墨墨找到了善于编程的你,请你帮他找一找给定重现的间隔(B部分的长度),有多少个时间段满足首尾部分的走势完全相同呢?当然,首尾部分的长度不能为零。
Input
输入的第一行包含两个整数N、M,分别表示需要统计的总时间以及重现的间隔(B部分的长度)。接下来N行,每行一个整数,代表每一个时间点的股价。
Output
输出一个整数,表示满足条件的时间段的个数
Sample Input
1 2 3 4 8 9 1 2 3 4 8 9
Sample Output
【样例说明】
6个时间段分别是:3-9、2-10、2-8、1-9、3-11、4-12。
HINT
对于100%的数据,4≤N≤50000 1≤M≤10 M≤N 所有出现的整数均不超过32位含符号整数。
题解:判断走势的话一定要用到差分,看起来数据规模较大,所以我们再离散化一下。
设B段长度为M,首先我们可以枚举A段的长度L,然后我们每隔连续的L个时间就选择一个关键点,这样做的用意何在?因为A段的长度就是L,所以每一个合法的A段都包含且仅包含一个关键点,所以我们可以枚举关键点,看一下有多少个合法的A段包含它,这样就能保证不重不漏。
具体地,对于关键点i,我们如何知道它被那些合法的A段所包含呢?我们令j=i+L+M,那么我们从i,j向左右拓展,找到最长的相同的段,我们设向左拓展了l,向右拓展了r,即[i-l...i+r]=[j-l,j+r],如果r+l+1大于L,那么答案就加上r+l+2-L。拓展的时候需要将原串的正串和反串都求一遍后缀数组,然后分别求LCP。
但是,为了满足上面黑字的那条性质,我们向左右拓展的时候不能拓展到其它的关键点,即保证l,r<L
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; const int maxn=50010; typedef long long ll; int n,m,L,ans; int Log[maxn],v[maxn]; struct SA { int r[maxn],ra[maxn],rb[maxn],st[maxn],sa[maxn],f[maxn][20],h[maxn],rank[maxn]; void getsa() { int i,j,k,p,*x=ra,*y=rb; for(i=0;i<n;i++) st[x[i]=r[i]]++; for(i=1;i<m;i++) st[i]+=st[i-1]; for(i=n-1;i>=0;i--) sa[--st[x[i]]]=i; for(j=p=1;p<n;j<<=1,m=p) { for(p=0,i=n-j;i<n;i++) y[p++]=i; for(i=0;i<n;i++) if(sa[i]>=j) y[p++]=sa[i]-j; for(i=0;i<m;i++) st[i]=0; for(i=0;i<n;i++) st[x[y[i]]]++; for(i=1;i<m;i++) st[i]+=st[i-1]; for(i=n-1;i>=0;i--) sa[--st[x[y[i]]]]=y[i]; for(swap(x,y),x[sa[0]]=0,i=p=1;i<n;i++) x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+j]==y[sa[i-1]+j])?p-1:p++; } for(i=1;i<n;i++) rank[sa[i]]=i; for(i=k=0;i<n-1;h[rank[i++]]=k) for(k?k--:0,j=sa[rank[i]-1];r[i+k]==r[j+k];k++); } void getf() { int i,j; for(i=1;i<=n;i++) f[i][0]=h[i]; for(j=1;(1<<j)<=n;j++) for(i=0;i+(1<<j)-1<=n;i++) f[i][j]=min(f[i][j-1],f[i+(1<<j-1)][j-1]); } int getlcp(int a,int b) { if(b>n||a<0) return 0; a=rank[a],b=rank[b]; if(a>b) swap(a,b); a++; int k=Log[b-a+1]; return min(f[a][k],f[b-(1<<k)+1][k]); } }A,B; struct node { ll val; int org; }num[maxn]; bool cmp(node a,node b) { return a.val<b.val; } int rd() { int ret=0,f=1; char gc=getchar(); while(gc<'0'||gc>'9') {if(gc=='-')f=-f; gc=getchar();} while(gc>='0'&&gc<='9') ret=ret*10+gc-'0',gc=getchar(); return ret*f; } void solve(int siz) { int i,j,a,b,l,r; for(i=0;i+L+siz<n;i+=siz) { j=i+siz+L; if(A.r[i]!=A.r[j]) continue; a=min(siz,A.getlcp(i,j)),b=min(siz-1,B.getlcp(n-i,n-j)); if(a+b>=siz) ans+=a+b+1-siz; } } int main() { n=rd(),L=rd(); int i; v[0]=rd(); for(i=1;i<n;i++) v[i]=rd(),num[i].org=i,num[i].val=v[i]-v[i-1]; sort(num+1,num+n,cmp); num[0].val=-1ll<<60; for(i=1;i<n;i++) { if(num[i].val>num[i-1].val) m++; A.r[num[i].org-1]=m,B.r[n-num[i].org-1]=m; } m++; for(i=2;i<=n;i++) Log[i]=Log[i>>1]+1; A.getsa(),B.getsa(),A.getf(),B.getf(),n--; for(i=1;i*2+L<=n;i++) solve(i); printf("%d",ans); return 0; }
| 欢迎来原网站坐坐! >原文链接<