SAM杂题

现在是 SAM 什么都不会了的状态。所以打算搞字符串。差不多就是从头开始学 SAM,但是全是题。

我觉得得一半多都是代码。

P2408 不同子串个数

回忆一下 SAM 中每个节点的子串个数就是 \(len_x-len_{fa_x}\)。加起来就好。也可以是到达每个节点的路径数,可以在 DAG 上计数,最后减掉一个空串。

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
int n;
long long ans;
char s[100010];
struct SAM{
    int cnt,last,trie[200010][26],fa[200010],len[200010];
    SAM(){cnt=1,last=1;}
    void ins(char ch){
        int p=last;last=++cnt;
        len[last]=len[p]+1;
        while(p&&!trie[p][ch])trie[p][ch]=last,p=fa[p];
        if(p==0){
            fa[last]=1;ans+=(len[last]-len[fa[last]]);return;
        }
        int q=trie[p][ch];
        if(len[p]+1==len[q]){
            fa[last]=q;ans+=(len[last]-len[fa[last]]);return;
        }
        len[++cnt]=len[p]+1;
        for(int i=0;i<26;i++)trie[cnt][i]=trie[q][i];
        fa[cnt]=fa[q];fa[q]=cnt;fa[last]=cnt;
        while(trie[p][ch]==q)trie[p][ch]=cnt,p=fa[p];
        ans+=(len[last]-len[fa[last]]);
    }
}sam;
int main(){
    scanf("%d%s",&n,s+1);
    for(int i=1;i<=n;i++)sam.ins(s[i]-'a');
    printf("%lld\n",ans);
}

P3804 【模板】后缀自动机 (SAM)

首先每个节点的 endpos 集合大小就是这个节点包含的所有子串的出现次数。然后回忆另一个结论:如果把每个前缀赋予 \(1\) 的权值,那么每个节点的 endpos 集合就是子树和。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
char s[1000010];
int n;
long long ans;
struct node{
    int v,next;
}edge[2000010];
int head[2000010],t;
void add(int u,int v){
    edge[++t].v=v;edge[t].next=head[u];head[u]=t;
}
int cnt,last,len[2000010],size[2000010],fa[2000010],trie[2000010][26];
void build(){
    cnt=last=1;
    int p;
    for(int i=1;i<=n;i++){
        p=last;last=++cnt;
        size[last]=1;len[last]=len[p]+1;
        while(p&&!trie[p][s[i]-'a'])trie[p][s[i]-'a']=cnt,p=fa[p];
        if(!p){
            fa[last]=1;continue;
        }
        int q=trie[p][s[i]-'a'];
        if(len[p]+1==len[q]){
            fa[last]=q;continue;
        }
        len[++cnt]=len[p]+1;
        for(int j=0;j<26;j++)trie[cnt][j]=trie[q][j];
        fa[cnt]=fa[q];fa[q]=cnt;fa[last]=cnt;
        for(int j=p;trie[j][s[i]-'a']==q;j=fa[j])trie[j][s[i]-'a']=cnt;
    }
}
void dfs(int x){
    for(int i=head[x];i;i=edge[i].next){
        dfs(edge[i].v);size[x]+=size[edge[i].v];
    }
    if(size[x]!=1)ans=max(ans,1ll*size[x]*len[x]);
}
int main(){
    scanf("%s",s+1);n=strlen(s+1);
    build();
    for(int i=2;i<=cnt;i++)add(fa[i],i);
    dfs(1);
    printf("%lld\n",ans);
    return 0;
}

[TJOI2015]弦论

首先我们可以类似线段树上二分的在 SAM 上跑出字符串。这样我们就需要找到每次走到一个节点,字典序比答案串小的串的个数。

在 SAM 上跑 dp,每个节点的答案是所有子节点之和。初值的问题,如果 \(t=0\),相同子串算一次,那么就仍然是 \(1\)。如果 \(t=1\),那么初值就是 endpos 集合大小。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <map>
using namespace std;
char s[500010];
int n,k,t[1000010],size[1000010],a[1000010];
long long ans;
int cnt,last,trie[1000010][26],len[1000010],fa[1000010],sum[1000010];
void ins(char ch){
    int p=last;last=++cnt;
    len[last]=len[p]+1;size[last]=1;
    while(p&&!trie[p][ch])trie[p][ch]=cnt,p=fa[p];
    if(!p){
        fa[last]=1;return;
    }
    int q=trie[p][ch];
    if(len[p]+1==len[q]){
        fa[last]=q;return;
    }
    len[++cnt]=len[p]+1;
    for(int j=0;j<26;j++)trie[cnt][j]=trie[q][j];
    fa[cnt]=fa[q];fa[q]=cnt;fa[last]=cnt;
    while(trie[p][ch]==q)trie[p][ch]=cnt,p=fa[p];
}
void print(int x,int k){
    if(k<=size[x])return;
    k-=size[x];
    for(int i=0;i<26;i++){
        if(!trie[x][i])continue;
        if(k>sum[trie[x][i]]){
            k-=sum[trie[x][i]];continue;
        }
        putchar(i+'a');
        print(trie[x][i],k);
        return;
    }
}
int main(){
    int od;cnt=last=1;
    scanf("%s%d%d",s+1,&od,&k);n=strlen(s+1);
    for(int i=1;i<=n;i++)ins(s[i]-'a');
    for(int i=1;i<=cnt;i++)t[len[i]]++;
    for(int i=1;i<=cnt;i++)t[i]+=t[i-1];
    for(int i=1;i<=cnt;i++)a[t[len[i]]--]=i;
    for(int i=cnt;i>=1;i--)size[fa[a[i]]]+=size[a[i]];
    for(int i=2;i<=cnt;i++){
        if(od==0)size[i]=1;
        sum[i]=size[i];
    }
    size[1]=sum[1]=0;
    for(int i=cnt;i>=1;i--){
        for(int j=0;j<26;j++){
            if(trie[a[i]][j])sum[a[i]]+=sum[trie[a[i]][j]];
        }
    }
    cerr<<sum[1]<<endl;
    if(sum[1]<k){
        puts("-1");return 0;
    }
    print(1,k);
    return 0;
}

SP1812 LCS2 - Longest Common Substring II

先看看这道题的削弱版,只有两个串。我们可以建出一个的 SAM,让另一个在上面跑,处理出每个位置的最大匹配长度。答案就是它们的最大值。

这个题类似,可以对除最后一个串的所有串各建一个 SAM,然后让最后一个串分别在上面跑,把最大匹配长度取 \(\min\) 再统计答案就行了。

也可以广义 SAM,预处理出广义 SAM 的哪些节点是所有字符串共有的。这个可以建完 SAM 后把每个字符串扔到上面跑一遍,对于每个字符串分别求出 endpos 集合大小,若都不为 \(0\) 则是共有的。然后在共有的节点上随便找个字符串跑一下就行了。

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
int n,m,ans,lst[100010];
char s[11][100010];
struct SAM{
    int cnt,last,trie[200010][26],fa[200010],len[200010],size[200010];
    SAM(){cnt=1,last=1;}
    void ins(char ch){
        int p=last;last=++cnt;
        len[last]=len[p]+1;size[last]=1;
        while(p&&!trie[p][ch])trie[p][ch]=last,p=fa[p];
        if(p==0){
            fa[last]=1;return;
        }
        int q=trie[p][ch];
        if(len[p]+1==len[q]){
            fa[last]=q;return;
        }
        len[++cnt]=len[p]+1;
        for(int i=0;i<26;i++)trie[cnt][i]=trie[q][i];
        fa[cnt]=fa[q];fa[q]=cnt;fa[last]=cnt;
        while(trie[p][ch]==q)trie[p][ch]=cnt,p=fa[p];
    }
    void query(char s[]){
        int p=1,l=0;
        for(int i=1;i<=n;i++){
            while(p!=1&&!trie[p][s[i]-'a'])p=fa[p],l=len[p];
            if(trie[p][s[i]-'a'])p=trie[p][s[i]-'a'],l++;
            lst[i]=min(lst[i],l);
        }
    }
}sam[11];
int main(){
    int pos=1;
    while(~scanf("%s",s[pos]+1))pos++;
    pos--;
    for(int i=1;i<pos;i++){
        int n=strlen(s[i]+1);
        for(int j=1;j<=n;j++)sam[i].ins(s[i][j]-'a');
    }
    n=strlen(s[pos]+1);
    for(int i=1;i<=n;i++)lst[i]=0x3f3f3f3f;
    for(int i=1;i<pos;i++)sam[i].query(s[pos]);
    for(int i=1;i<=n;i++)ans=max(ans,lst[i]);
    printf("%d\n",ans);
    return 0;
}

[BJOI2020] 封印

不错的题,就是作为 A 队加试有点水)

首先仍然建出 \(t\) 的 SAM,并让 \(s\) 在上面跑出最大匹配长度 \(ans\)。考虑怎么快速解决询问。

发现每个位置 \(i\) 的答案是 \(\min(i-l+1,ans_i)\),那么位置每次向右走一个, \(i-l+1\) 增加 \(1\) ,而 \(ans_i\) 至多增加 \(1\)。容易发现 \(ans_i-(i-l+1)\) 单调不增,因此可以二分零点。

找到零点 \(pos\) 之后左边的答案自然是 \(pos-l\),右边就是个区间 \(ans\) 的最大值,可以 ST 表。

二分的时候注意特判全段都小于 \(0\) 的情况,应当返回 \(r+1\)

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
int n,m,ans[200010];
char s[200010],t[200010];
int cnt=1,last=1,trie[400010][26],fa[400010],len[400010];
void ins(char ch){
    int p=last;last=++cnt;
    len[last]=len[p]+1;
    while(p&&!trie[p][ch])trie[p][ch]=last,p=fa[p];
    if(p==0){
        fa[last]=1;return;
    }
    int q=trie[p][ch];
    if(len[p]+1==len[q]){
        fa[last]=q;return;
    }
    len[++cnt]=len[p]+1;
    for(int i=0;i<26;i++)trie[cnt][i]=trie[q][i];
    fa[cnt]=fa[q];fa[q]=cnt;fa[last]=cnt;
    while(trie[p][ch]==q)trie[p][ch]=cnt,p=fa[p];
}
void qry(char s[]){
    int p=1,l=0;
    for(int i=1;i<=n;i++){
        while(p!=1&&!trie[p][s[i]-'a'])p=fa[p],l=len[p];
        if(trie[p][s[i]-'a'])p=trie[p][s[i]-'a'],l++;
        ans[i]=l;
    }
}
int st[200010][20];
int query(int l,int r){
    if(l>r)return 0;
    int k=__lg(r-l+1);
    return max(st[l][k],st[r-(1<<k)+1][k]);
}
int getans(int l,int r){
    int L=l,R=r;
    while(L<R){
        int mid=(L+R)>>1;
        if(ans[mid]<=mid-l+1)R=mid;
        else L=mid+1;
    }
    if(L-l+1<ans[L])L++;
    return L;
}
int main(){
    scanf("%s%s",s+1,t+1);
    n=strlen(s+1);m=strlen(t+1);
    for(int i=1;i<=m;i++)ins(t[i]-'a');
    qry(s);
    for(int i=1;i<=n;i++)st[i][0]=ans[i];
    for(int j=1;j<=__lg(n);j++){
        for(int i=1;i+(1<<j)-1<=n;i++){
            st[i][j]=max(st[i][j-1],st[i+(1<<(j-1))][j-1]);
        }
    }
    int q;scanf("%d",&q);
    while(q--){
        int l,r;scanf("%d%d",&l,&r);
        int pos=getans(l,r);
        printf("%d\n",max(query(pos,r),pos-l));
    }
    return 0;
}

[TJOI2019]甲苯先生和大中锋的字符串

把 endpos 找出来然后就是个区间加最后单点查,实际上差分就行。代码不上了。

CF235C Cyclical Quest

(WJMZBMR)

循环同构可以变成前面去掉一个字符后面加上一个字符。后面加上就是在 SAM 里跑转移边,前面去掉就是在 parent 树上跳父亲。相同的子串只能算一次,那么每个位置打个标记就行了。

写起来挺顺手的。

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
int n,m,ans;
char s[1000010];
struct node{
    int v,next;
}edge[2000010];
int t,head[2000010];
void add(int u,int v){
    edge[++t].v=v;edge[t].next=head[u];head[u]=t;
}
int cnt=1,last=1,trie[2000010][26],fa[2000010],len[2000010],siz[2000010];
int v[2000010];
void ins(int ch){
    int p=last;last=++cnt;
    len[last]=len[p]+1;siz[last]++;
    while(p&&!trie[p][ch])trie[p][ch]=cnt,p=fa[p];
    if(!p){
        fa[last]=1;return;
    }
    int q=trie[p][ch];
    if(len[p]+1==len[q]){
        fa[last]=q;return;
    }
    len[++cnt]=len[p]+1;
    for(int i=0;i<26;i++)trie[cnt][i]=trie[q][i];
    fa[cnt]=fa[q];fa[q]=cnt;fa[last]=cnt;
    while(trie[p][ch]==q)trie[p][ch]=cnt,p=fa[p];
}
void dfs(int x){
    for(int i=head[x];i;i=edge[i].next){
        dfs(edge[i].v);siz[x]+=siz[edge[i].v];
    }
}
void query(char s[],int id){
    int p=1,l=0;
    for(int i=1;i<=n;i++){
        while(p!=1&&!trie[p][s[i]-'a'])p=fa[p],l=len[p];
        if(trie[p][s[i]-'a'])p=trie[p][s[i]-'a'],l++;
    }
    for(int i=1;i<=n;i++){
        if(l==n){
            if(v[p]!=id)ans+=siz[p];
            v[p]=id;l--;
            if(l==len[fa[p]])p=fa[p];
        }
        while(p!=1&&!trie[p][s[i]-'a'])p=fa[p],l=len[p];
        if(trie[p][s[i]-'a'])p=trie[p][s[i]-'a'],l++;
    }
}
int main(){
    scanf("%s",s+1);n=strlen(s+1);
    for(int i=1;i<=n;i++)ins(s[i]-'a');
    for(int i=2;i<=cnt;i++)add(fa[i],i);
    dfs(1);
    scanf("%d",&m);
    while(m--){
        scanf("%s",s+1);n=strlen(s+1);ans=0;
        query(s,m+1);
        printf("%d\n",ans);
    }
}

[HEOI2016/TJOI2016]字符串

之前题库暴力艹过去了,现在重新写一遍。用 SAM。

二分答案,于是变成 \(s[c,c+mid-1]\)\(s[a,b]\) 是否出现过。从 \(s[1,c+mid-1]\) 倍增跳到 \(s[c,c+mid-1]\),然后线段树合并维护每个节点的 endpos 即可。复杂度好像是 \(O(n\log^2n)\)

另外注意线段树合并的时候一定要新开节点。别的其实不是很难写。

一个类似的题是 CF666E,对 \(S\) 建 SAM 然后把所有 \(T\) 扔上去跑并标记出现次数,线段树合并就行了。

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#define lson tree[rt].ls
#define rson tree[rt].rs
using namespace std;
struct node{
    int v,next;
}edge[200010];
int tot,head[200010];
void add(int u,int v){
    edge[++tot].v=v;edge[tot].next=head[u];head[u]=tot;
}
int n,m,t,rt[200010];
char s[100010];
struct tr{
    int ls,rs,sum;
}tree[200010<<5];
void pushup(int rt){
    tree[rt].sum=tree[lson].sum+tree[rson].sum;
}
void update(int &rt,int L,int R,int pos){
    if(!rt)rt=++t;
    if(L==R){
        tree[rt].sum++;return;
    }
    int mid=(L+R)>>1;
    if(pos<=mid)update(lson,L,mid,pos);
    else update(rson,mid+1,R,pos);
    pushup(rt);
}
int query(int rt,int L,int R,int l,int r){
    if(l<=L&&R<=r)return tree[rt].sum;
    int mid=(L+R)>>1,val=0;
    if(l<=mid)val+=query(lson,L,mid,l,r);
    if(mid<r)val+=query(rson,mid+1,R,l,r);
    return val;
}
int merge(int x,int y,int l,int r){
    if(!x||!y)return x|y;
    int rt=++t;
    if(l==r){
        tree[rt].sum=tree[x].sum+tree[y].sum;
        return rt;
    }
    int mid=(l+r)>>1;
    tree[rt].ls=merge(tree[x].ls,tree[y].ls,l,mid);
    tree[rt].rs=merge(tree[x].rs,tree[y].rs,mid+1,r);
    pushup(rt);
    return rt;
}
struct SAM{
    int cnt,last,trie[200010][26],fa[200010],len[200010];
    SAM(){cnt=last=1;}
    void ins(int ch){
        int p=last;last=++cnt;
        len[last]=len[p]+1;
        while(p&&!trie[p][ch])trie[p][ch]=cnt,p=fa[p];
        if(!p){
            fa[last]=1;return;
        }
        int q=trie[p][ch];
        if(len[p]+1==len[q]){
            fa[last]=q;return;
        }
        len[++cnt]=len[p]+1;
        for(int i=0;i<26;i++)trie[cnt][i]=trie[q][i];
        fa[cnt]=fa[q];fa[q]=cnt;fa[last]=cnt;
        while(trie[p][ch]==q)trie[p][ch]=cnt,p=fa[p];
    }
}sam;
int fa[200010][20],id[100010];
void dfs(int x){
    for(int i=head[x];i;i=edge[i].next){
        fa[edge[i].v][0]=x;
        for(int j=1;j<=__lg(sam.cnt);j++)fa[edge[i].v][j]=fa[fa[edge[i].v][j-1]][j-1];
        dfs(edge[i].v);
        rt[x]=merge(rt[x],rt[edge[i].v],1,n);
    }
}
bool check(int mid,int a,int b,int c,int d){
    int x=id[c+mid-1];
    for(int i=__lg(sam.cnt);i>=0;i--){
        if(sam.len[fa[x][i]]>=mid)x=fa[x][i];
    }
    return query(rt[x],1,n,a+mid-1,b);
}
int main(){
    scanf("%d%d%s",&n,&m,s+1);
    id[0]=1;
    for(int i=1;i<=n;i++){
        sam.ins(s[i]-'a');
        id[i]=sam.last;
        update(rt[id[i]],1,n,i);
    }
    for(int i=2;i<=sam.cnt;i++)add(sam.fa[i],i);
    dfs(1);
    while(m--){
        int a,b,c,d;scanf("%d%d%d%d",&a,&b,&c,&d);
        int l=0,r=min(b-a+1,d-c+1);
        while(l<r){
            int mid=(l+r+1)>>1;
            if(check(mid,a,b,c,d))l=mid;
            else r=mid-1;
        }
        printf("%d\n",l);
    }
    return 0;
}

[CTSC2012]熟悉的文章

首先显然可以二分答案。考虑怎么验证。

一个 dp:设 \(dp_i\) 是到 \(i\) 能匹配的最大长度,\(l_i\) 为到 \(i\) 在广义 SAM 上能匹配到的长度。那么有:

\[dp_i=\max(dp_{i-1},dp_j+i-j,j\in[i-l_i,i-mid]) \]

发现 \(i-mid\) 单增,\(i-l_i\) 单调不降,单调队列。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
char s[1100010];
int n,m,len,dp[1100010],match[1100010];
struct Sam{
    int cnt,fa[2200010],len[2200010],trie[2200010][2];
    Sam(){cnt=1;}
    int ins(int ch,int last){
        if(trie[last][ch]){
            int p=last,q=trie[p][ch];
            if(len[p]+1==len[q])return q;
            else{
                len[++cnt]=len[p]+1;
                for(int i=0;i<2;i++)trie[cnt][i]=trie[q][i];
                fa[cnt]=fa[q];fa[q]=cnt;
                while(trie[p][ch]==q)trie[p][ch]=cnt,p=fa[p];
                return cnt;
            }
        }
        int p=last;last=++cnt;
        len[last]=len[p]+1;
        while(p&&!trie[p][ch])trie[p][ch]=cnt,p=fa[p];
        if(!p){
            fa[last]=1;return last;
        }
        int q=trie[p][ch];
        if(len[p]+1==len[q]){
            fa[last]=q;return last;
        }
        len[++cnt]=len[p]+1;
        for(int i=0;i<2;i++)trie[cnt][i]=trie[q][i];
        fa[cnt]=fa[q];fa[q]=cnt;fa[last]=cnt;
        while(trie[p][ch]==q)trie[p][ch]=cnt,p=fa[p];
        return last;
    }
    void getmatch(char s[]){
        int slen=strlen(s+1),p=1,l=0;
        for(int i=1;i<=slen;i++){
            while(p&&!trie[p][s[i]-'0'])p=fa[p],l=len[p];
            if(p)p=trie[p][s[i]-'0'],l++;
            else p=1,l=0;
            match[i]=l;
        }
    }
}SAM;
int q[1100010],L,R,last;
bool check(int mid){
    L=1,R=0;
    for(int i=1;i<mid;i++)dp[i]=0;
    for(int i=mid;i<=len;i++){
        dp[i]=dp[i-1];
        while(L<=R&&dp[q[R]]-q[R]<dp[i-mid]-(i-mid))R--;
        q[++R]=i-mid;
        while(L<=R&&q[L]<i-match[i])L++;
        if(L<=R)dp[i]=max(dp[i],dp[q[L]]-q[L]+i);
    }
    return dp[len]*10>=len*9;
}
signed main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<=m;i++){
        scanf("%s",s+1);
        last=1;len=strlen(s+1);
        for(int j=1;j<=len;j++)last=SAM.ins(s[j]-'0',last);
    }
    while(n--){
        scanf("%s",s+1);
        int l=1,r=len=strlen(s+1);
        SAM.getmatch(s);
        while(l<r){
            int mid=(l+r+1)>>1;
            if(check(mid))l=mid;
            else r=mid-1;
        }
        printf("%d\n",l);
    }
    return 0;
}

[NOI2018] 你的名字

好强。

首先题意翻译成人话是找到 \(T\) 的所有本质不同子串中不是 \(S[l,r]\) 子串的个数。

先考虑一下不带 \(l,r\) 的情况。考虑 \(T\) 的 SAM 上每个节点的贡献。对于这个节点的一个 endpos \(i\),那么假设 \(T\)\(x\) 位置能和 \(S\) 匹配的最长长度是 \(match_x\),那么这个节点所有大于 \(match_i\) 的子串都不能与 \(S\) 匹配。为了方便直接找 firstpos 即可。于是一个节点的贡献就 \(\max(0,len_x-\max(len_{fa_x},match_{firstpos_x}))\)

然后带上 \(l,r\) 的限制,那么只要找这一段最前面的 \(endpos\) ,求 \(match\) 的时候只找 endpos 在范围内的即可。

注意求 \(match\) 的时候不能直接跳 \(fa\),因为匹配长度减小之后 endpos 的合法范围增大,还要继续判断,一直到当前的长度不在该节点的 \(len\) 范围内再往上跳。

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#define lson tree[rt].ls
#define rson tree[rt].rs
using namespace std;
struct node{
    int v,next;
}edge[2000010];
int tot,head[2000010];
void add(int u,int v){
    edge[++tot].v=v;edge[tot].next=head[u];head[u]=tot;
}
int n,m,t,rt[2000010];
char s[1000010];
struct tr{
    int ls,rs,sum;
}tree[2000010<<5];
void pushup(int rt){
    tree[rt].sum=tree[lson].sum+tree[rson].sum;
}
void update(int &rt,int L,int R,int pos){
    if(!rt)rt=++t;
    if(L==R){
        tree[rt].sum++;return;
    }
    int mid=(L+R)>>1;
    if(pos<=mid)update(lson,L,mid,pos);
    else update(rson,mid+1,R,pos);
    pushup(rt);
}
int query(int rt,int L,int R,int l,int r){
    if(!rt)return 0;
    if(l<=L&&R<=r)return tree[rt].sum;
    int mid=(L+R)>>1,val=0;
    if(l<=mid)val+=query(lson,L,mid,l,r);
    if(mid<r)val+=query(rson,mid+1,R,l,r);
    return val;
}
int merge(int x,int y,int l,int r){
    if(!x||!y)return x|y;
    int rt=++t;
    if(l==r){
        tree[rt].sum=tree[x].sum+tree[y].sum;
        return rt;
    }
    int mid=(l+r)>>1;
    tree[rt].ls=merge(tree[x].ls,tree[y].ls,l,mid);
    tree[rt].rs=merge(tree[x].rs,tree[y].rs,mid+1,r);
    pushup(rt);
    return rt;
}
int match[200010];
struct SAM{
    int cnt,last,trie[2000010][26],fa[2000010],len[2000010],firstpos[2000010];
    SAM(){cnt=last=1;}
    void ins(int ch,int id){
        int p=last;last=++cnt;
        len[last]=len[p]+1;firstpos[last]=id;
        while(p&&!trie[p][ch])trie[p][ch]=cnt,p=fa[p];
        if(!p){
            fa[last]=1;return;
        }
        int q=trie[p][ch];
        if(len[p]+1==len[q]){
            fa[last]=q;return;
        }
        len[++cnt]=len[p]+1;firstpos[cnt]=firstpos[q];
        for(int i=0;i<26;i++)trie[cnt][i]=trie[q][i];
        fa[cnt]=fa[q];fa[q]=cnt;fa[last]=cnt;
        while(trie[p][ch]==q)trie[p][ch]=cnt,p=fa[p];
    }
}sam,sam2;
void dfs(int x){
    for(int i=head[x];i;i=edge[i].next){
        dfs(edge[i].v);rt[x]=merge(rt[x],rt[edge[i].v],1,n);
    }
}
void get(int l,int r){
    int p=1,L=0;
    for(int i=1;i<=m;i++){
        while(1){
            if(sam.trie[p][s[i]-'a']&&query(rt[sam.trie[p][s[i]-'a']],1,n,l+L,r)){
                p=sam.trie[p][s[i]-'a'];L++;
                break;
            }
            if(!L)break;
            L--;
            if(L==sam.len[sam.fa[p]])p=sam.fa[p];
        }
        match[i]=L;
    }
}
int main(){
    scanf("%s",s+1);n=strlen(s+1);
    for(int i=1;i<=n;i++){
        sam.ins(s[i]-'a',i);
        update(rt[sam.last],1,n,i);
    }
    for(int i=2;i<=sam.cnt;i++)add(sam.fa[i],i);
    dfs(1);
    int q;scanf("%d",&q);
    while(q--){
        int l,r;scanf("%s%d%d",s+1,&l,&r);
        m=strlen(s+1);
        long long ans=0;
        for(int i=1;i<=m;i++)sam2.ins(s[i]-'a',i);
        get(l,r);
        for(int i=2;i<=sam2.cnt;i++)ans+=max(0,sam2.len[i]-max(match[sam2.firstpos[i]],sam2.len[sam2.fa[i]]));
        printf("%lld\n",ans);
        for(int i=1;i<=sam2.cnt;i++){
            sam2.len[i]=sam2.fa[i]=sam2.firstpos[i]=0;
            for(int j=0;j<26;j++)sam2.trie[i][j]=0;
        }
        sam2.cnt=sam2.last=1;
    }
    return 0;
}

先就这样发出来。

posted @ 2023-02-01 19:13  gtm1514  阅读(21)  评论(0编辑  收藏  举报