Luogu4770 NOI2018 你的名字 SAM、主席树

传送门


UPD:发现之前被smy误导的一个细节,改过来之后就AC了……

一道比较套路的SAM题,虽然我连套路都不会……

先考虑前\(68pts\),也就是\(l=1 , r=|S|\)的情况。我们对\(S\)建好SAM,把\(T\)扔到\(S\)的SAM上匹配,如果不考虑本质不同子串的性质,那么答案就是\(\sum\limits_{i=1}^{|T|} i - l_i\),其中\(l_i\)是匹配到第\(i\)个字符时的匹配长度。

然后考虑如何去重。对\(T\)也建SAM,把\(T\)也放在\(T\)的SAM上匹配。发现在匹配到第\(i\)个字符时,以\(i\)为右端点、长度为\([1,l_i]\)的串都是不合法的,而这些不合法的串在\(T\)所在的SAM上对应的是\(parent\)树上的一条链。于是在放在\(T\)的SAM上匹配的时候不断跳父亲,直到\(Shortest_u \leq l_i\),然后在\(u\)上打上\(l_i\)的标记。最后在\(parent\)树上递推一遍把标记传一下就可以计算出不合法的串的数量。

然后考虑一般情况。唯一存在的问题是:可能存在某些情况下到达的状态在\(S[l,r]\)中没有出现。这个时候考虑对于\(S\)上的每一个节点维护它的\(endpos\)集合,此时\(S[l,r]\)中不存在当前对应的串\(\Leftrightarrow\)当前状态的\(endpos\)集合与\([l+len-1,r]\)无交,然后\(--len\),如果\(len =\)父亲状态的\(Longest\)就跳父亲。这里暴力修改\(len\)的总次数仍然是\(O(|T|)\)的所以可以接受。

对于\(endpos\)集合的维护,不难知道某一个点有的\(endpos\)它的祖先也会有。所以就变成了一个子树内某区间是否有值的问题,在\(parent\)树上跑出dfn序然后建立主席树即可。总复杂度\(O(\sum |T| log |S|)\)

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<queue>
#include<vector>
//This code is written by Itst
using namespace std;

const int MAXN = 1e6 + 7;

namespace Tree{
    struct node{
        int l , r , sz;
    }Tree[MAXN << 4];
    int rt[MAXN] , cnt;

#define mid ((l + r) >> 1)
    int insert(int x , int l , int r , int tar){
        int t = ++cnt;
        Tree[t] = Tree[x];
        ++Tree[t].sz;
        if(l == r) return t;
        if(mid >= tar)
            Tree[t].l = insert(Tree[t].l , l , mid , tar);
        else
            Tree[t].r = insert(Tree[t].r , mid + 1 , r , tar);
        return t;
    }

    bool query(int x , int y , int l , int r , int L , int R){
        if(L > R || Tree[x].sz == Tree[y].sz) return 0;
        if(l >= L && r <= R) return 1;
        if(mid >= L && query(Tree[x].l , Tree[y].l , l , mid , L , R))
            return 1;
        return mid < R && query(Tree[x].r , Tree[y].r , mid + 1 , r , L , R);
    }
}
using Tree::insert; using Tree::query; using Tree::rt;

struct SAM{
    int Lst[MAXN] , Sst[MAXN] , fa[MAXN] , trans[MAXN][26] , endpos[MAXN];
    int cnt = 1 , lst = 1 , L;

    void insert(int len , int x){
        int t = ++cnt , p = lst;
        Lst[lst = t] = endpos[t] = len;
        while(p && !trans[p][x]){
            trans[p][x] = t;
            p = fa[p];
        }
        if(!p) {Sst[t] = fa[t] = 1; return;}
        int q = trans[p][x];
        Sst[t] = Lst[p] + 2;
        if(Lst[q] == Lst[p] + 1) {fa[t] = q; return;}
        int k = ++cnt;
        memcpy(trans[k] , trans[q] , sizeof(trans[k]));
        Lst[k] = Lst[p] + 1; Sst[k] = Sst[q];
        Sst[q] = Lst[p] + 2;
        fa[k] = fa[q]; fa[q] = fa[t] = k;
        while(trans[p][x] == q){
            trans[p][x] = k;
            p = fa[p];
        }
    }
}S , T;
char s[MAXN];
int mrk[MAXN] , sz[MAXN] , dfn[MAXN] , ts , LS;
vector < int > ch[MAXN];

void clear(){
    memset(T.trans , 0 , sizeof(int) * 26 * (T.cnt + 1));
    memset(T.fa , 0 , sizeof(int) * (T.cnt + 1));
    memset(T.Lst , 0 , sizeof(int) * (T.cnt + 1));
    memset(T.Sst , 0 , sizeof(int) * (T.cnt + 1));
    memset(mrk , 0 , sizeof(int) * (T.cnt + 1));
    T.lst = T.cnt = 1;
}

void dfs(int x){
    dfn[x] = ++ts;
    sz[x] = 1;
    rt[ts] = rt[ts - 1];
    if(S.endpos[x])
        rt[ts] = insert(rt[ts] , 1 , LS , S.endpos[x]);
    for(int i = 0 ; i < ch[x].size() ; ++i){
        dfs(ch[x][i]);
        sz[x] += sz[ch[x][i]];
    }
}

void init(){
    scanf("%s" , s + 1);
    LS = strlen(s + 1);
    for(int i = 1 ; i <= LS ; ++i)
        S.insert(i , s[i] - 'a');
    for(int i = 2 ; i <= S.cnt ; ++i)
        ch[S.fa[i]].push_back(i);
    dfs(1);
}

queue < int > q;
int in[MAXN];
long long ans;
void getans(){
    for(int i = 2 ; i <= T.cnt ; ++i)
        ++in[T.fa[i]];
    for(int i = 2 ; i <= T.cnt ; ++i)
        if(!in[i]) q.push(i);
    while(!q.empty()){
        int t = q.front(); q.pop();
        if(t == 1) continue;
        if(mrk[t]){
            ans -= min(mrk[t] , T.Lst[t]) - T.Sst[t] + 1;
            mrk[T.fa[t]] = mrk[t];
        }
        if(!--in[T.fa[t]])
            q.push(T.fa[t]);
    }
}

int main(){
    #ifndef ONLINE_JUDGE
    freopen("in" , "r" , stdin);
    //freopen("out" , "w" , stdout);
    #endif
    init();
    int Q;
    for(scanf("%d" , &Q) ; Q ; --Q){
        clear();
        int l , r;
        scanf("%s %d %d" , s + 1 , &l , &r);
        int L = strlen(s + 1) , u = 1 , len = 0 , v = 1;
        for(int i = 1 ; i <= L ; ++i)
            T.insert(i , s[i] - 'a');
        for(int i = 1 ; i <= L ; ++i){
            while(u - 1 && !S.trans[u][s[i] - 'a'])
                len = S.Lst[u = S.fa[u]];
            if(S.trans[u][s[i] - 'a']){
                u = S.trans[u][s[i] - 'a'];
                ++len;
            }
            while(u != 1 && !query(rt[dfn[u] + sz[u] - 1] , rt[dfn[u] - 1] , 1 , LS , l + len - 1 , r)){
                --len;
                if(len < S.Sst[u])
                    u = S.fa[u];
            }
            v = T.trans[v][s[i] - 'a'];
            while(v - 1 && T.Sst[v] > len) v = T.fa[v];
            mrk[v] = max(mrk[v] , len);
        }
        ans = 0;
        getans();
        for(int i = 2 ; i <= T.cnt ; ++i)
            ans += T.Lst[i] - T.Sst[i] + 1;
        printf("%lld\n" , ans);
    }
    return 0;
}
posted @ 2019-02-26 21:44  cjoier_Itst  阅读(262)  评论(5编辑  收藏  举报