SPOJ 694
题意:
给一个字符串,求它有多少个不同的子串
多组数据。
Solution :
模板题,用所有的减去重复的即可。
#include <cstdio>
#include <algorithm>
#include <iostream>
#include <cstring>
using namespace std;
const int N = 1e6;
char str[N + 1];
namespace SA {
int sa[N + 1], rk[N + 1], height[N + 1];
int a[N+1], set[N+1], n;
int fir[N+1], sec[N+1], buc[N+1], tp[N+1];
void sufSort() {
n = strlen(str + 1);
copy(str + 1, str + 1 + n, set + 1);
sort(set + 1, set + n + 1);
int *end = unique(set + 1, set + n + 1);
for (int i = 1; i <= n; ++ i)
a[i] = lower_bound(set + 1, end, str[i]) - set;
fill(buc, buc + 1 + n, 0);
for (int i = 1; i <= n; ++ i) buc[a[i]] ++;
for (int i = 1; i <= n; ++ i) buc[i] += buc[i - 1];
for (int i = 1 ;i <= n; ++ i) rk[i] = buc[a[i] - 1] + 1;
for (int t = 1; t <= n; t <<= 1) {
copy(rk + 1, rk + 1 + n, fir + 1);//那张图
for (int i = 1; i <= n; ++ i) sec[i] = i + t > n ? 0 : rk[i + t];
fill(buc, buc + 1 + n, 0);//基排桶清空
for (int i = 1; i <= n; ++ i) buc[sec[i]] ++;
for (int i = 1; i <= n; ++ i) buc[i] += buc[i - 1];//统计比sec[i]小的数有多少个
//tp[i]为第二关键字为第i大的二元组在排序前的位置
//n - --buc[sec[i]]为排名(第几大), 排序前的位置自然是i.
for (int i = 1; i <= n; ++ i) tp[n - --buc[sec[i]]] = i;
fill(buc, buc + 1 + n, 0);
for (int i = 1; i <= n; ++ i) buc[fir[i]] ++;
for (int i = 1; i <= n; ++ i) buc[i] += buc[i - 1];
//i为第二关键字为j大的二元组的位置,j递增,即保证sec[i]递增, 所以基排就是正确的。
for (int j = 1, i; j <= n; ++ j) i = tp[j], sa[buc[fir[i]]--] = i;
//手模,理解.
bool unique = 1;//若关键字相同,则排名也要相同,而基排使rk不同
for (int j = 1, i, last = 0; j <= n; ++ j) {
i = sa[j];
if (!last) rk[i] = 1;
else if (sec[i] == sec[last] && fir[i] == fir[last])
rk[i] = rk[last], unique = 0;
else
rk[i] = rk[last] + 1;
last = i;
}
if (unique) break;//如果所有rk都不同,就证明排好了.
}
}
void getHeight() {
for (int i = 1, k = 0; i <= n; ++ i) {
if (rk[i] == 1) k = 0;
else {
if (k > 0) k --;
int j = sa[rk[i] - 1];
while (j + k <= n && i + k <= n && str[i + k] == str[j + k]) k ++;
}
height[rk[i]] = k;
}
}
}using namespace SA;
int main() {
int T; scanf("%d", &T);
while (T --) {
scanf("%s", str + 1);
int n = strlen(str + 1);
sufSort();
getHeight();
int ans = n * (n + 1) / 2;
for (int i = 2; i <= n; ++ i) ans -= height[i];
printf("%d\n", ans);
}
}