bzoj3473

广义后缀自动机

具体我也不是很清楚

像这样有很多个串要统计方案的题我们建一个广义后缀自动机,就是每次对一个串建完后把last设为root,然后就是每个串在自动机上跑一遍,记录每个节点的访问次数,为了避免重复,我们记录当前这个节点这个字符串走没走过,出现次数也是要向上推的。最后按照套路把贡献向上推,再走一遍加上每个点的贡献就是答案。

 
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 2e5 + 5;
int n, k;
int a[N], c[N];
string s[N];
char tc[N];
ll f[N];
namespace SAM
{
    struct node {
        int val, par, cnt, last;
        int ch[26];
    } t[N];
    int last = 1, root = 1, sz = 1;
    int nw(int x)
    {
        t[++sz].val = x;
        return sz;
    }
    void extend(int c)
    {
        int p = last, np = nw(t[p].val + 1);
        while(p && !t[p].ch[c]) t[p].ch[c] = np, p = t[p].par;
        if(!p) t[np].par = root;
        else
        {
            int q = t[p].ch[c];
            if(t[q].val == t[p].val + 1) t[np].par = q;
            else
            {
                int nq = nw(t[p].val + 1);
                memcpy(t[nq].ch, t[q].ch, sizeof(t[q].ch));
                t[nq].par = t[q].par;
                t[q].par = t[np].par = nq;
                while(p && t[p].ch[c] == q) t[p].ch[c] = nq, p = t[p].par;
            }
        }
        last = np;
    }
} using namespace SAM;
int main()
{
    scanf("%d%d", &n, &k);
    for(int i = 1; i <= n; ++i)
    {
        scanf("%s", tc);
        s[i] = string(tc);
        int len = strlen(tc);
        last = root;
        for(int j = 0; j < len; ++j) extend(tc[j] - 'a');
    }
    for(int i = 1; i <= n; ++i)
    {
        int u = root, ans = 0;
        for(int j = 0; j < s[i].size(); ++j) 
        {
            u = t[u].ch[s[i][j] - 'a'];
            int p = u;
            while(p) 
            {
                if(t[p].last != i) 
                {
                    ++t[p].cnt;
                    t[p].last = i;
                }
                else break;
                p = t[p].par;
            }
        }
    }
    for(int i = 1; i <= sz; ++i) ++c[t[i].val];
    for(int i = 1; i <= sz; ++i) c[i] += c[i - 1];
    for(int i = 1; i <= sz; ++i) a[c[t[i].val]--] = i;
    t[1].cnt = 0;
    for(int i = 1; i <= sz; ++i)
    {
        int u = a[i];
        f[u] += f[t[u].par] + (t[u].cnt >= k ? t[u].val - t[t[u].par].val : 0);
    }
    for(int i = 1; i <= n; ++i)
    {
        int u = root;
        ll ans = 0;
        for(int j = 0; j < s[i].size(); ++j)
        {
            u = t[u].ch[s[i][j] - 'a'];
            ans += f[u];
        }
        printf("%lld ", ans);
    }
    return 0;
}
View Code

 

posted @ 2017-11-19 19:27  19992147  阅读(172)  评论(0编辑  收藏  举报