[The 3rd Ucup. Stage 10 West Lake] Generated String

题意

维护一个字符串集合,支持动态插入,动态删除,查询同时具有前缀 \(s_1\) 与后缀 \(s_2\) 的串的个数,所有字符串用如下方式给出:先给定一个全局模板串 \(S\),每一个字符串都是 \(S\) 的若干个下标区间对应的字符串拼接而成的。即给出若干个区间 \([l_1,r_1],[l_2,r_2],\dots,[l_k,r_k]\),描述字符串 \(S[l_1,r_1]+S[l_2,r_2]+\dots+S[l_k,r_k]\)

做法

小清新字符串,确信。如果字符串给定方式不特殊的话,那么利用 trie 建出所有串的前后缀关系树,动态二维数点一下就是答案了。

现在字符串是模板串的若干个区间拼接成的,那么我们该如何建出前后缀关系构成的树呢?考虑不是像 trie 一样一个一个插入串,而是分治建树一下子把树都建好。

比如说我们要建一个大小为串的总数级别的压缩 trie,我们可以先求出所有串的 lcp,扣掉所有串长为 lcp 的前缀之后,按照第一个字符分类,每一类递归建树置成当前节点的儿子。这样每次至少会划开一次集合,一共只会划开不超过串的总数次。

可是现在,求出所有串的 lcp 复杂度是不太靠谱的,因为我们求一次 lcp 需要遍历多个字符串片段,这样的话,全局 lcp 衰减,之前遍历的一些片段就算不进势能里了,怎么办呢?

考虑取出第一个片段最长的一个字符串,然后其它所有字符串跟这个字符串的第一个片段取 lcp,这样的话如果全局 lcp 比这个片段还长,直接砍掉这个片段长度的前缀,每个串至少会砍掉一个片段,片段减少量就是我们此次耗费的复杂度。否则我们就可以按照与这一个片段的 lcp 长度分类,每一类建一个点,从小到大连成一条链,接下来再按第一个字符分类给链下面建子树。

可以发现上述只需要支持模板串快速 lcp 就行了,后缀数组甚至二分+哈希都可以接受,因为瓶颈还是在动态二维数点。

代码经过 clang-format 不压行后仅有短短 7.9K。

#include <algorithm>
#include <cassert>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <map>
#include <vector>
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef vector<pii> vpii;
int read() {
    char c = getchar();
    while (c < 48 || c > 57)
        c = getchar();
    int x = 0;
    do
        x = x * 10 + (c ^ 48), c = getchar();
    while (c >= 48 && c <= 57);
    return x;
}
const int N = 100003, T = 600003;
int n, q;
char s[N];
namespace SA {
int buc[N], rk[N], sa[N], od[N], id[N], ht[17][N], w, p;
bool eq(int x, int y) {
    return od[x] == od[y] && od[x + w] == od[y + w];
}
void getSA(int m) {
    for (int i = 1; i <= n; ++i)
        ++buc[rk[i] = s[i]];
    for (int i = 1; i <= m; ++i)
        buc[i] += buc[i - 1];
    for (int i = n; i; --i)
        sa[buc[rk[i]]--] = i;
    for (int i = 1; i <= m; ++i)
        buc[i] = 0;
    w = 1;
    p = 0;
    while (true) {
        for (int i = n; i > n - w; --i)
            id[++p] = i;
        for (int i = 1; i <= n; ++i)
            if (sa[i] > w)
                id[++p] = sa[i] - w;
        for (int i = 1; i <= n; ++i)
            ++buc[od[i] = rk[i]];
        for (int i = 1; i <= m; ++i)
            buc[i] += buc[i - 1];
        for (int i = n; i; --i)
            sa[buc[rk[id[i]]]--] = id[i];
        for (int i = 1; i <= m; ++i)
            buc[i] = 0;
        rk[sa[1]] = p = 1;
        for (int i = 2; i <= n; ++i) {
            if (!eq(sa[i], sa[i - 1]))
                ++p;
            rk[sa[i]] = p;
        }
        if (p == n)
            break;
        w <<= 1, m = p, p = 0;
    }
}
void getLCP() {
    s[n + 1] = '!';
    for (int i = 1, k = 0; i <= n; ++i) {
        if (k)
            --k;
        if (rk[i] == 1)
            continue;
        while (s[i + k] == s[sa[rk[i] - 1] + k])
            ++k;
        ht[0][rk[i]] = k;
    }
    for (int t = 1; t < 17; ++t)
        for (int i = 2; i + (1 << t) - 1 <= n; ++i)
            ht[t][i] = min(ht[t - 1][i], ht[t - 1][i + (1 << (t - 1))]);
}
int lcp(int x, int y) {
    if (x == y)
        return n - x + 1;
    x = rk[x], y = rk[y];
    if (x > y)
        swap(x, y);
    int k = __lg(y - x);
    return min(ht[k][x + 1], ht[k][y - (1 << k) + 1]);
}
} // namespace SA
vpii v1[N], v2[N];
int id1[N], lef1[N], rig1[N], num1;
int id2[N], lef2[N], rig2[N], num2;
int del[N], res[N];
char op[N];
int cnt;
vi bel[T], sn[T];
void build(vi, int);
void split(vi vec, int fa) {
    if (vec.empty())
        return;
    vector<int> vc[26];
    for (int x : vec)
        vc[s[v1[x].back().fi] - 'a'].emplace_back(x);
    for (int c = 0; c < 26; ++c)
        build(vc[c], fa);
}
void build(vi vec, int fa) {
    for (int x : vec)
        assert(!v1[x].empty());
    if (vec.empty())
        return;
    if (vec.size() == 1u) {
        int p = ++cnt;
        sn[fa].emplace_back(p);
        bel[p].emplace_back(vec.back());
        return;
    }
    int mx = 0, ps = 0;
    for (int x : vec) {
        int len = v1[x].back().se - v1[x].back().fi + 1;
        if (len > mx)
            mx = len, ps = x;
    }
    map<int, vi> mp;
    for (int x : vec) {
        if (x == ps)
            continue;
        int len = 0;
        int cur = v1[ps].back().fi;
        while (!v1[x].empty()) {
            auto &[l, r] = v1[x].back();
            int tmp = min(SA::lcp(l, cur), v1[ps].back().se - cur + 1);
            if (tmp > r - l) {
                len += r - l + 1;
                cur += r - l + 1;
                v1[x].pop_back();
            } else {
                len += tmp;
                l += tmp;
                break;
            }
        }
        mp[len].emplace_back(x);
    }
    int mxlen = prev(mp.end())->fi;
    while (!v1[ps].empty()) {
        auto &[l, r] = v1[ps].back();
        if (mxlen > r - l) {
            mxlen -= r - l + 1;
            v1[ps].pop_back();
        } else {
            l += mxlen;
            break;
        }
    }
    prev(mp.end())->se.emplace_back(ps);
    for (auto [lcplen, vc] : mp) {
        int p = ++cnt;
        sn[fa].emplace_back(p);
        vi tmp;
        for (int x : vc) {
            if (v1[x].empty())
                bel[p].emplace_back(x);
            else
                tmp.emplace_back(x);
        }
        split(tmp, fa = p);
    }
}
void dfs(int u) {
    for (int x : bel[u])
        if (op[x] == '?')
            lef1[x] = num1;
    for (int x : bel[u])
        if (op[x] == '+')
            id1[x] = ++num1;
    for (int v : sn[u])
        dfs(v);
    for (int x : bel[u])
        if (op[x] == '?')
            rig1[x] = num1;
}
struct node {
    int v, x, y, id;
    friend bool operator<(const node a, const node b) {
        if (a.x ^ b.x)
            return a.x < b.x;
        return a.id < b.id;
    }
} o[N << 2];
int tr[N];
void upd(int x, int v) {
    while (x <= num2)
        tr[x] += v, x += (x & -x);
}
int qry(int x) {
    int res = 0;
    while (x)
        res += tr[x], x ^= (x & -x);
    return res;
}
void solve(int l, int r) {
    if (l == r)
        return;
    int mid = (l + r) >> 1;
    solve(l, mid);
    solve(mid + 1, r);
    int ed = 0;
    for (int i = l; i <= mid; ++i) {
        if (op[i] == '+')
            o[++ed] = (node){1, id1[i], id2[i], 0};
        if (op[i] == '-')
            o[++ed] = (node){-1, id1[del[i]], id2[del[i]], 0};
    }
    for (int i = r; i > mid; --i)
        if (op[i] == '?') {
            o[++ed] = (node){1, rig1[i], rig2[i], i};
            o[++ed] = (node){1, lef1[i], lef2[i], i};
            o[++ed] = (node){-1, lef1[i], rig2[i], i};
            o[++ed] = (node){-1, rig1[i], lef2[i], i};
        }
    sort(o + 1, o + ed + 1);
    for (int i = 1; i <= ed; ++i) {
        if (o[i].id)
            res[o[i].id] += qry(o[i].y) * o[i].v;
        else
            upd(o[i].y, o[i].v);
    }
    for (int i = 1; i <= ed; ++i)
        if (!o[i].id)
            upd(o[i].y, -o[i].v);
}
int main() {
    n = read();
    q = read();
    char cc = getchar();
    while (isspace(cc))
        cc = getchar();
    for (int i = 1; i <= n; ++i)
        s[i] = cc, cc = getchar();
    SA::getSA(123);
    SA::getLCP();
    vi init;
    for (int i = 1; i <= q; ++i) {
        cc = getchar();
        while (isspace(cc))
            cc = getchar();
        op[i] = cc;
        if (cc == '+') {
            int len = read();
            v1[i].resize(len);
            v2[i].resize(len);
            for (int t = 0; t < len; ++t) {
                int l = read(), r = read();
                v1[i][len - 1 - t].fi = l;
                v1[i][len - 1 - t].se = r;
                v2[i][t].fi = n - r + 1;
                v2[i][t].se = n - l + 1;
            }
            init.emplace_back(i);
        }
        if (cc == '-')
            del[i] = read();
        if (cc == '?') {
            int len1 = read();
            v1[i].resize(len1);
            for (int t = len1 - 1; ~t; --t) {
                v1[i][t].fi = read();
                v1[i][t].se = read();
            }
            int len2 = read();
            v2[i].resize(len2);
            for (int t = 0; t < len2; ++t) {
                v2[i][t].se = n - read() + 1;
                v2[i][t].fi = n - read() + 1;
            }
            init.emplace_back(i);
        }
    }
    split(init, 0);
    dfs(0);
    for (int i = 0; i <= cnt; ++i)
        bel[i].clear(), sn[i].clear();
    cnt = 0;
    for (int i = 1; i <= q; ++i) {
        swap(id1[i], id2[i]);
        swap(lef1[i], lef2[i]);
        swap(rig1[i], rig2[i]);
        v1[i].swap(v2[i]);
    }
    swap(num1, num2);
    reverse(s + 1, s + n + 1);
    SA::getSA(123);
    SA::getLCP();
    build(init, 0);
    dfs(0);
    solve(1, q);
    for (int i = 1; i <= q; ++i)
        if (op[i] == '?')
            printf("%d\n", res[i]);
    return 0;
}
posted @ 2024-10-11 21:27  yyyyxh  阅读(94)  评论(0编辑  收藏  举报