bzoj 4650 & 洛谷 P1117 优秀的拆分 —— 枚举关键点+后缀数组
题目:https://www.lydsy.com/JudgeOnline/problem.php?id=4650
https://www.luogu.org/problemnew/show/P1117
枚举每一段 a 的长度,然后分块,后缀数组求出每一块首关键点附近的可行范围;
然后用线段树区间加,区间查询;
在开头范围查询结尾,就得到长度 <= 当前长度 d 的前半部分接上长度 = d 的后半部分的答案;
在结尾范围查询开头,就得到长度 = d 的前半部分接上长度 < d 的后半部分的答案,正好组合了所有情况!
于是写了,跑出来发现比别人慢了很多;
然后发现为什么要边找边算答案...可以差分数组区间加,最后直接每个位置算答案即可...
注意多组数据,要清空 rk[n+1] 和 tp[n+1],因为求后缀数组的时候用到了超出 n 的 tp(rk) 。
代码如下:
#include<cstdio> #include<cstring> #include<algorithm> #define mid ((l+r)>>1) #define ls (x<<1) #define rs ((x<<1)|1) using namespace std; typedef long long ll; int const xn=6e4+5;//n<<1 int rk[xn],sa[xn],tax[xn],tp[xn],ht[xn][20],bin[20],bit[xn],op[xn]; ll sum[2][xn<<1],lzy[2][xn<<1];//n<<2 char s[xn]; int Min(int x,int y){return x<y?x:y;} int Max(int x,int y){return x>y?x:y;} void rsort(int n,int m) { for(int i=1;i<=m;i++)tax[i]=0; for(int i=1;i<=n;i++)tax[rk[tp[i]]]++; for(int i=1;i<=m;i++)tax[i]+=tax[i-1]; for(int i=n;i;i--)sa[tax[rk[tp[i]]]--]=tp[i]; } void work(int n) { int m=125; for(int i=1;i<=n;i++)rk[i]=s[i],tp[i]=i; rk[n+1]=0; tp[n+1]=0;//! rsort(n,m); for(int k=1;k<=n;k<<=1) { int num=0; for(int i=n-k+1;i<=n;i++)tp[++num]=i; for(int i=1;i<=n;i++) if(sa[i]>k)tp[++num]=sa[i]-k; rsort(n,m); swap(rk,tp); rk[sa[1]]=1; num=1; for(int i=2;i<=n;i++) { int u=Min(n+1,sa[i]+k),v=Min(n+1,sa[i-1]+k);// rk[sa[i]]=(tp[sa[i]]==tp[sa[i-1]]&&tp[u]==tp[v])?num:++num;//tp[n+1]! } if(num==n)break; m=num; } } void geth(int n) { int k=0; ht[1][0]=0; for(int i=1;i<=n;i++) { if(rk[i]==1)continue; if(k)k--; int j=sa[rk[i]-1]; while(i+k<=n&&j+k<=n&&s[i+k]==s[j+k])k++; ht[rk[i]][0]=k; } for(int i=1;i<20;i++) for(int j=1;j+bin[i]-1<=n;j++) ht[j][i]=Min(ht[j][i-1],ht[j+bin[i-1]][i-1]); } int getlcp(int x,int y,int n) { if(x==y)return n-x+1; x=rk[x]; y=rk[y]; if(x>y)swap(x,y); x++; int r=bit[y-x+1]; return Min(ht[x][r],ht[y-bin[r]+1][r]); } void psd(int t,int x,int l,int r) { if(!lzy[t][x])return; int d=lzy[t][x]; lzy[t][x]=0; sum[t][ls]+=d*(mid-l+1); lzy[t][ls]+=d; sum[t][rs]+=d*(r-mid); lzy[t][rs]+=d; } void update(int t,int x,int l,int r,int L,int R) { if(L>R)return; if(l>=L&&r<=R){sum[t][x]+=r-l+1; lzy[t][x]++; return;} psd(t,x,l,r); if(mid>=L)update(t,ls,l,mid,L,R); if(mid<R)update(t,rs,mid+1,r,L,R); sum[t][x]=sum[t][ls]+sum[t][rs]; } ll query(int t,int x,int l,int r,int L,int R) { if(L>R)return 0; if(l>=L&&r<=R)return sum[t][x]; psd(t,x,l,r); ll ret=0; if(mid>=L)ret+=query(t,ls,l,mid,L,R); if(mid<R)ret+=query(t,rs,mid+1,r,L,R); return ret; } int main() { int T; scanf("%d",&T); bin[0]=1; for(int i=1;i<20;i++)bin[i]=(bin[i-1]<<1); bit[1]=0; for(int i=2;i<=6e4;i++)bit[i]=bit[i>>1]+1;//6e4 while(T--) { memset(sum,0,sizeof sum); memset(lzy,0,sizeof lzy); scanf("%s",s+1); int n=strlen(s+1); s[n+1]='a'-1; for(int i=n+2,k=n;k;i++,k--)s[i]=s[k],op[k]=i; work(n*2+1); geth(n*2+1); ll ans=0; for(int d=1;d<=n;d++) for(int i=1;i+d<=n;i+=d) { int j=i+d; int t2=Min(d,getlcp(i,j,n)); int t1=Min(d,getlcp(op[i],op[j],n)); if(t1+t2-1<d)continue; int l=i-t1+1,r=i+t2-d;//hd ans+=query(1,1,1,n,Max(1,l-1),r-1);//qtl ans+=query(0,1,1,n,l+2*d,Min(n,r+2*d));//qhd update(0,1,1,n,l,r); update(1,1,1,n,l+2*d-1,r+2*d-1); } printf("%lld\n",ans); } return 0; }