bzoj4566

 后缀自动机+dp

一个串在另一个串上跑。

先对A建出自动机,然后用B在上面跑,记录当前匹配的最大长度,每次经过一个节点记录经过次数,并加上(len-Max(par))*Right,是这个状态对答案的贡献,然后把每个节点的出现次数向par树上的祖先推一遍计算贡献。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 4e5 + 5;
int n, m;
ll ans;
ll sum[N], Right[N], apr[N], f[N];
int a[N], c[N];
char s1[N], s2[N];
namespace SAM
{
    struct node {
        int val, par;
        int ch[26];
    } t[N];
    int last = 1, root = 1, sz = 1;
    int nw(int x)
    {
        t[++sz].val = x;
        return sz;
    }
    void extend(int c)
    {
        int p = last, np = nw(t[p].val + 1);
        while(p && !t[p].ch[c]) t[p].ch[c] = np, p = t[p].par;
        if(!p) t[np].par = root;
        else
        {
            int q = t[p].ch[c];
            if(t[q].val == t[p].val + 1) t[np].par = q;
            else
            {
                int nq = nw(t[p].val + 1);
                memcpy(t[nq].ch, t[q].ch, sizeof(t[q].ch));
                t[nq].par = t[q].par;
                t[q].par = t[np].par = nq;
                while(p && t[p].ch[c] == q) t[p].ch[c] = nq, p = t[p].par;
            }
        }
        Right[np] = 1;
        last = np;
    }
} using namespace SAM;
int main()
{
    scanf("%s%s", s1 + 1, s2 + 1);
    n = strlen(s1 + 1);
    m = strlen(s2 + 1);
    for(int i = 1; i <= n; ++i) extend(s1[i] - 'a');
    for(int i = 1; i <= sz; ++i) ++c[t[i].val];
    for(int i = 1; i <= sz; ++i) c[i] += c[i - 1];
    for(int i = 1; i <= sz; ++i) a[c[t[i].val]--] = i;    
    for(int i = sz; i; --i) Right[t[a[i]].par] += Right[a[i]];
    int u = root, step = 0;
    for(int i = 1; i <= m; ++i)
    {
        int c = s2[i] - 'a';
        if(t[u].ch[c]) u = t[u].ch[c], ++step;
        else
        {
            while(u && !t[u].ch[c]) u = t[u].par;
            if(!u) u = root, step = 0;
            else
            {
                step = t[u].val + 1;
                u = t[u].ch[c];
            }
        }
        ++apr[u];
        if(u != root) ans += (ll)(step - t[t[u].par].val) * Right[u];
    }
    for(int i = sz; i > 1; --i) f[t[a[i]].par] += f[a[i]] + apr[a[i]];
    for(int i = 2; i <= sz; ++i) ans += f[a[i]] * (ll)(t[a[i]].val - t[t[a[i]].par].val) * Right[a[i]];
    printf("%lld\n", ans);
    return 0;
}
View Code

 

posted @ 2017-11-19 19:22  19992147  阅读(228)  评论(0编辑  收藏  举报