BZOJ 3238 差异
BZOJ 3238 差异
看这个式子其实就是求任意两个后缀的 $ LCP $ 长度和。前面的 $ len(T_i)+len(T_j) $ 求和其实就是 $ n(n-1)(n+1)/2 $ ,这个是很好推的。。
任意两个后缀的 $ LCP $ 长度和很容易想到构造 height 数组,然后问题就变成了所有区间的最小值的和。
这是个套路题,可以单调栈,但是其实分治也很好写!
设我们要求的区间是 $ [l,r] $ 我们可以找出其中最小值所在的位置,这个可以ST表快速求,然后从这个位置进行分治。
这样的分治每进行一次,总有效的元素数量会减少1,因此复杂度是 $ O(nlogn) $ 的。
开始有个地方漏了 1ll
。。。
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
#define MAXN 500006
#define C 130
namespace wtf {
char ch[MAXN];
int sa[MAXN], tp[MAXN], rnk[MAXN], buc[MAXN], len;
int P[MAXN][20], ht[MAXN];
int que(int l, int r) {
int L = (31 - __builtin_clz(r - l + 1));
return (ht[P[l][L]] < ht[P[r - (1 << L) + 1][L]] ? P[l][L] : P[r - (1 << L) + 1][L]);
}
void init() {
len = strlen(ch + 1);
int m = C;
for (int i = 1; i <= len; ++i) ++buc[rnk[i] = ch[i]];
for (int i = 1; i <= m; ++i) buc[i] += buc[i - 1];
for (int i = len; i >= 1; --i) sa[buc[rnk[i]]--] = i;
for (int k = 1, p = 0; p < len; k <<= 1) {
p = 0;
for (int i = 0; i <= m; ++i) buc[i] = 0;
for (int i = len - k + 1; i <= len; ++i) tp[++p] = i;
for (int i = 1; i <= len; ++i) if (sa[i] > k) tp[++p] = sa[i] - k;
for (int i = 1; i <= len; ++i) ++buc[rnk[i]];
for (int i = 1; i <= m; ++i) buc[i] += buc[i - 1];
for (int i = len; i >= 1; --i) sa[buc[rnk[tp[i]]]--] = tp[i];
p = 1;
swap(rnk, tp);
rnk[sa[1]] = 1;
for (int i = 2; i <= len; ++i)
rnk[sa[i]] = (tp[sa[i]] == tp[sa[i - 1]] && tp[sa[i] + k] == tp[sa[i - 1] + k]) ? p : ++p;
m = p;
}
for (int i = 1; i <= len; ++i) rnk[sa[i]] = i;
for (int i = 1, k = 0; i <= len; ++i) {
if (k) --k;
while (ch[i + k] == ch[sa[rnk[i] - 1] + k]) ++k;
ht[rnk[i]] = k;
P[i][0] = i;
}
ht[0] = 0x3f3f3f3f;
for (int i = 1; i < 20; ++i)
for (int j = 1; j <= len - (1 << i) + 1; ++j)
P[j][i] = (ht[P[j][i - 1]] < ht[P[j + (1 << i - 1)][i - 1]] ? P[j][i - 1] : P[j + (1 << i - 1)][i - 1]);
}
long long res = 0;
void div(int l, int r) {
if( l > r ) return;
int ps = que( l , r );
res += 1ll * ht[ps] * ( ps - l + 1 ) * ( r - ps + 1 );
div( l , ps - 1 ) , div( ps + 1 , r );
}
void main() {
// freopen("1.in","r",stdin);
scanf("%s", ch + 1);
init();
div( 1 , len );
cout << 1ll * ( len - 1 ) * ( len + 1 ) * len / 2 - 2 * res << endl;
}
}
int main() {
wtf::main();
}