Codeforces Round #146 (Div. 1) C - Cyclical Quest 后缀自动机+最小循环节

#include<bits/stdc++.h>
#define LL long long
#define fi first
#define se second
#define mk make_pair
#define PII pair<int, int>
#define PLI pair<LL, int>
#define ull unsigned long long
using namespace std;

const int N = 2e6 + 7;
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3f;
const int mod = 1e9 + 7;
const double eps = 1e-8;

int n, q, nx[N];
char s[N];

void getNext(char *s, int n) {
    int k = 0;
    for(int i = 1; i < n; i++) {
        while(k && s[k] != s[i]) k = nx[k-1];
        if(s[k] == s[i]) k++;
        nx[i] = k;
    }
}

struct SuffixAutomaton {
    int last, cur, cnt, ch[N<<1][26], id[N<<1], fa[N<<1], dis[N<<1], sz[N<<1], c[N];
    SuffixAutomaton() {cur = cnt = 1;}
    void init() {
        for(int i = 1; i <= cnt; i++) {
            memset(ch[i], 0, sizeof(ch[i]));
            sz[i] = c[i] = dis[i] = fa[i] = 0;
        }
        cur = cnt = 1;
    }
    int extend(int p, int c) {
        cur = ++cnt; dis[cur] = dis[p]+1;
        for(; p && !ch[p][c]; p = fa[p]) ch[p][c] = cur;
        if(!p) fa[cur] = 1;
        else {
            int q = ch[p][c];
            if(dis[q] == dis[p]+1) fa[cur] = q;
            else {
                int nt = ++cnt; dis[nt] = dis[p]+1;
                memcpy(ch[nt], ch[q], sizeof(ch[q]));
                fa[nt] = fa[q]; fa[q] = fa[cur] = nt;
                for(; ch[p][c]==q; p=fa[p]) ch[p][c] = nt;
            }
        }
        sz[cur] = 1;
        return cur;
    }
    void getSize(int n) {
        for(int i = 1; i <= cnt; i++) c[dis[i]]++;
        for(int i = 1; i <= n; i++) c[i] += c[i-1];
        for(int i = cnt; i >= 1; i--) id[c[dis[i]]--] = i;
        for(int i = cnt; i >= 1; i--) sz[fa[id[i]]] += sz[id[i]];
    }

    void solve() {
        scanf("%s", s + 1);
        n = strlen(s + 1);
        for(int i = 1, last = 1; i <= n; i++)
            last = extend(last, s[i]-'a');
        getSize(n);
        scanf("%d", &q);
        while(q--) {
            scanf("%s", s + 1);
            n = strlen(s + 1);
            int tar = n;
            getNext(s + 1, n);
            int len = n%(n-nx[n-1]) ? n : (n-nx[n-1]);
            for(int i = 1; i < len; i++) s[++n] = s[i];
            len = 0;
            LL ans = 0;
            for(int i = 1, p = 1; i <= n; i++) {
                while(p!=1 && !ch[p][s[i]-'a']) p = fa[p];
                if(ch[p][s[i]-'a']) {
                    len = min(dis[p], len) + 1;
                    p = ch[p][s[i]-'a'];
                }
                else len = 0;
                if(len >= tar) {
                    int u = p;
                    while(u != 1 && dis[fa[u]] >= tar) u = fa[u];
                    if(u != 1) ans += sz[u];
                }
            }
            printf("%lld\n", ans);
        }
    }
} sam;

int main() {
    sam.solve();
    return 0;
}

/*
*/

 

posted @ 2018-10-21 19:40  NotNight  阅读(197)  评论(0编辑  收藏  举报