【回文自动机】 CODEFORCES 17E Palisection

通道

题意:相交回文对数

思路:用总的方案数减去不相交方案数,注意内存优化,邻接表表示

代码:

#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>

using namespace std;

typedef long long ll;
typedef pair<int,int> pii;

const int MAX_N = 2000005;
struct PTree {
    vector<pii> nxt[MAX_N];
    int fail[MAX_N], cnt[MAX_N], num[MAX_N], len[MAX_N], S[MAX_N];
    int last, n, p;
    int newNode (int l) {
        nxt[p].clear();
        cnt[p] = num[p] = 0;
        len[p] = l;
        return p++;
    }
    void init() {
        p = last = n = 0;
        newNode(0), newNode(-1);
        S[n] = -1;
        fail[0] = 1;
    }
    int getFail(int x) {
        while (S[n - len[x] - 1] != S[n]) x = fail[x];
        return x;
    }
    int add(int c) {
        c -= 'a';
        S[++n] = c;
        int cur = getFail(last), sz = nxt[cur].size(), u = 0;
        bool found = false;
        for (int i = 0; i < sz; ++i) 
            if (nxt[cur][i].first == c) {
                found = true; u = i; break;
            }
        if (!found) {
            int now = newNode(len[cur] + 2);
            int v = getFail(fail[cur]), sz = nxt[v].size(), id = 0;
            bool f1 = false;
            for (int i = 0; i < sz; ++i) {
                if (nxt[v][i].first == c) {
                    id = i;
                    f1 = true;
                    break;
                }
            }
            if (f1)
                fail[now] = nxt[v][id].second;
            else fail[now] = 0;
            nxt[cur].push_back(make_pair(c, now));
            num[now] = num[fail[now]] + 1;
        }
        if (found) last = nxt[cur][u].second;
        else last = nxt[cur][sz].second;
        ++cnt[last];
        return num[last];
    }    
    void count() {
        for (int i = p - 1; i >= 0; --i) cnt[fail[i]] += cnt[i];
    }
};

const ll MOD = 51123987ll;

PTree A;
char s[MAX_N];
ll dp[MAX_N];

int main() {
    int n; scanf("%d%s", &n, s + 1);
    A.init();
    for (int i = 1; s[i]; ++i) {
        dp[i] = (dp[i - 1] + A.add(s[i])) % MOD;
    }
    ll ans = 0;
    A.init();
    for (int i = n; i > 0; --i) {
        ans = (ans + dp[i - 1] * A.add(s[i])) % MOD;
    }
    ans = (dp[n] * (dp[n] - 1) / 2 % MOD - ans + MOD) % MOD;
    printf("%I64d\n", ans);
    return 0;
}
View Code

 

posted @ 2015-08-15 21:32  mithrilhan  阅读(310)  评论(0编辑  收藏  举报