loj3326.「SNOI2020」字符串(sa + 并查集)

loj3326.「SNOI2020」字符串

\(Description\)

给定两个长度为 \(n\) 的小写字符串 \(a, \ b\),求出他们所有长为 \(k\) 的子串,分别组成集合 \(A, \ B\),每次可以修改 \(A\) 中一个元素的后缀,费用为后缀的长度,求将 \(A\) 修改成 \(B\) 的最小费用之和。

\(Data \ Constraint\)

\(1 \leq k, \ n \leq 1.5 \times 10^5\)

考点

\(sa\),并查集,广义 \(sam\)

\(Solution\)

本题有多种解法,本人写的是 \(sa\) + 并查集做法(比较慢唔)。

答案可以转化为 \(k (n - k + 1) - \sum{\text{匹配元素的}lcp}\)

有一个比较显然的贪心是每次我们从 \(A, \ B\) 中选出一对 \(lcp\) 最大的元素配对。

证明也很简单:考虑我们当前选了一对 \(lcp\) 最大的元素,显然我们无法找到另外一对元素使得这两对元素交叉匹配的 \(lcp\) 之和更大。

\(Method1\)

\(a, \ b\) 拼接起来建 \(sa\),从大到小枚举 \(height\),并查集维护块及块中未配对后缀的个数,每次合并两个块,若两块中未配对的后缀来自不同字符串则计算对答案的贡献,然后合并为配对的后缀。

\(Method2\)

直接建出广义 \(sam\),简单树上 \(dp\) 计算即可。

\(Code(sa)\)

#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

#define N 600000

#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define fd(i, x, y) for(int i = x; i >= y; i --)
#define Mec(a, b) memcpy(a, b, sizeof b)

int sa[N + 1], rk[N + 1], oldrk[N + 1], buc[N + 1], px[N + 1], id[N + 1], ht[N + 1], c[N + 1];

char a[N + 1], b[N + 1];

#define ll long long

int fa[N + 1];

ll sz[N + 1];

struct Arr { int x, y; } d[N + 1];

int n, m, m1 = 26;

void Sort() {
    fill(buc, buc + N, 0);
    fo(i, 1, n) ++ buc[ px[i] = rk[id[i]] ];
    fo(i, 1, m1) buc[i] += buc[i - 1];
    fd(i, n, 1) sa[ buc[px[i]] -- ] = id[i];
}

bool Cmp(int x, int y, int w) { return oldrk[x] == oldrk[y] && oldrk[x + w] == oldrk[y + w]; }

void Get_sa() {
    fo(i, 1, n) c[i] = a[i] - 'a' + 1;
    fo(i, 1, n) c[i + n] = b[i] - 'a' + 1;
    n = (n << 1);
    fo(i, 1, n) rk[ id[i] = i ] = c[i];
    Sort();
    for (int w = 1, p = 0; w <= n; w <<= 1) {
        p = 0;
        fo(i, n - w + 1, n) id[ ++ p ] = i;
        fo(i, 1, n) if (sa[i] > w) id[ ++ p ] = sa[i] - w;
        Sort();
        Mec(oldrk, rk), p = 0;
        fo(i, 1, n)
            rk[sa[i]] = Cmp(sa[i], sa[i - 1], w) ? p : ++ p;
        m1 = p;
    }
    int k = 0;
    fo(i, 1, n) {
        if (k) k --;
        while (c[i + k] == c[sa[rk[i] - 1] + k]) ++ k;
        ht[rk[i]] = k;
    }
}

bool cmp(Arr a, Arr b) { return a.x < b.x; }

int Getf(int u) { return u == fa[u] ? u : fa[u] = Getf(fa[u]); }

bool Pd(int x) { return (x <= (n >> 1) && (n >> 1) - x + 1 >= m) || (x > (n >> 1) && n - x + 1 >= m); } 

int Abs(int x) { return x < 0 ? -x : x; }

int main() {
    scanf("%d %d\n", &n, &m);
    scanf("%s\n%s\n", a + 1, b + 1);

    Get_sa();

    fo(i, 1, n) fa[i] = i, sz[i] = Pd(sa[i]) ? (sa[i] <= (n >> 1) ? 1 : -1) : 0;
    int tot = 0;
    fo(i, 2, n)
        d[ ++ tot ] = (Arr) { ht[i], i };
    sort(d + 1, d + 1 + tot, cmp);
    ll ans = 0;
    fd(i, (n >> 1), 1) {
        while (d[tot].x == i) {
            int u = Getf(d[tot].y), v = Getf(d[tot].y - 1);
            if (sz[u] * sz[v] < 0)
                ans += min(Abs(sz[u]), Abs(sz[v])) * min(i, m);
            fa[v] = u;
            sz[u] += sz[v];
            -- tot;
        }
    }
    printf("%lld\n", 1ll * m * ((n >> 1) - m + 1) - ans);

    return 0;
}

posted @ 2020-12-24 20:30  buzzhou  阅读(102)  评论(0编辑  收藏  举报