Loading

【题解】P4482 [BJWC2018]Border 的四种求法

思路

SAM + 树剖。

好仙的题啊,做了一天。

\(\operatorname{lcs}(i, j)\) 表示长度为 \(i, j\) 的前缀的最长公共后缀长度,则题目中的 border 可以等价转化成:求最大且满足:

  1. \(l \leq p \leq r\)

  2. \(\operatorname{lcs}(p, r) \geq p - l\)

其中 \(\operatorname{lcs}(p, r)\) 正是它们对应结点在后缀树上 LCA 的 \(len\) 值。

再加上字符串区间查询,可以想象到有一个高妙的 SAM 做法。

首先考虑一下暴力做法。

假设我们把问题直接暴力挂在从根到 SAM 中结点的路径上。

从其中一个结点一直向上跳枚举 LCA,然后找到与该结点 LCA 为当前点的所有结点,尝试更新答案。

当我们跳到一个结点时,相当于确定了 \(\operatorname{lcs}\) 的长度,所以对于当前点,我们只需要查询最大的 \(p \leq l + len - 1\) 的合法的 \(p\).

这个可以线段树合并直接维护。

然后接下来考虑优化上面那个做法。

首先确定问题转化成:

每次假设已经确定树上的一个结点 \(u\) 和常数 \(x\),其中 \(x = l + len - 1\)。现在要求出所有满足 \(\operatorname{lca}(u, v) = k\) 并且 \(v \leq l + len_k - 1\) 的编号最大的 \(v\),其中 \(k\) 是当前枚举的 LCA.

这里有一个比较仙的想法:把询问划分到 \(O(\log n)\) 条重链上。

换言之对于每条重链上的每一个点,考虑它的子树对它的贡献。

假设从结点 \(u\) 进入当前重链,重链从浅到深依次为 \(p_1, ..., p_m\).

对于比 \(u\) 浅的部分,它们和 \(u\) 的 LCA 为其自身;对于比 \(u\) 深的部分,它们和 \(u\) 的 LCA 为 \(u\).

可以预处理出 \(dis(i)\) 表示点 \(i\) 与当前跳到的结点在树上 LCA 处的 \(len\) 值。

假设现在要考虑结点 \(u\) 处的询问,\(u\)\(p\) 中的下标为 \(k\). 令 \(subt(u)\) 为结点 \(u\) 的所有轻子树,则答案可以划分成两部分贡献:

  1. \(\max_{i = 1}^{k - 1} \max\limits_{v \in subt(p_i)} [l \leq v \leq r] [v \leq l + dis(v) - 1]\)

  2. \(\max_{i = k}^{m} \max\limits_{v \in subt(p_i)} [l \leq v \leq r] [v \leq l + dis(u) - 1]\)

上面的划分实际上等价于划分成:

  1. 重链链首的子树中除点 \(u\) 的重子树外的部分

  2. 结点 \(u\) 的重子树

可以考虑对整个重链上下跑两次扫描线再线段树合并维护。

对于第一个限制可以直接线段树区间查询。

对于第二个限制可以考虑把所有 \(v\) 减去 \(dis(u)\) 再扔进线段树。

\(O(\log n)\) 条重链的低端进入,每条重链暴力遍历的复杂度是 \(O(n \log n)\),所以时间复杂度是 \(O(n \log^2 n)\)

注意到询问只需要挂在重链的最底端即可,所以空间复杂度是 \(O(n \log n)\)

代码

#include <cstdio>
#include <cmath>
#include <cstring>
// #include <vector>
#include <algorithm>
// using namespace std;

namespace IO
{
    //by cyffff
	int len = 0;
	char ibuf[(1 << 20) + 1], *iS, *iT, out[(1 << 26) + 1];
	#define gh() (iS == iT ? iT = (iS = ibuf) + fread(ibuf, 1, (1 << 20) + 1, stdin), (iS == iT ? EOF : *iS++) : *iS++)
	#define reg register

	inline int read()
    {
		reg char ch = gh();
		reg int x = 0;
		reg char t = 0;
		while (ch < '0' || ch > '9') t |= ch == '-', ch = gh();
		while (ch >= '0' && ch <= '9') x = x * 10 + (ch ^ 48), ch = gh();
		return t ? -x : x;
	}

	inline void putc(char ch) { out[len++] = ch; }

	template<class T>

	inline void write(T x)
    {
		if (x < 0) putc('-'), x = -x;
		if (x > 9) write(x / 10);
		out[len++] = x % 10 + 48;
	}

	inline void flush()
    {
		fwrite(out, 1, len, stdout);
		len = 0;
	}
}
using IO::read;
using IO::write;
using IO::flush;
using IO::putc;

#define reg
typedef long long ll;

const int maxn = 3e5 + 1;
const int maxm = 3e5 + 1;
const int sz = 2e3 + 1;

struct Modify
{
    int t, pos, val;
} q1[sz];

struct Query
{
    int t, l, r, x, idx;
} q2[sz];

struct node
{
    int idx, res, t[4];
    bool flag;

    inline void clear() { t[0] = t[1] = t[2] = t[3] = flag = 0; }
} stk[sz], tmp;

struct Edge
{
    int to, nxt;
} edge[maxn];

int n, m, sqn, sqm, tot, c1, c2, top, ecnt;
int st[sz], ed[sz];
int head[maxn], a[maxn], bel[maxn], pos[maxn];
bool vis[maxn], chg[maxn], used[maxn], seq[maxn];
ll gt[maxn];
ll ans[sz], sum[sz];
// vector<int> idx[maxn];

inline int min(reg const int &a, reg const int &b) { return (a <= b ? a : b); }
inline bool cmp(const Query& x, const Query& y) { return (x.x < y.x); }

inline void modify(reg const int& p, reg bool flag)
{
    pos[p] = p, seq[p] = true, tmp.idx = p;
    reg const int lst = p - 1, nxt = p + 1, tmp_p = pos[lst];
    reg const bool flag1 = (seq[lst] && (st[bel[p]] != p)), flag2 = (seq[nxt] && (ed[bel[p]] != p));
    if ((!flag1) && (!flag2)) tmp.clear(), tmp.res = 1;
    else
    {
        tmp.flag = true;
        if (flag1 && flag2)
        {
            tmp.res = (nxt - pos[lst]) * (pos[nxt] - lst);
            tmp.t[0] = pos[lst], tmp.t[1] = pos[pos[lst]], pos[pos[lst]] = pos[nxt];
            tmp.t[2] = pos[nxt], tmp.t[3] = pos[pos[nxt]], pos[pos[nxt]] = tmp_p;
        }
        else if(flag1)
        {
            tmp.res = nxt - pos[lst];
            tmp.t[0] = p, tmp.t[1] = pos[p], pos[p] = pos[lst];
            tmp.t[2] = pos[lst], tmp.t[3] = pos[pos[lst]], pos[pos[lst]] = p;
        }
        else
        {
            tmp.res = pos[nxt] - lst;
            tmp.t[0] = p, tmp.t[1] = pos[p], pos[p] = pos[nxt];
            tmp.t[2] = pos[nxt], tmp.t[3] = pos[pos[nxt]], pos[pos[nxt]] = p;
        }
    }
    sum[bel[p]] += tmp.res;
    if (flag) stk[++top] = tmp;
}

inline ll query(const int& l, const int& r)
{
    reg ll ans = 0;
    if (bel[l] == bel[r])
    {
        reg int cnt = 0;
        for (reg int i = l; i <= r; i++)
            if (seq[i]) cnt++;
            else ans += gt[cnt], cnt = 0;
        return ans + gt[cnt];
    }
    reg int cnt1 = 0, cnt2 = 0;
    for (reg int i = l; i <= ed[bel[l]]; i++)
        if (seq[i]) cnt1++;
        else ans += gt[cnt1], cnt1 = 0;
    for (reg int i = r; i >= st[bel[r]]; i--)
        if (seq[i]) cnt2++;
        else ans += gt[cnt2], cnt2 = 0;
    reg int res = cnt1;
    for (reg int i = bel[l] + 1; i <= bel[r] - 1; i++)
    {
        if (pos[st[i]] == ed[i]) res += ed[i] - st[i] + 1;
        else
        {
            if (seq[st[i]]) res += pos[st[i]] - st[i] + 1, ans -= gt[pos[st[i]] - st[i] + 1];
            ans += gt[res] + sum[i], res = 0;
            if (seq[ed[i]]) res += ed[i] - pos[ed[i]] + 1, ans -= gt[ed[i] - pos[ed[i]] + 1];
        }
    }
    return ans + gt[res + cnt2];
}

inline void add_edge(const int& u, const int& v)
{
    edge[++ecnt] = (Edge){v, head[u]};
    head[u] = ecnt;
}

inline void solve()
{
    memset(seq, false, sizeof(seq));
    memset(pos, 0, sizeof(pos));
    memset(sum, 0, sizeof(sum));
    for (reg int i = 1; i <= c1; i++) chg[q1[i].pos] = true;
    for (reg int i = 1; i <= n; i++)
        if (!chg[i]) add_edge(a[i], i);
    std::sort(q2 + 1, q2 + c2 + 1, cmp);
    reg int lim = 1;
    for (reg int i = 1; i <= c2; i++)
    {
        while (lim <= q2[i].x)
        {
            for (reg int& j = head[lim]; j; j = edge[j].nxt) modify(edge[j].to, 0);
            lim++;
        }
        for (reg int j = c1; j >= 1; j--)
            if ((q1[j].t < q2[i].t) && (!used[q1[j].pos]))
            {
                used[q1[j].pos] = true;
                if (q1[j].val <= q2[i].x) modify(q1[j].pos, 1);
            }
        for (reg int j = 1; j <= c1; j++)
            if (!used[q1[j].pos])
            {
                used[q1[j].pos] = true;
                if (a[q1[j].pos] <= q2[i].x) modify(q1[j].pos, 1);
            }
        ans[q2[i].idx] = query(q2[i].l, q2[i].r);
        while (top)
        {
            tmp = stk[top--], sum[bel[tmp.idx]] -= tmp.res, seq[tmp.idx] = false;
            if (tmp.flag) pos[tmp.t[2]] = tmp.t[3], pos[tmp.t[0]] = tmp.t[1];
        }
        for (int j = 1; j <= c1; j++) used[q1[j].pos] = false;
    }
    ecnt = 0;
    memset(head, 0, sizeof(head));
    // for (int i = 0; i <= n; i++) idx[i].clear();
    for (int i = 1; i <= c1; i++) chg[q1[i].pos] = false;
}

int main()
{
    n = read(), m = read(), sqn = 516, sqm = 1821;
    tot = (n + sqn - 1) / sqn;
    for (reg int i = 1; i <= tot; i++) st[i] = ed[i - 1] + 1, ed[i] = (i == tot ? n : i * sqn);
    for (reg int i = 1; i <= n; i++) a[i] = read(), bel[i] = (i - 1) / sqn + 1, gt[i] = 1ll * i * (i + 1) / 2;
    for (reg int i = 1, j; i <= m; i = j + 1)
    {
        j = min(m, i + sqm), c1 = c2 = 0;
        for (reg int k = i; k <= j; k++)
            if (read() == 1) q1[++c1] = (Modify){k, read(), read()};
            else q2[++c2] = (Query){k, read(), read(), read(), c2};
        solve();
        for (reg int k = 1; k <= c2; k++) write(ans[k]), putc('\n');
        for (reg int k = 1; k <= c1; k++) a[q1[k].pos] = q1[k].val;
    }
    flush();
    return 0;
}
posted @ 2023-01-17 23:13  kymru  阅读(55)  评论(0编辑  收藏  举报