[LOJ#6198]谢特[后缀数组+trie+并查集]

题意

给你一个长度为 \(n\) 的字符串,问 \(LCP(i,j)+(w_i\ xor\ w_j)\) 的最大值,其中 \(LCP\) 表示两个后缀的最长公共前缀。

\(n\le 10^5\)

分析

  • 建立 \(SA\) 之后把所有的 \(height\) 从大到小加入,维护连通块(类似 \(MST\) ),这样可以找到某个 \(height\) 作为两个后缀的 \(LCP\) 长度时的合法的区间,启发式合并 \(trie\) 即可。

  • 或者也可以建 \(SAM\)\(parent\) 树上启发式合并 \(trie\)

  • 总时间复杂度为 \(O(nlogn)\)

    代码采用后缀数组的方式。

代码

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
#define go(u) for(int i = head[u], v = e[i].to; i; i=e[i].lst, v=e[i].to)
#define rep(i, a, b) for(int i = a; i <= b; ++i)
#define pb push_back
#define re(x) memset(x, 0, sizeof x)
inline int gi() {
    int x = 0,f = 1;
    char ch = getchar();
    while(!isdigit(ch)) { if(ch == '-') f = -1; ch = getchar();}
    while(isdigit(ch)) { x = (x << 3) + (x << 1) + ch - 48; ch = getchar();}
    return x * f;
}
template <typename T> inline void Max(T &a, T b){if(a < b) a = b;}
template <typename T> inline void Min(T &a, T b){if(a > b) a = b;}
const int N = 1e5 + 7;
int n, ans;
char s[N];
namespace SA {
	int sa[N], x[N], y[N], c[N], h[N];
	void pre(int m) {
		rep(i, 1, m) c[i] = 0;
		rep(i, 1, n) c[x[i] = s[i]] ++;
		rep(i, 1, m) c[i] += c[i - 1];
		for(int i = n; i; --i) sa[c[x[i]]--] = i;
		for(int k = 1; k <= n; k <<= 1) {
			int p = 0;
			for(int i = n; i >= n - k + 1; --i) y[++p] = i;
			rep(i, 1, n) if(sa[i] > k) y[++p] = sa[i] - k;
			rep(i, 1, m) c[i] = 0;
			rep(i, 1, n) c[x[y[i]]]++;
			rep(i, 1, m) c[i] += c[i - 1];
			for(int i = n; i; --i) sa[c[x[y[i]]]--] = y[i];
			swap(x, y);p = 1; x[sa[1]] = 1;
			rep(i, 2, n)
			x[sa[i]] = y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k] ? p : ++p;
			if(p >= n) break;m = p;
		}
		rep(i, 1, n) x[sa[i]] = i;
		for(int i = 1, j = 0; i <= n; ++i) {
			if(j) --j;if(sa[i] == 1) continue;
			while(s[i + j] == s[sa[x[i] - 1] + j]) ++j;
			h[x[i]] = j;
		}
	}
}
struct data {
	int p, h;
	bool operator <(const data &rhs) const {
		return h > rhs.h;
	}
}t[N];
int ndc;
int par[N], rt[N], w[N], ch[N * 20][2];
void ins(int v, int &rt) {
	if(!rt) rt = ++ndc;
	int u = rt;
	for(int i = 16; ~i; --i) {
		int c = v >> i & 1;
		if(!ch[u][c]) ch[u][c] = ++ndc;
		u = ch[u][c];
	}
}
int merge(int a, int b) {
	if(!a || !b) return a + b;
	rep(i, 0, 1) if(ch[a][i] || ch[b][i]) ch[a][i] = merge(ch[a][i], ch[b][i]);
	return a;
}
int getans(int dep, int a, int b) {
	int res = 0;
	rep(i, 0, 1) if(ch[a][i] && ch[b][i ^ 1]) Max(res, getans(dep - 1, ch[a][i], ch[b][i ^ 1]) + (1 << dep));
	if(res) return res;
	rep(i, 0, 1) if(ch[a][i] && ch[b][i]) Max(res, getans(dep - 1, ch[a][i], ch[b][i]));
	return res;
}
int getpar(int a) {
	return par[a] == a ? a: par[a] = getpar(par[a]);
}
void Union(int a, int b) {
	a = getpar(a), b = getpar(b);
	if(a == b) return;
	par[b] = a;
	rt[a] = merge(rt[a], rt[b]);
}
int main() {
	using namespace SA;
	n = gi();
	scanf("%s", s + 1);
	rep(i, 1, n) w[i] = gi();
	pre(128);
	rep(i, 2, n) t[i - 1] = (data){ i, h[i]};
	sort(t + 1, t + n);
	rep(i, 1, n) ins(w[sa[i]], rt[i]), par[i] = i;
	rep(i, 1, n - 1) {
		int f1 = getpar(t[i].p - 1), f2 = getpar(t[i].p), tmp = getans(16, rt[f1], rt[f2]);
		Max(ans, tmp + t[i].h);
		Union(f1, f2);
	}
	printf("%d\n", ans);
	return 0;
}
posted @ 2018-12-26 08:47  fwat  阅读(323)  评论(0编辑  收藏  举报