bzoj3473: 字符串 && bzoj3277串

3473: 字符串

Time Limit: 20 Sec  Memory Limit: 256 MB
Submit: 121  Solved: 53
[Submit][Status][Discuss]

Description

给定n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串?

Input

第一行两个整数n,k。
接下来n行每行一个字符串。

Output

一行n个整数,第i个整数表示第i个字符串的答案。

Sample Input

3 1
abc
a
ab

Sample Output

6 1 3

HINT



对于 100% 的数据,1<=n,k<=10^5,所有字符串总长不超过10^5,字符串只包含小写字母。

Source

Adera 1 杯冬令营模拟赛

 

很久之前做的题今天一看竟然不会做了。。。于是补篇题解。

首先把所有串连起来做一遍SA,求出hight,然后在后缀数组上从前往后扫。

那么现在要求的就是当前这个后缀有多少前缀是至少k个串的子串,这些前缀一定是连续的一段,因为如果Sx出现了k次,那么S也一定出现了k次。

设当前位是i,我们现在拥有后缀数组上一位的答案lastans,那么把它与hight[i]取一个min得到x,那么这位的答案至少是x。

然后考虑这位新出现的子串,那些包含这些子串的位置一定在i下面,那么维护一个指针使当前区间内刚好包含k个不同串的任意一个后缀,当i++时指针往后扫。

那么指针的位置与i用ST表求个区间RMQ,用x与这个区间最小值取max就是当前位的答案。

复杂度$nlogn$

感觉后缀自动机的做法很不科学

 

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 300005
#define ll long long
using namespace std;
int n,k;
int sa[N],rank[N],wb[N],sum[N],c[N];
void getsa(int n,int m)
{
    int *x=rank,*y=wb;
    for(int i=0;i<m;i++)sum[i]=0;
    for(int i=0;i<n;i++)sum[x[i]=c[i]]++;
    for(int i=1;i<m;i++)sum[i]+=sum[i-1];
    for(int i=n-1;i>=0;i--)sa[--sum[x[i]]]=i;
    int p=1;
    for(int j=1;p<n;j<<=1,m=p)
    {
        p=0;
        for(int i=n-j;i<n;i++)y[p++]=i;
        for(int i=0;i<n;i++)if(sa[i]>=j)y[p++]=sa[i]-j;   
        for(int i=0;i<m;i++)sum[i]=0;
        for(int i=0;i<n;i++)sum[x[i]]++;
        for(int i=1;i<m;i++)sum[i]+=sum[i-1];
        for(int i=n-1;i>=0;i--)sa[--sum[x[y[i]]]]=y[i];
        swap(x,y);x[sa[0]]=0;p=1;
        for(int i=1;i<n;i++)
          x[sa[i]]=y[sa[i]]==y[sa[i-1]]&&y[sa[i]+j]==y[sa[i-1]+j]?p-1:p++;
    }
}
int h[N];
void calh(int n)
{
    for(int i=1;i<=n;i++)rank[sa[i]]=i;
    int kk=0;
    for(int i=0;i<n;i++)
    {
        if(kk)kk--;
        int j=sa[rank[i]-1];
        while(c[i+kk]==c[j+kk])kk++;
        h[rank[i]]=kk;
    }
    return ;
}
int mn[N][20],lg[N];
void ST()
{
    lg[0]=-1;
    for(int i=1;i<=n;i++)lg[i]=lg[i>>1]+1;
    for(int i=1;i<=n;i++)mn[i][0]=h[i];
    for(int i=1;i<=19;i++)
    {
        for(int j=1;j<=n;j++)
        {
            if(j+(1<<(i-1))<=n)mn[j][i]=min(mn[j][i-1],mn[j+(1<<(i-1))][i-1]);
            else mn[j][i]=mn[j][i-1];
        }
    }return ;
}
int qur(int l,int r)
{
    int k=lg[r-l+1];
    return min(mn[l][k],mn[r-(1<<k)+1][k]);
}
int be[N];
ll ans[N];
int len[N],sz[N];
int now[N],nw;
void solve()
{
    int l=0;int tmp=0;
    for(int i=1;i<=n;i++)
    {
        if(i!=1&&be[sa[i-1]]!=0)
        {
            now[be[sa[i-1]]]--;
            if(!now[be[sa[i-1]]])nw--;
        }
        while(l!=n&&nw<k)
        {
            l++;
            if(be[sa[l]]!=0)
            {
                now[be[sa[l]]]++;
                if(now[be[sa[l]]]==1)nw++;
            }
        }
        tmp=min(tmp,h[i]);
        if(nw==k)
        {
            if(be[sa[i]])
            {
                int num;
                if(l!=i)num=qur(i+1,l);
                else num=sz[sa[i]];
                tmp=max(tmp,num);
            }
        }
        ans[be[sa[i]]]+=tmp;
    }
    return ;
}
char s[N];
int main()
{
    int cnt;
    scanf("%d%d",&cnt,&k);
    int m=256;n=-1;
    for(int i=1;i<=cnt;i++)
    {
        scanf("%s",s+1);len[i]=strlen(s+1);
        for(int j=1;j<=len[i];j++)
        {
            c[++n]=s[j];
            be[n]=i;
            sz[n]=len[i]-j+1;
        }
        if(i!=cnt)c[++n]=m++;
    }n++;
    getsa(n+1,m);calh(n);
    ST();
    solve();
    for(int i=1;i<cnt;i++)printf("%lld ",ans[i]);
    printf("%lld",ans[cnt]);
    return 0;
}

 

  

 

  

posted @ 2017-04-19 21:44  SD_le  阅读(265)  评论(0编辑  收藏  举报
重置按钮