Luogu4770 [NOI2018]你的名字

Luogu4770 [NOI2018]你的名字

\(SAM+LCT\)

\(update2020.11.16:\)更新了一只\(\log\)解法。

好歹自己切了一道字符串黑题,这几天字符串没白颓。

观察原问题,如果模式串不是区间形式的话,很容易想到一个做法,就是对于输入串每一个\(r\)位置,除去\(r\)所在的后缀中不满足条件的后缀,显然不满足条件的后缀一定是一段区间,所以从\(r\)后缀在\(SAM\)\(parent\)树上的匹配的最深位置\(x\)\(root\)路径中所有的后缀都需要砍掉。

然而模式串是一段区间,可以发现,不满足条件的串仍然是\(x \rightarrow root\)路径上一个位置\(y\)\(root\)的路径,因为新的不满足条件的后缀一定是原问题的子集。

那么我们考虑找到那个节点,也就是说,该节点代表的集合中存在串\(s\),使得它在\(r\)之前最后出现的位置的左端点\(L \ge l\)

参考Luogu6292 区间本质不同子串个数的做法,将\(r\)指针一位一位扫过去,同时用\(LCT\)更新最右的右端点。

同时,在\(parent\)树上,节点的\(l\)端点在\(x \rightarrow root\)上一定单调递增,因为对于最右\(r\)端点,显然祖先一定不小于子孙,对于长度,祖先又比子孙小。

既然具有了单调性,我们就可以考虑倍增了,向\(L<l\)的最浅节点跳。在细节上,注意我们最终跳到的节点可能会存在一部分子串满足题意,这些必须统计。

还有一个问题需要处理,就是对于输入的串,即使除去了不满足题意的串,其自身的串依然会存在重复情况,对此,我们对输入串仍需要建立\(SAM\)

那么我们在\(parent\)树上打标记,我们已经找出了一个后缀长度在\([1,t]\)范围内是不合法的,这在新的\(SAM\)上仍然对应一个节点到根的路径,再次倍增。

最后一次\(DFS\)统计答案即可。

时间复杂度:\(O(\sum \lvert T_i \rvert \log^2 n)\)

\(Code:\)

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<string>
#include<vector>
#define pr pair<int,int>
#define mp make_pair
#define ll long long
#define IT vector< pr > :: iterator
#define N 500005
#define qN 100005
using namespace std;
int l,r;
int lth[qN],g[N << 1];
ll ans[qN];
char s[N];
string T[qN];
vector< pr >e[N];
int n,q;
struct edge
{
    int nxt,v;
    edge (int Nxt=0,int V=0)
    {
        nxt=Nxt,v=V;
    }
}E[N << 1];
int tot,fr[N << 1];
void add(int x,int y)
{
    ++tot;
    E[tot]=edge(fr[x],y),fr[x]=tot;
}
struct SAM
{
    int lst=1,cnt=1,tr[N << 1][26],pre[N << 1],len[N << 1];
    int f[N << 1][22];
    void ins(int c)
    {
        int p=lst,q,np;
        lst=np=++cnt;
        len[np]=len[p]+1;
        for (;p && !tr[p][c];p=pre[p])
            tr[p][c]=np;
        if (!p) 
            pre[np]=1; else
            {
                q=tr[p][c];
                if (len[p]+1==len[q])
                    pre[np]=q; else
                    {
                        int g=++cnt;
                        memcpy(tr[g],tr[q],sizeof(tr[q]));
                        len[g]=len[p]+1,pre[g]=pre[q];
                        for (;p && tr[p][c]==q;p=pre[p])
                            tr[p][c]=g;
                        pre[np]=pre[q]=g;
                    }
            }
    }
    void Do_st()
    {
        for (int i=1;i<=cnt;++i)
            f[i][0]=pre[i];
        for (int j=1;j<=20;++j)
            for (int i=1;i<=cnt;++i)
                f[i][j]=f[f[i][j-1]][j-1];
    }
    void build()
    {
        for (int i=2;i<=cnt;++i)
            add(pre[i],i);
    }
    void Clear()
    {
        memset(tr,0,26*(cnt+1)*sizeof(int));
        memset(g,0,(cnt+1)*sizeof(int));
        memset(fr,0,(cnt+1)*sizeof(int));
        tot=0,cnt=lst=1;
    }
}S1,S2;
#define ls(x) a[x].ch[0]
#define rs(x) a[x].ch[1]
#define fa(x) a[x].f
#define tag(x) a[x].cltg
#define col(x) a[x].cl
struct LCT
{
    int ch[2],f,cltg,cl;
}a[N << 1];
int Q[N << 1];
int id(int x)
{
    return ls(fa(x))==x?0:1;
}
bool isrt(int x)
{
    return ls(fa(x))!=x && rs(fa(x))!=x;
}
void connect(int x,int F,int son)
{
    fa(x)=F;                           
    a[F].ch[son]=x;
}
void rot(int x)
{
    int y=fa(x),r=fa(y);
    int yson=id(x),rson=id(y);
    if (isrt(y))
        fa(x)=r; else
        connect(x,r,rson);
    connect(a[x].ch[yson^1],y,yson);
    connect(y,x,yson^1);
}
void push_tag(int x,int z)
{
    if (!x)
        return;
    tag(x)=col(x)=z;
}
void push_down(int x)
{
    if (tag(x))
    {
        push_tag(ls(x),tag(x));
        push_tag(rs(x),tag(x));
        tag(x)=0;
    }
}
void splay(int x)
{
    int g=x,k=0;
    Q[++k]=x;
    while (!isrt(g))
        g=fa(g),Q[++k]=g;
    while (k)
        push_down(Q[k--]);
    while (!isrt(x))
    {
        int y=fa(x);
        if (isrt(y))
            rot(x); else
        if (id(x)==id(y))
            rot(y),rot(x); else
            rot(x),rot(x);
    }
}
void access(int x,int r)
{
    int y;
    for (y=0;x;y=x,x=fa(x))
    {
        splay(x);
        rs(x)=y;
    }
    push_tag(y,r);
}
int Col(int x)
{
    splay(x);
    return col(x);
}
void Dfs(int u,int w)
{
    for (int i=fr[u];i;i=E[i].nxt)
    {
        int v=E[i].v;
        Dfs(v,w);
        g[u]=max(g[u],g[v]);
    }
    ans[w]+=S2.len[u]-max(S2.len[S2.pre[u]],min(g[u],S2.len[u]));
}
int main()
{
    scanf("%s",s+1);
    n=strlen(s+1);
    for (int i=1;i<=n;++i)
        S1.ins(s[i]-'a');
    S1.Do_st();
    for (int i=2;i<=S1.cnt;++i)
        fa(i)=S1.pre[i];
    scanf("%d",&q);
    for (int i=1;i<=q;++i)
    {
        cin >> T[i];
        lth[i]=T[i].length();
        scanf("%d%d",&l,&r);
        e[r].push_back(mp(l,i));
    }
    int st=1;
    for (int i=1;i<=n;++i)
    {
        st=S1.tr[st][s[i]-'a'];
        access(st,i);
        for (IT it=e[i].begin();it!=e[i].end();++it)
        {
            S2.Clear();
            int l=it->first,w=it->second,s0=1,st0=1,nlen=0;
            for (int j=0;j<lth[w];++j)
                S2.ins(T[w][j]-'a');
            S2.Do_st();
            for (int j=0;j<lth[w];++j)
            {
                int c=T[w][j]-'a';
                st0=S2.tr[st0][c];
                if (!S1.tr[1][c])
                    s0=1,nlen=0; else
                    {
                        while (!S1.tr[s0][c])
                            s0=S1.pre[s0],nlen=S1.len[s0];
                        s0=S1.tr[s0][c];
                        ++nlen;
                        int tans;
                        if (Col(s0)-nlen+1>=l)
                            tans=nlen; else
                            {
                                int F=s0;
                                for (int j=20;j>=0;--j)
                                    if (S1.f[F][j] && Col(S1.f[F][j])-S1.len[S1.f[F][j]]+1<l)
                                        F=S1.f[F][j];
                                tans=max(Col(F)-l+1,S1.len[S1.f[F][0]]);
                            }
                        int k=st0;
                        for (int j=20;j>=0;--j)
                            if (S2.f[k][j] && S2.len[S2.f[k][j]]>=tans)
                                k=S2.f[k][j];
                        g[k]=max(g[k],tans);
                    }
            }
            S2.build();
            Dfs(1,w);
        }
    }
    for (int i=1;i<=q;++i)
        printf("%lld\n",ans[i]);
    return 0;
}

根据线段树合并解法的启发,发现倍增根本不需要,我们在需要跳祖先的地方跳祖先即可,这样可以省去一只\(\log\)

时间复杂度:\(O(\sum \lvert T_i \rvert \log n)\)

\(Code:\)

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<string>
#include<vector>
#define pr pair<int,int>
#define mp make_pair
#define ll long long
#define IT vector< pr > :: iterator
#define N 500005
#define qN 100005
using namespace std;
int l,r;
int lth[qN],g[N];
ll ans[qN];
char s[N];
string T[qN];
vector< pr >e[N];
int n,q;
struct SAM
{
    int lst=1,cnt=1,tr[N << 1][26],pre[N << 1],len[N << 1],lc[N << 1];
    void ins(int c)
    {
        int p=lst,q,np;
        lst=np=++cnt;
        len[np]=len[p]+1,lc[np]=len[p];
        for (;p && !tr[p][c];p=pre[p])
            tr[p][c]=np;
        if (!p) 
            pre[np]=1; else
            {
                q=tr[p][c];
                if (len[p]+1==len[q])
                    pre[np]=q; else
                    {
                        int g=++cnt;
                        memcpy(tr[g],tr[q],sizeof(tr[q]));
                        len[g]=len[p]+1,pre[g]=pre[q],lc[g]=lc[q];
                        for (;p && tr[p][c]==q;p=pre[p])
                            tr[p][c]=g;
                        pre[np]=pre[q]=g;
                    }
            }
    }
    void Clear()
    {
        memset(tr,0,26*(cnt+1)*sizeof(int));
        cnt=lst=1;
    }
}S1,S2;
#define ls(x) a[x].ch[0]
#define rs(x) a[x].ch[1]
#define fa(x) a[x].f
#define tag(x) a[x].cltg
#define col(x) a[x].cl
struct LCT
{
    int ch[2],f,cltg,cl;
}a[N << 1];
int Q[N << 1];
int id(int x)
{
    return ls(fa(x))==x?0:1;
}
bool isrt(int x)
{
    return ls(fa(x))!=x && rs(fa(x))!=x;
}
void connect(int x,int F,int son)
{
    fa(x)=F;                           
    a[F].ch[son]=x;
}
void rot(int x)
{
    int y=fa(x),r=fa(y);
    int yson=id(x),rson=id(y);
    if (isrt(y))
        fa(x)=r; else
        connect(x,r,rson);
    connect(a[x].ch[yson^1],y,yson);
    connect(y,x,yson^1);
}
void push_tag(int x,int z)
{
    if (!x)
        return;
    tag(x)=col(x)=z;
}
void push_down(int x)
{
    if (tag(x))
    {
        push_tag(ls(x),tag(x));
        push_tag(rs(x),tag(x));
        tag(x)=0;
    }
}
void splay(int x)
{
    int g=x,k=0;
    Q[++k]=x;
    while (!isrt(g))
        g=fa(g),Q[++k]=g;
    while (k)
        push_down(Q[k--]);
    while (!isrt(x))
    {
        int y=fa(x);
        if (isrt(y))
            rot(x); else
        if (id(x)==id(y))
            rot(y),rot(x); else
            rot(x),rot(x);
    }
}
void access(int x,int r)
{
    int y;
    for (y=0;x;y=x,x=fa(x))
    {
        splay(x);
        rs(x)=y;
    }
    push_tag(y,r);
}
int Col(int x)
{
    splay(x);
    return col(x);
}
int main()
{
    scanf("%s",s+1);
    n=strlen(s+1);
    for (int i=1;i<=n;++i)
        S1.ins(s[i]-'a');
    for (int i=2;i<=S1.cnt;++i)
        fa(i)=S1.pre[i];
    scanf("%d",&q);
    for (int i=1;i<=q;++i)
    {
        cin >> T[i];
        lth[i]=T[i].length();
        scanf("%d%d",&l,&r);
        e[r].push_back(mp(l,i));
    }
    int st=1;
    for (int i=1;i<=n;++i)
    {
        st=S1.tr[st][s[i]-'a'];
        access(st,i);
        for (IT it=e[i].begin();it!=e[i].end();++it)
        {
            S2.Clear();
            int l=it->first,w=it->second,s0=1,st0=1,nlen=0;
            for (int j=0;j<lth[w];++j)
                S2.ins(T[w][j]-'a');
            for (int j=0;j<lth[w];++j)
            {
                int c=T[w][j]-'a';
                st0=S2.tr[st0][c];
                if (!S1.tr[1][c] || Col(S1.tr[1][c])<l)
                    s0=1,nlen=0; else
                    {
                        while (!S1.tr[s0][c])
                            s0=S1.pre[s0],nlen=S1.len[s0];
                        s0=S1.tr[s0][c];
                        ++nlen;
                        if (Col(s0)-nlen+1<l)
                        {
                            while (Col(s0)-S1.len[S1.pre[s0]]<l)
                                s0=S1.pre[s0];
                            nlen=min(S1.len[s0],Col(s0)-l+1);
                        }
                    }
                g[j]=nlen;
            }
            for (int j=2;j<=S2.cnt;++j)
                ans[w]+=max(S2.len[j]-max(g[S2.lc[j]],S2.len[S2.pre[j]]),0);
        }
    }
    for (int i=1;i<=q;++i)
        printf("%lld\n",ans[i]);
    return 0;
}
posted @ 2020-11-14 12:58  GK0328  阅读(51)  评论(0编辑  收藏  举报