POJ 3415 (后缀数组)

  被虐残了T_T。开始没思路,膜拜大牛的思路又看不懂。。。推荐一个题解:http://hi.baidu.com/fpkelejggfbfimd/item/5c76cfcba28fba26e90f2ea6

  思路是求单个串的k前缀,不如有串sx, sy。sx的k前缀是a,sy的k前缀是b,sx + sy的k前缀是c。那结果就是:c - a - b;

  关于怎么求单个串的k前缀:

  大概的思想是求出后缀数组的height值,然后按k进行分组。对每一组里,假设有x个连续的height值为d的情况,那么这段连续的子区间贡献出的结构就是C(x, 2)*(d - k + 1);

因为数据是10^5,所有要O(n^2)来搞肯定不行。然后这里就被卡住了,话说用单调栈就不知道怎么搞了。。。

  参考思路:因为i和j的最长公共前缀是height[rank[i]+1]到height[rank[k]]的最小值。所以可以用一个栈对height进行扫描,当扫描到i位置时,用当前的height[i]值和height[stack[top]]进行比较(stack里面放的是i之前的部分height值)。如果是height[i] > height[stack[top]]则height[i]入栈,继续向后扫描,如果等于的话同样继续向后扫描。如果出现height[i] < height[stack[top]],则要累加结果并且更新栈里的元素,分两种情况:

  1、height[i] >= k && height[i] >height[stack[top-1]],这种情况时,区间为i - stack[top] + 1, 相对贡献值为height[stack[top]] - height[i]。累加结果,把栈顶改为i;

  2、height[i] < height[stack[top-1]],这时这段区间同样为i - stack[top] + 1, 不过贡献值改为height[stack[top]] - height[stack[top-1]]。累加结果,把栈顶值改为stack[top-1];

详见代码:

View Code
//#pragma comment(linker,"/STACK:327680000,327680000")
#include <iostream>
#include <cstdio>
#include <cmath>
#include <vector>
#include <cstring>
#include <algorithm>
#include <string>
#include <set>
#include <functional>
#include <numeric>
#include <sstream>
//#include <stack>
#include <map>
#include <queue>

#define CL(arr, val)    memset(arr, val, sizeof(arr))
#define REP(i, n)       for((i) = 0; (i) < (n); ++(i))
#define FOR(i, l, h)    for((i) = (l); (i) <= (h); ++(i))
#define FORD(i, h, l)   for((i) = (h); (i) >= (l); --(i))
#define L(x)    (x) << 1
#define R(x)    (x) << 1 | 1
#define MID(l, r)   (l + r) >> 1
#define Min(x, y)   (x) < (y) ? (x) : (y)
#define Max(x, y)   (x) < (y) ? (y) : (x)
#define E(x)        (1 << (x))
#define iabs(x)     (x) < 0 ? -(x) : (x)
#define OUT(x)  printf("%I64d\n", x)
#define Read()  freopen("data.in", "r", stdin)
#define Write() freopen("data.out", "w", stdout);

typedef long long LL;
const double eps = 1e-8;
const double PI = acos(-1.0);
const int inf = ~0u>>2;


using namespace std;

const int maxn = 200010;

int wa[maxn], wb[maxn], wv[maxn], WS[maxn];
int cmp(int *r, int a, int b, int l) {
    return r[a] == r[b]&&r[a+l] == r[b+l];
}

void da(int* r, int* sa, int n, int m) {
    int i, j, p, *x = wa, *y = wb, *t;
    for(i = 0; i < m; ++i)  WS[i] = 0;
    for(i = 0; i < n; ++i)  WS[x[i]=r[i]]++;
    for(i = 1; i < m; ++i)  WS[i] += WS[i-1];
    for(i = n - 1; i >= 0; --i) sa[--WS[x[i]]] = i;
    for(j = 1, p = 1; p < n; j *= 2, m = p) {
        for(p = 0, i = n - j; i < n; ++i)   y[p++] = i;
        for(i = 0; i < n; ++i)  if(sa[i] >= j)  y[p++] = sa[i] - j;
        for(i = 0; i < n; ++i)  wv[i] = x[y[i]];
        for(i = 0; i < m; ++i)  WS[i] = 0;
        for(i = 0; i < n; ++i)  WS[wv[i]]++;
        for(i = 1; i < m; ++i)  WS[i] += WS[i-1];
        for(i = n - 1; i >= 0; --i) sa[--WS[wv[i]]] = y[i];
        for(t = x, x = y, y = t, p = 1, x[sa[0]] = 0, i = 1; i < n; ++i)
            x[sa[i]] = cmp(y, sa[i-1], sa[i], j)?p-1:p++;
    }
    return ;
}

int rank[maxn], height[maxn];
void calheight(int* r, int* sa, int n) {
    int i, j, k = 0;
    for(i = 1; i <= n; ++i) rank[sa[i]] = i;
    for(i = 0; i < n; height[rank[i++]] = k)
    for(k?k--:0, j = sa[rank[i]-1]; r[i+k] == r[j+k]; ++k);
    return ;
}

int r[maxn], sa[maxn], stack[maxn], k;

LL cal(char* st) {
    int n = strlen(st), i, top, hg;
    LL res = 0, m, fac;
    for(i = 0; i < n; ++i)  r[i] = st[i];
    r[n] = 0;
    da(r, sa, n + 1, 129);
    calheight(r, sa, n);

    height[0] = height[n+1] = k - 1;
    top = 0; i = 1;
    stack[0] = 0;

    while(i <= n + 1) {
        hg = height[stack[top]];
        if(height[i] < k && top == 0)    i++;
        else if(height[i] == hg) i++;
        else if(height[i] > hg) stack[++top] = i++;
        else {
            m = i - stack[top] + 1;
            if(height[i] >= k && height[i] > height[stack[top-1]]) {
                fac = hg - height[i];
                height[stack[top]] = height[i];
            } else {
                fac = hg - height[stack[top-1]];
                top--;
            }
            res += (LL(m)*LL(m-1)/2*LL(fac));
        }
    }
    return res;
}

char sx[maxn], sy[maxn];

int main() {
    //Read();
    int i;
    while(scanf("%d", &k), k) {
        CL(sx, 0); CL(sy, 0);
        scanf("%s", sx);
        scanf("%s", sy);
        LL a = cal(sx);
        LL b = cal(sy);
        int n = strlen(sx), m = strlen(sy);
        sx[n++] = ' ';
        for(i = 0; i < m; ++i) {
            sx[n++] = sy[i];
        }
        //puts(sx);
        LL c = cal(sx);
        //printf("%lld %lld %lld\n", a, b, c);
        printf("%lld\n", c - a - b);
    }
    return 0;
}
posted @ 2012-11-20 09:32  AC_Von  阅读(1620)  评论(0编辑  收藏  举报