【LOJ #6198】谢特

LOJ #6198

Description

给出一个长度为 \(n\) 的,仅包含小写字母的字符串 \(s\)

定义这个字符串以第 \(i\) 个字符开头的后缀为后缀 \(i\)(编号从 \(1\) 开始),每个后缀 \(i\) 都有一个权值 \(w_i\),同时定义两个后缀 \(i, j (i \neq j)\) 的贡献为它们的最长公共前缀长度加上它们权值的异或和,也就是 \(\text{LCP}(i, j) + (w_i \ \text{xor} \ w_j)\)

而你的任务就是,求出这个字符串的所有后缀两两之间贡献的最大值。

数据范围:\(1 \leq n \leq 10^5\)\(0 \leq w_i < n\)
时空限制:\(1000 \ \mathrm{ms} / 512 \ \mathrm{MiB}\)

Solution

算法一:SA + 可持久化 0/1 trie

对字符串 \(s\) 做一遍 SA,再将 \(\text{height}\) 数组求出。

为了方便求解,我们求出一个新的后缀权值序列 \(\{ w'_i \}\),满足 \(w'_i = w_{\text{SA}_i}\)

那么,后缀 \(\text{SA}_i, \text{SA}_j(i < j)\) 的贡献即为:

\[\min\limits_{i < k \leq j} \left\{ \text{height}_k \right\} + \left(w'_i \ \text{xor} \ w'_j\right) \]

那么现在就是要求出所有满足 \(1 \leq i < j \leq n\) 的数对 \((i, j)\) 所计算出的上式最大值。

考虑分治。定义分治函数 \(\text{solve}(l, r)\),表示计算 \(l \leq i < j \leq r\) 时的答案。

取分治中心 \(\text{mid}\) 为区间 \((l, r]\)\(\text{height}\) 值最小的一点,至于计算 \(\text{mid}\) 可以使用 ST 表。

那么现在只需要计算出所有满足 \(l \leq i < \text{mid} \leq j \leq r\) 的数对 \((i, j)\) 的贡献最大值。随后调用 \(\text{solve}(l, \text{mid} - 1)\)\(\text{solve}(\text{mid}, r)\) 即可。

对于所有满足 \(l \leq i < \text{mid} \leq j \leq r\) 的数对 \((i, j)\),一定有 \(\min\limits_{i < k \leq j} \left\{ \text{height}_k \right\} = \text{height}_\text{mid}\)
此时贡献表达式的第一个参数已经确定下来,现在就是要求 \(\left(w'_i \ \text{xor} \ w'_j\right)\) 的最大值。

对于分治中心 \(\text{mid}\) 分出来的两个区间 \([l, \text{mid})\)\([\text{mid}, r]\)。我们选择长度较小的一段区间,穷举该区间里的每一个点,求出该点与另一段区间中的每个点的异或最大值来更新答案,使用可持久化 0/1 trie 维护即可。

注意到每枚举一个区间,总区间的长度就至少乘 \(2\),这本质上是一个启发式合并。每个点至多被枚举了 \(\log_2 n\) 次。故总时间复杂度为 \(\mathcal{O}(n \log^2 n)\)

#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 100100;

int n, m = 256;
char s[N];

int temp[N], w[N];

int SA[N], rk[N];
int cnt[N], id[N], oldrk[N * 2], px[N];

bool cmp(int x, int y, int w) {
	return oldrk[x] == oldrk[y] && oldrk[x + w] == oldrk[y + w];
}

int height[N];

int logx[N];
int f[N][17];

int calc(int l, int r) {
	int k = logx[r - l + 1];

	if (height[f[l][k]] <= height[f[r - (1 << k) + 1][k]])
		return f[l][k];
	else
		return f[r - (1 << k) + 1][k];
}

int cT, root[N];
struct trie {
	int ch[2];
	int cnt;
} t[N * 32];

void insert(int &p, int now, int dep, int val) {
	p = ++ cT;
	t[p] = t[now];
	t[p].cnt ++;
	if (dep < 0) return;
	int v = val >> dep & 1;
	insert(t[p].ch[v], t[now].ch[v], dep - 1, val);
}

int ask(int p, int q, int dep, int val) {
	if (dep < 0) return 0;
	int v = val >> dep & 1;
	int cnt = t[t[q].ch[v ^ 1]].cnt - t[t[p].ch[v ^ 1]].cnt;
	if (cnt)
		return ask(t[p].ch[v ^ 1], t[q].ch[v ^ 1], dep - 1, val) + (1 << dep);
	else
		return ask(t[p].ch[v], t[q].ch[v], dep - 1, val);
}

int ans;

void solve(int l, int r) {
	if (l == r) return;

	int mid = calc(l + 1, r);

	if (mid - l <= r - mid + 1) {
		for (int i = l; i < mid; i ++)
			ans = max(ans, height[mid] + ask(root[mid - 1], root[r], 17, w[i]));
	} else {
		for (int i = mid; i <= r; i ++)
			ans = max(ans, height[mid] + ask(root[l - 1], root[mid - 1], 17, w[i]));
	}

	solve(l, mid - 1), solve(mid, r);
}

int main() {
	scanf("%d", &n);
	scanf("%s", s + 1);

	for (int i = 1; i <= n; i ++)
		scanf("%d", &temp[i]);

	for (int i = 1; i <= n; i ++) rk[i] = s[i];
	for (int i = 1; i <= n; i ++) cnt[rk[i]] ++;
	for (int i = 1; i <= m; i ++) cnt[i] += cnt[i - 1];
	for (int i = n; i >= 1; i --) SA[cnt[rk[i]] --] = i; 

	for (int w = 1, p = 0; w < n; w <<= 1, m = p) {
		p = 0;
		for (int i = n - w + 1; i <= n; i ++) id[++ p] = i;
		for (int i = 1; i <= n; i ++)
			if (SA[i] > w) id[++ p] = SA[i] - w;

		for (int i = 0; i <= m; i ++) cnt[i] = 0;
		for (int i = 1; i <= n; i ++) cnt[px[i] = rk[id[i]]] ++;
		for (int i = 1; i <= m; i ++) cnt[i] += cnt[i - 1];
		for (int i = n; i >= 1; i --) SA[cnt[px[i]] --] = id[i];

		p = 0;
		for (int i = 1; i <= n; i ++) oldrk[i] = rk[i];
		for (int i = 1; i <= n; i ++)
			rk[SA[i]] = cmp(SA[i - 1], SA[i], w) ? p : ++ p;
	}

	for (int i = 1, H = 0; i <= n; i ++) {
		if (H) H --;
		while (s[i + H] == s[SA[rk[i] - 1] + H]) H ++;
		height[rk[i]] = H;
	}

	logx[0] = -1;
	for (int i = 1; i <= n; i ++)
		logx[i] = logx[i >> 1] + 1;

	for (int i = 1; i <= n; i ++)
		f[i][0] = i;

	for (int j = 1; j <= 16; j ++)
		for (int i = 1; i + (1 << j) - 1 <= n; i ++) {
			if (height[f[i][j - 1]] <= height[f[i + (1 << (j - 1))][j - 1]])
				f[i][j] = f[i][j - 1];
			else
				f[i][j] = f[i + (1 << (j - 1))][j - 1];
		}

	for (int i = 1; i <= n; i ++)
		w[i] = temp[SA[i]];

	for (int i = 1; i <= n; i ++)
		insert(root[i], root[i - 1], 17, w[i]);

	solve(1, n);

	printf("%d\n", ans);

	return 0;
}

算法二:SAM + 0/1 trie 合并

对字符串 \(s\)反串建出后缀自动机,并求出 parent 树。 其实这个 parent 树就是原串的后缀树

因为对于后缀树上任意两点 \(x, y\),分别从 \(x, y\) 代表的子串集合中选出一个子串。那么这两个子串的 \(\text{LCP}\) 即为 \(\text{lca}(x, y)\) 代表的子串中最长的串。

所以可以考虑在这个 parent 树上走,设当前走到的点为 \(u\)。此时在 \(u\)两个不同的子树内取一个后缀 \(i\) 和一个后缀 \(j\),那么一定有 \(\text{LCP}(i,j) = \text{Longest}(u)\)

此时贡献表达式的第一个参数已经确定下来,那么就是要分别在 \(u\) 的两个不同的子树内取一个后缀,使得这两个后缀的权值异或和最大。

对于一个点 \(u\),不妨设它只有两个子树(多个子树的情况,可以考虑将子树两两合并)。我们选择大小较小的一棵子树,穷举该子树内的每一个后缀,求出该后缀与另一个子树中的每个后缀的异或最大值来更新答案,使用 0/1 trie 合并维护即可。

注意到每枚举一棵子树,总子树的大小就至少乘 \(2\),这本质上是一个启发式合并。每个点至多被枚举了 \(\log_2 n\) 次。故总时间复杂度为 \(\mathcal{O}(n \log^2 n)\)

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>

using namespace std;

const int N = 100100;
const int SIZE = N * 2;

int n;

char s[N];
int w[N];

vector<int> g[SIZE];

int rush;
int root[SIZE];
struct trie {
    int ch[2];
} H[N * 32];

void insert(int p, int num) {
    for (int i = 16; i >= 0; i --) {
        int v = num >> i & 1;

        if (!H[p].ch[v])
            H[p].ch[v] = ++ rush;

        p = H[p].ch[v];
    }
}

int merge(int p, int q) {
    if (!p || !q)
        return p ^ q;

    H[p].ch[0] = merge(H[p].ch[0], H[q].ch[0]);
    H[p].ch[1] = merge(H[p].ch[1], H[q].ch[1]);
    return p;
}

int ask(int p, int num) {
    int ans = 0;

    for (int i = 16; i >= 0; i --) {
        int v = num >> i & 1;

        if (H[p].ch[v ^ 1])
            p = H[p].ch[v ^ 1], ans += 1 << i;
        else
            p = H[p].ch[v];
    }

    return ans;
}

int cT = 1, last = 1;
struct SAM {
    int trans[26];
    int link, len;
} t[SIZE];

void extend(int c, int val) {
    int p = last,
        Np = last = ++ cT;

    g[Np].push_back(val);
    insert(root[Np] = ++ rush, val);

    t[Np].len = t[p].len + 1;

    for (; p && t[p].trans[c] == 0; p = t[p].link)
        t[p].trans[c] = Np;

    if (!p)
        t[Np].link = 1;
    else {
        int q = t[p].trans[c];

        if (t[q].len == t[p].len + 1)
            t[Np].link = q;
        else {
            int Nq = ++ cT;

            t[Nq] = t[q], t[Nq].len = t[p].len + 1;

            t[Np].link = t[q].link = Nq;

            for (; p && t[p].trans[c] == q; p = t[p].link)
                t[p].trans[c] = Nq;
        }
    }
}

int tot, head[SIZE], ver[SIZE], Next[SIZE];

void addedge(int u, int v) {
    ver[++ tot] = v;
    Next[tot] = head[u];
    head[u] = tot;
}

int ans;

void dfs(int u) {
    for (int i = head[u]; i; i = Next[i]) {
        int v = ver[i];

        dfs(v);

        if (g[u].size() < g[v].size())
            swap(g[u], g[v]), swap(root[u], root[v]);

        for (int j = 0; j < (int)g[v].size(); j ++) {
            int num = g[v][j];
            g[u].push_back(num);
            ans = max(ans, t[u].len + ask(root[u], num));
        }

        root[u] = merge(root[u], root[v]);
    }
}

int main() {
    scanf("%d", &n);

    scanf("%s", s + 1);

    for (int i = 1; i <= n; i ++)
        scanf("%d", &w[i]);

    for (int i = n; i >= 1; i --)
        extend(s[i] - 'a', w[i]);

    for (int i = 2; i <= cT; i ++)
        addedge(t[i].link, i);

    dfs(1);

    printf("%d\n", ans);

    return 0;
}
posted @ 2021-04-05 09:57  Calculatelove  阅读(119)  评论(0编辑  收藏  举报