loj3326.「SNOI2020」字符串(sa + 并查集)
loj3326.「SNOI2020」字符串
\(Description\)
给定两个长度为 \(n\) 的小写字符串 \(a, \ b\),求出他们所有长为 \(k\) 的子串,分别组成集合 \(A, \ B\),每次可以修改 \(A\) 中一个元素的后缀,费用为后缀的长度,求将 \(A\) 修改成 \(B\) 的最小费用之和。
\(Data \ Constraint\)
\(1 \leq k, \ n \leq 1.5 \times 10^5\)。
考点
\(sa\),并查集,广义 \(sam\)。
\(Solution\)
本题有多种解法,本人写的是 \(sa\) + 并查集做法(比较慢唔)。
答案可以转化为 \(k (n - k + 1) - \sum{\text{匹配元素的}lcp}\),
有一个比较显然的贪心是每次我们从 \(A, \ B\) 中选出一对 \(lcp\) 最大的元素配对。
证明也很简单:考虑我们当前选了一对 \(lcp\) 最大的元素,显然我们无法找到另外一对元素使得这两对元素交叉匹配的 \(lcp\) 之和更大。
\(Method1\)
将 \(a, \ b\) 拼接起来建 \(sa\),从大到小枚举 \(height\),并查集维护块及块中未配对后缀的个数,每次合并两个块,若两块中未配对的后缀来自不同字符串则计算对答案的贡献,然后合并为配对的后缀。
\(Method2\)
直接建出广义 \(sam\),简单树上 \(dp\) 计算即可。
\(Code(sa)\)
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define N 600000
#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define fd(i, x, y) for(int i = x; i >= y; i --)
#define Mec(a, b) memcpy(a, b, sizeof b)
int sa[N + 1], rk[N + 1], oldrk[N + 1], buc[N + 1], px[N + 1], id[N + 1], ht[N + 1], c[N + 1];
char a[N + 1], b[N + 1];
#define ll long long
int fa[N + 1];
ll sz[N + 1];
struct Arr { int x, y; } d[N + 1];
int n, m, m1 = 26;
void Sort() {
fill(buc, buc + N, 0);
fo(i, 1, n) ++ buc[ px[i] = rk[id[i]] ];
fo(i, 1, m1) buc[i] += buc[i - 1];
fd(i, n, 1) sa[ buc[px[i]] -- ] = id[i];
}
bool Cmp(int x, int y, int w) { return oldrk[x] == oldrk[y] && oldrk[x + w] == oldrk[y + w]; }
void Get_sa() {
fo(i, 1, n) c[i] = a[i] - 'a' + 1;
fo(i, 1, n) c[i + n] = b[i] - 'a' + 1;
n = (n << 1);
fo(i, 1, n) rk[ id[i] = i ] = c[i];
Sort();
for (int w = 1, p = 0; w <= n; w <<= 1) {
p = 0;
fo(i, n - w + 1, n) id[ ++ p ] = i;
fo(i, 1, n) if (sa[i] > w) id[ ++ p ] = sa[i] - w;
Sort();
Mec(oldrk, rk), p = 0;
fo(i, 1, n)
rk[sa[i]] = Cmp(sa[i], sa[i - 1], w) ? p : ++ p;
m1 = p;
}
int k = 0;
fo(i, 1, n) {
if (k) k --;
while (c[i + k] == c[sa[rk[i] - 1] + k]) ++ k;
ht[rk[i]] = k;
}
}
bool cmp(Arr a, Arr b) { return a.x < b.x; }
int Getf(int u) { return u == fa[u] ? u : fa[u] = Getf(fa[u]); }
bool Pd(int x) { return (x <= (n >> 1) && (n >> 1) - x + 1 >= m) || (x > (n >> 1) && n - x + 1 >= m); }
int Abs(int x) { return x < 0 ? -x : x; }
int main() {
scanf("%d %d\n", &n, &m);
scanf("%s\n%s\n", a + 1, b + 1);
Get_sa();
fo(i, 1, n) fa[i] = i, sz[i] = Pd(sa[i]) ? (sa[i] <= (n >> 1) ? 1 : -1) : 0;
int tot = 0;
fo(i, 2, n)
d[ ++ tot ] = (Arr) { ht[i], i };
sort(d + 1, d + 1 + tot, cmp);
ll ans = 0;
fd(i, (n >> 1), 1) {
while (d[tot].x == i) {
int u = Getf(d[tot].y), v = Getf(d[tot].y - 1);
if (sz[u] * sz[v] < 0)
ans += min(Abs(sz[u]), Abs(sz[v])) * min(i, m);
fa[v] = u;
sz[u] += sz[v];
-- tot;
}
}
printf("%lld\n", 1ll * m * ((n >> 1) - m + 1) - ans);
return 0;
}