Good Article Good sentence HDU - 4416 (后缀自动机)

Good Article Good sentence

\[Time Limit: 3000 ms\quad Memory Limit: 32768 kB \]

题意

给出一个 \(S\) 串,在给出 \(n\)\(T\) 串,求出 \(S\) 串中有多少子串没有在任意一个 \(T\) 串中出现过

思路

\(\quad\) 首先可以对 \(S\) 串构建后缀自动机,然后在插入 \(n\)\(T\) 串,每两个串之间用 \(27\) 隔开,然后可以求出这个自动机上每个节点出现的最左位置 \(left\) 和最右位置 \(right\),然后判断 \(right\)\(Slen\) 内时,这个节点包含的子串就是一个答案!但是!\(MLE\) 了,哭了😭,所以我也不知道这个做法能不能过,应该可以
\(\quad\) 正确的做法,先对 \(S\) 串构建后缀自动机,结构体内多开一个变量 \(maxlen\) 表示在 \(i\) 个节点上,和任意一个 \(T\) 串的最长连续公共子串,可以枚举 \(T\) 串,每一个 \(T\) 串用与 \(SPOJ-LCS\) 类似的做法来做。在注意一下子串往 \(father\) 更新的过程,原理与 \(SPOJ-LCS2\) 类似。
\(\quad\) 求出每个节点的 \(maxlen\) 后,就可以开始计算答案了,可以分成两种情况:

  1. maxlen = 0,这个节点内的所有子串都满足条件,满足条件的子串有 \(node[i].len - node[father].len\) 个。
  2. 长度在 [maxlen+1,node[i].len] 区间内的子串满足条件,满足条件的子串有 \(node[i].len - maxlen\) 个。
#include <map>
#include <set>
#include <list>
#include <ctime>
#include <cmath>
#include <stack>
#include <queue>
#include <cfloat>
#include <string>
#include <vector>
#include <cstdio>
#include <bitset>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define  lowbit(x)  x & (-x)
#define  mes(a, b)  memset(a, b, sizeof a)
#define  fi         first
#define  se         second
#define  pii        pair<int, int>
#define  INOPEN     freopen("in.txt", "r", stdin)
#define  OUTOPEN    freopen("out.txt", "w", stdout)

typedef unsigned long long int ull;
typedef long long int ll;
const int    maxn = 1e5 + 10;
const int    maxm = 1e5 + 10;
const ll     mod  = 1e9 + 7;
const ll     INF  = 1e18 + 100;
const int    inf  = 0x3f3f3f3f;
const double pi   = acos(-1.0);
const double eps  = 1e-8;
using namespace std;

int n, m;
int cas, tol, T;

struct SAM {
    struct Node {
        int next[27];
        int fa, len;
        int maxlen;
        void init() {
            mes(next, 0);
            fa = len = maxlen = 0;
        }
    } node[maxn<<1];
    vector<int> vv[maxn<<1];
    bool vis[maxn<<1];
    int last, sz;
    void init() {
        last = sz = 1;
        node[sz].init();
    }
    void insert(int k) {
        int p = last, np = last = ++sz;
        node[np].init();
        node[np].len = node[p].len+1;
        for(; p&&!node[p].next[k]; p=node[p].fa)
            node[p].next[k] = np;
        if(p == 0) {
            node[np].fa = 1;
        } else {
            int q = node[p].next[k];
            if(node[q].len == node[p].len + 1) {
                node[np].fa = q;
            } else {
                int nq = ++sz;
                node[nq] = node[q];
                node[nq].len = node[p].len + 1;
                node[np].fa = node[q].fa = nq;
                for(; p&&node[p].next[k]==q; p=node[p].fa)
                    node[p].next[k] = nq;
            }
        }
    }
    void solve(char *s) {
        int len = strlen(s+1);
        int p = 1, tmpans = 0;
        for(int i=1; i<=len; i++) {
            int k = s[i]-'a'+1;
            while(p && !node[p].next[k]) {
                p = node[p].fa;
                tmpans = node[p].len;
            }
            if(p == 0) {
                p = 1;
                tmpans = 0;
            } else {
                p = node[p].next[k];
                tmpans++;
            }
            node[p].maxlen = max(node[p].maxlen, tmpans);
        }
    }
    ll ans;
    void dfs(int u) {
        if(vis[u])	return ;
        vis[u] = true;
        for(auto v : vv[u]) {
            dfs(v);
            node[u].maxlen = min(node[u].len, max(node[u].maxlen, node[v].maxlen));
        }
        if(node[u].maxlen == 0) {
            ans += node[u].len - node[node[u].fa].len;
        } else {
			ans += node[u].len - node[u].maxlen;
        }
    }
    ll calc() {
        ans = 0;
        for(int i=1; i<=sz; i++)	vv[i].clear();
        for(int i=2; i<=sz; i++) {
            vv[node[i].fa].push_back(i);
        }
        mes(vis, 0);
        dfs(1);
        return ans;
    }
} sam;
char s[maxn], t[maxn];

int main() {
    cas = 1;
    scanf("%d", &T);
    while(T--) {
        scanf("%d", &n);
        scanf("%s", s+1);
        int slen = strlen(s+1);
        sam.init();
        for(int i=1; i<=slen; i++) {
            sam.insert(s[i]-'a'+1);
        }
        for(int i=1; i<=n; i++) {
            scanf("%s", t+1);
            sam.solve(t);
        }
        ll ans = sam.calc();
        printf("Case %d: %lld\n", cas++, ans);
    }
    return 0;
}
posted @ 2019-05-29 23:27  Jiaaaaaaaqi  阅读(169)  评论(0编辑  收藏  举报