bzoj 4650 & 洛谷 P1117 优秀的拆分 —— 枚举关键点+后缀数组

题目:https://www.lydsy.com/JudgeOnline/problem.php?id=4650

https://www.luogu.org/problemnew/show/P1117

枚举每一段 a 的长度,然后分块,后缀数组求出每一块首关键点附近的可行范围;

然后用线段树区间加,区间查询;

在开头范围查询结尾,就得到长度 <= 当前长度 d 的前半部分接上长度 = d 的后半部分的答案;

在结尾范围查询开头,就得到长度 = d 的前半部分接上长度 < d 的后半部分的答案,正好组合了所有情况!

于是写了,跑出来发现比别人慢了很多;

然后发现为什么要边找边算答案...可以差分数组区间加,最后直接每个位置算答案即可...

注意多组数据,要清空 rk[n+1] 和 tp[n+1],因为求后缀数组的时候用到了超出 n 的 tp(rk) 。

代码如下:

#include<cstdio>
#include<cstring>
#include<algorithm>
#define mid ((l+r)>>1)
#define ls (x<<1)
#define rs ((x<<1)|1)
using namespace std;
typedef long long ll;
int const xn=6e4+5;//n<<1
int rk[xn],sa[xn],tax[xn],tp[xn],ht[xn][20],bin[20],bit[xn],op[xn];
ll sum[2][xn<<1],lzy[2][xn<<1];//n<<2
char s[xn];
int Min(int x,int y){return x<y?x:y;}
int Max(int x,int y){return x>y?x:y;}
void rsort(int n,int m)
{
  for(int i=1;i<=m;i++)tax[i]=0;
  for(int i=1;i<=n;i++)tax[rk[tp[i]]]++;
  for(int i=1;i<=m;i++)tax[i]+=tax[i-1];
  for(int i=n;i;i--)sa[tax[rk[tp[i]]]--]=tp[i];
}
void work(int n)
{
  int m=125; for(int i=1;i<=n;i++)rk[i]=s[i],tp[i]=i;
  rk[n+1]=0; tp[n+1]=0;//!
  rsort(n,m);
  for(int k=1;k<=n;k<<=1)
    {
      int num=0;
      for(int i=n-k+1;i<=n;i++)tp[++num]=i;
      for(int i=1;i<=n;i++)
    if(sa[i]>k)tp[++num]=sa[i]-k;
      rsort(n,m); swap(rk,tp);
      rk[sa[1]]=1; num=1; 
      for(int i=2;i<=n;i++)
    {
      int u=Min(n+1,sa[i]+k),v=Min(n+1,sa[i-1]+k);//
      rk[sa[i]]=(tp[sa[i]]==tp[sa[i-1]]&&tp[u]==tp[v])?num:++num;//tp[n+1]!
    }
      if(num==n)break;
      m=num;
    }
}
void geth(int n)
{
  int k=0; ht[1][0]=0;
  for(int i=1;i<=n;i++)
    {
      if(rk[i]==1)continue;
      if(k)k--; int j=sa[rk[i]-1];
      while(i+k<=n&&j+k<=n&&s[i+k]==s[j+k])k++;
      ht[rk[i]][0]=k;
    }
  for(int i=1;i<20;i++)
    for(int j=1;j+bin[i]-1<=n;j++)
      ht[j][i]=Min(ht[j][i-1],ht[j+bin[i-1]][i-1]);
}
int getlcp(int x,int y,int n)
{
  if(x==y)return n-x+1;
  x=rk[x]; y=rk[y];
  if(x>y)swap(x,y); x++;
  int r=bit[y-x+1];
  return Min(ht[x][r],ht[y-bin[r]+1][r]);
}
void psd(int t,int x,int l,int r)
{
  if(!lzy[t][x])return;
  int d=lzy[t][x]; lzy[t][x]=0;
  sum[t][ls]+=d*(mid-l+1); lzy[t][ls]+=d;
  sum[t][rs]+=d*(r-mid); lzy[t][rs]+=d;
}
void update(int t,int x,int l,int r,int L,int R)
{
  if(L>R)return;
  if(l>=L&&r<=R){sum[t][x]+=r-l+1; lzy[t][x]++; return;}
  psd(t,x,l,r);
  if(mid>=L)update(t,ls,l,mid,L,R);
  if(mid<R)update(t,rs,mid+1,r,L,R);
  sum[t][x]=sum[t][ls]+sum[t][rs];
}
ll query(int t,int x,int l,int r,int L,int R)
{
  if(L>R)return 0;
  if(l>=L&&r<=R)return sum[t][x];
  psd(t,x,l,r); ll ret=0;
  if(mid>=L)ret+=query(t,ls,l,mid,L,R);
  if(mid<R)ret+=query(t,rs,mid+1,r,L,R);
  return ret;
}
int main()
{
  int T; scanf("%d",&T);
  bin[0]=1; for(int i=1;i<20;i++)bin[i]=(bin[i-1]<<1);
  bit[1]=0; for(int i=2;i<=6e4;i++)bit[i]=bit[i>>1]+1;//6e4
  while(T--)
    {
      memset(sum,0,sizeof sum);
      memset(lzy,0,sizeof lzy);
      scanf("%s",s+1); int n=strlen(s+1);
      s[n+1]='a'-1;
      for(int i=n+2,k=n;k;i++,k--)s[i]=s[k],op[k]=i;
      work(n*2+1); geth(n*2+1); ll ans=0;
      for(int d=1;d<=n;d++)
      for(int i=1;i+d<=n;i+=d)
        {
          int j=i+d;
          int t2=Min(d,getlcp(i,j,n));
          int t1=Min(d,getlcp(op[i],op[j],n));
          if(t1+t2-1<d)continue;
          int l=i-t1+1,r=i+t2-d;//hd
          ans+=query(1,1,1,n,Max(1,l-1),r-1);//qtl
          ans+=query(0,1,1,n,l+2*d,Min(n,r+2*d));//qhd
          update(0,1,1,n,l,r); update(1,1,1,n,l+2*d-1,r+2*d-1);
        }
      printf("%lld\n",ans);
    }
  return 0;
}

 

posted @ 2019-01-28 14:09  Zinn  阅读(242)  评论(0编辑  收藏  举报