【LOJ #6198】谢特
Description
给出一个长度为 \(n\) 的,仅包含小写字母的字符串 \(s\)。
定义这个字符串以第 \(i\) 个字符开头的后缀为后缀 \(i\)(编号从 \(1\) 开始),每个后缀 \(i\) 都有一个权值 \(w_i\),同时定义两个后缀 \(i, j (i \neq j)\) 的贡献为它们的最长公共前缀长度加上它们权值的异或和,也就是 \(\text{LCP}(i, j) + (w_i \ \text{xor} \ w_j)\)。
而你的任务就是,求出这个字符串的所有后缀两两之间贡献的最大值。
数据范围:\(1 \leq n \leq 10^5\),\(0 \leq w_i < n\)。
时空限制:\(1000 \ \mathrm{ms} / 512 \ \mathrm{MiB}\)。
Solution
算法一:SA + 可持久化 0/1 trie
对字符串 \(s\) 做一遍 SA,再将 \(\text{height}\) 数组求出。
为了方便求解,我们求出一个新的后缀权值序列 \(\{ w'_i \}\),满足 \(w'_i = w_{\text{SA}_i}\)。
那么,后缀 \(\text{SA}_i, \text{SA}_j(i < j)\) 的贡献即为:
那么现在就是要求出所有满足 \(1 \leq i < j \leq n\) 的数对 \((i, j)\) 所计算出的上式最大值。
考虑分治。定义分治函数 \(\text{solve}(l, r)\),表示计算 \(l \leq i < j \leq r\) 时的答案。
取分治中心 \(\text{mid}\) 为区间 \((l, r]\) 中 \(\text{height}\) 值最小的一点,至于计算 \(\text{mid}\) 可以使用 ST 表。
那么现在只需要计算出所有满足 \(l \leq i < \text{mid} \leq j \leq r\) 的数对 \((i, j)\) 的贡献最大值。随后调用 \(\text{solve}(l, \text{mid} - 1)\) 和 \(\text{solve}(\text{mid}, r)\) 即可。
对于所有满足 \(l \leq i < \text{mid} \leq j \leq r\) 的数对 \((i, j)\),一定有 \(\min\limits_{i < k \leq j} \left\{ \text{height}_k \right\} = \text{height}_\text{mid}\)。
此时贡献表达式的第一个参数已经确定下来,现在就是要求 \(\left(w'_i \ \text{xor} \ w'_j\right)\) 的最大值。
对于分治中心 \(\text{mid}\) 分出来的两个区间 \([l, \text{mid})\) 与 \([\text{mid}, r]\)。我们选择长度较小的一段区间,穷举该区间里的每一个点,求出该点与另一段区间中的每个点的异或最大值来更新答案,使用可持久化 0/1 trie 维护即可。
注意到每枚举一个区间,总区间的长度就至少乘 \(2\),这本质上是一个启发式合并。每个点至多被枚举了 \(\log_2 n\) 次。故总时间复杂度为 \(\mathcal{O}(n \log^2 n)\)。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 100100;
int n, m = 256;
char s[N];
int temp[N], w[N];
int SA[N], rk[N];
int cnt[N], id[N], oldrk[N * 2], px[N];
bool cmp(int x, int y, int w) {
return oldrk[x] == oldrk[y] && oldrk[x + w] == oldrk[y + w];
}
int height[N];
int logx[N];
int f[N][17];
int calc(int l, int r) {
int k = logx[r - l + 1];
if (height[f[l][k]] <= height[f[r - (1 << k) + 1][k]])
return f[l][k];
else
return f[r - (1 << k) + 1][k];
}
int cT, root[N];
struct trie {
int ch[2];
int cnt;
} t[N * 32];
void insert(int &p, int now, int dep, int val) {
p = ++ cT;
t[p] = t[now];
t[p].cnt ++;
if (dep < 0) return;
int v = val >> dep & 1;
insert(t[p].ch[v], t[now].ch[v], dep - 1, val);
}
int ask(int p, int q, int dep, int val) {
if (dep < 0) return 0;
int v = val >> dep & 1;
int cnt = t[t[q].ch[v ^ 1]].cnt - t[t[p].ch[v ^ 1]].cnt;
if (cnt)
return ask(t[p].ch[v ^ 1], t[q].ch[v ^ 1], dep - 1, val) + (1 << dep);
else
return ask(t[p].ch[v], t[q].ch[v], dep - 1, val);
}
int ans;
void solve(int l, int r) {
if (l == r) return;
int mid = calc(l + 1, r);
if (mid - l <= r - mid + 1) {
for (int i = l; i < mid; i ++)
ans = max(ans, height[mid] + ask(root[mid - 1], root[r], 17, w[i]));
} else {
for (int i = mid; i <= r; i ++)
ans = max(ans, height[mid] + ask(root[l - 1], root[mid - 1], 17, w[i]));
}
solve(l, mid - 1), solve(mid, r);
}
int main() {
scanf("%d", &n);
scanf("%s", s + 1);
for (int i = 1; i <= n; i ++)
scanf("%d", &temp[i]);
for (int i = 1; i <= n; i ++) rk[i] = s[i];
for (int i = 1; i <= n; i ++) cnt[rk[i]] ++;
for (int i = 1; i <= m; i ++) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; i --) SA[cnt[rk[i]] --] = i;
for (int w = 1, p = 0; w < n; w <<= 1, m = p) {
p = 0;
for (int i = n - w + 1; i <= n; i ++) id[++ p] = i;
for (int i = 1; i <= n; i ++)
if (SA[i] > w) id[++ p] = SA[i] - w;
for (int i = 0; i <= m; i ++) cnt[i] = 0;
for (int i = 1; i <= n; i ++) cnt[px[i] = rk[id[i]]] ++;
for (int i = 1; i <= m; i ++) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; i --) SA[cnt[px[i]] --] = id[i];
p = 0;
for (int i = 1; i <= n; i ++) oldrk[i] = rk[i];
for (int i = 1; i <= n; i ++)
rk[SA[i]] = cmp(SA[i - 1], SA[i], w) ? p : ++ p;
}
for (int i = 1, H = 0; i <= n; i ++) {
if (H) H --;
while (s[i + H] == s[SA[rk[i] - 1] + H]) H ++;
height[rk[i]] = H;
}
logx[0] = -1;
for (int i = 1; i <= n; i ++)
logx[i] = logx[i >> 1] + 1;
for (int i = 1; i <= n; i ++)
f[i][0] = i;
for (int j = 1; j <= 16; j ++)
for (int i = 1; i + (1 << j) - 1 <= n; i ++) {
if (height[f[i][j - 1]] <= height[f[i + (1 << (j - 1))][j - 1]])
f[i][j] = f[i][j - 1];
else
f[i][j] = f[i + (1 << (j - 1))][j - 1];
}
for (int i = 1; i <= n; i ++)
w[i] = temp[SA[i]];
for (int i = 1; i <= n; i ++)
insert(root[i], root[i - 1], 17, w[i]);
solve(1, n);
printf("%d\n", ans);
return 0;
}
算法二:SAM + 0/1 trie 合并
对字符串 \(s\) 的反串建出后缀自动机,并求出 parent 树。 其实这个 parent 树就是原串的后缀树。
因为对于后缀树上任意两点 \(x, y\),分别从 \(x, y\) 代表的子串集合中选出一个子串。那么这两个子串的 \(\text{LCP}\) 即为 \(\text{lca}(x, y)\) 代表的子串中最长的串。
所以可以考虑在这个 parent 树上走,设当前走到的点为 \(u\)。此时在 \(u\) 的两个不同的子树内取一个后缀 \(i\) 和一个后缀 \(j\),那么一定有 \(\text{LCP}(i,j) = \text{Longest}(u)\)。
此时贡献表达式的第一个参数已经确定下来,那么就是要分别在 \(u\) 的两个不同的子树内取一个后缀,使得这两个后缀的权值异或和最大。
对于一个点 \(u\),不妨设它只有两个子树(多个子树的情况,可以考虑将子树两两合并)。我们选择大小较小的一棵子树,穷举该子树内的每一个后缀,求出该后缀与另一个子树中的每个后缀的异或最大值来更新答案,使用 0/1 trie 合并维护即可。
注意到每枚举一棵子树,总子树的大小就至少乘 \(2\),这本质上是一个启发式合并。每个点至多被枚举了 \(\log_2 n\) 次。故总时间复杂度为 \(\mathcal{O}(n \log^2 n)\)。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int N = 100100;
const int SIZE = N * 2;
int n;
char s[N];
int w[N];
vector<int> g[SIZE];
int rush;
int root[SIZE];
struct trie {
int ch[2];
} H[N * 32];
void insert(int p, int num) {
for (int i = 16; i >= 0; i --) {
int v = num >> i & 1;
if (!H[p].ch[v])
H[p].ch[v] = ++ rush;
p = H[p].ch[v];
}
}
int merge(int p, int q) {
if (!p || !q)
return p ^ q;
H[p].ch[0] = merge(H[p].ch[0], H[q].ch[0]);
H[p].ch[1] = merge(H[p].ch[1], H[q].ch[1]);
return p;
}
int ask(int p, int num) {
int ans = 0;
for (int i = 16; i >= 0; i --) {
int v = num >> i & 1;
if (H[p].ch[v ^ 1])
p = H[p].ch[v ^ 1], ans += 1 << i;
else
p = H[p].ch[v];
}
return ans;
}
int cT = 1, last = 1;
struct SAM {
int trans[26];
int link, len;
} t[SIZE];
void extend(int c, int val) {
int p = last,
Np = last = ++ cT;
g[Np].push_back(val);
insert(root[Np] = ++ rush, val);
t[Np].len = t[p].len + 1;
for (; p && t[p].trans[c] == 0; p = t[p].link)
t[p].trans[c] = Np;
if (!p)
t[Np].link = 1;
else {
int q = t[p].trans[c];
if (t[q].len == t[p].len + 1)
t[Np].link = q;
else {
int Nq = ++ cT;
t[Nq] = t[q], t[Nq].len = t[p].len + 1;
t[Np].link = t[q].link = Nq;
for (; p && t[p].trans[c] == q; p = t[p].link)
t[p].trans[c] = Nq;
}
}
}
int tot, head[SIZE], ver[SIZE], Next[SIZE];
void addedge(int u, int v) {
ver[++ tot] = v;
Next[tot] = head[u];
head[u] = tot;
}
int ans;
void dfs(int u) {
for (int i = head[u]; i; i = Next[i]) {
int v = ver[i];
dfs(v);
if (g[u].size() < g[v].size())
swap(g[u], g[v]), swap(root[u], root[v]);
for (int j = 0; j < (int)g[v].size(); j ++) {
int num = g[v][j];
g[u].push_back(num);
ans = max(ans, t[u].len + ask(root[u], num));
}
root[u] = merge(root[u], root[v]);
}
}
int main() {
scanf("%d", &n);
scanf("%s", s + 1);
for (int i = 1; i <= n; i ++)
scanf("%d", &w[i]);
for (int i = n; i >= 1; i --)
extend(s[i] - 'a', w[i]);
for (int i = 2; i <= cT; i ++)
addedge(t[i].link, i);
dfs(1);
printf("%d\n", ans);
return 0;
}