BZOJ2119 股市的预测 字符串 SA ST表
原文链接https://www.cnblogs.com/zhouzhendong/p/9069171.html
题目传送门 - BZOJ2119
题意
给定一个股票连续$n$个时间点的价位,问有多少段股票走势在间隔$m$单位时间之后重现?
$n\leq 5\times 10^4,m\leq 10$
题解
和 http://www.cnblogs.com/zhouzhendong/p/9025092.html 此题十分类似。
这里稍微讲讲本题的不同之处。
首先相邻值求差,转换成字符串匹配。
然后,由于要间隔$m$个字符,所以我们在找关键点的对应点的时候要稍微改一改。
对于得到的$lcs$和$lcp$,在计算的时候有一点小小的注意点。
详见代码。
代码
#include <bits/stdc++.h> using namespace std; typedef long long LL; const int N=100005; int n,m,a[N],tot=0; map <int,int> mp; int SA[N],rank[N],tmp[N],height[N],tax[N]; int ST[N][20]; void Sort(int n,int m){ for (int i=0;i<=m;i++) tax[i]=0; for (int i=1;i<=n;i++) tax[rank[i]]++; for (int i=1;i<=m;i++) tax[i]+=tax[i-1]; for (int i=n;i>=1;i--) SA[tax[rank[tmp[i]]]--]=tmp[i]; } bool cmp(int rk[],int x,int y,int w){ return rk[x]==rk[y]&&rk[x+w]==rk[y+w]; } void Suffix_Array(int s[],int n){ memset(SA,0,sizeof SA); memset(tmp,0,sizeof tmp); memset(rank,0,sizeof rank); memset(height,0,sizeof height); for (int i=1;i<=n;i++) rank[i]=s[i],tmp[i]=i; int m=tot+1; Sort(n,m); for (int w=1,p=0;p<n;w<<=1,m=p){ p=0; for (int i=n-w+1;i<=n;i++) tmp[++p]=i; for (int i=1;i<=n;i++) if (SA[i]>w) tmp[++p]=SA[i]-w; Sort(n,m); swap(rank,tmp); rank[SA[1]]=p=1; for (int i=2;i<=n;i++) rank[SA[i]]=cmp(tmp,SA[i],SA[i-1],w)?p:++p; } for (int i=1,j,k=0;i<=n;height[rank[i++]]=k) for (k=max(k-1,0),j=SA[rank[i]-1];s[i+k]==s[j+k];k++); height[1]=0; } void Get_ST(int n){ memset(ST,0,sizeof ST); for (int i=1;i<=n;i++){ ST[i][0]=height[i]; for (int j=1;j<20;j++){ ST[i][j]=ST[i][j-1]; if (i-(1<<(j-1))>0) ST[i][j]=min(ST[i][j],ST[i-(1<<(j-1))][j-1]); } } } int Query(int L,int R){ int val=floor(log(R-L+1)/log(2)); return min(ST[L+(1<<val)-1][val],ST[R][val]); } int LCP(int x,int y){ x=rank[x],y=rank[y]; return Query(min(x,y)+1,max(x,y)); } int LCS(int x,int y){ return LCP(n*2+2-x,n*2+2-y); } int main(){ scanf("%d%d",&n,&m); for (int i=1;i<=n;i++) scanf("%d",&a[i]); n--; mp.clear(); for (int i=1;i<=n;i++){ a[i]=a[i+1]-a[i]; if (mp[a[i]]==0) mp[a[i]]=++tot; a[n*2+2-i]=a[i]=mp[a[i]]; } a[n+1]=tot+1; Suffix_Array(a,n*2+1); Get_ST(n*2+1); LL ans=0; for (int L=1;L<=n;L++) for (int i=1;i+L+m<=n;i+=L){ int x=i,y=i+L+m; int lcp=min(LCP(x,y),L),lcs=min(LCS(x,y),L); int len=lcp+lcs-(lcp>0&&lcs>0); ans+=max(len-L+1,0); } printf("%lld\n",ans); return 0; }