【Coel.学习笔记】莫队(下)- 树上莫队和二次离线莫队

树上莫队和二次离线莫队都比较难,所以只讲几个模板(明明是你太懒了)

树上莫队

我们之前处理的问题都是在数列上的,如果换成树,怎么办呢?下面这题给出了一个常用的方法。

SP10707 COT2 - Count on a tree II

洛谷传送门
给定一棵点带权的树,静态询问每两个节点之间(包括端点)路径上的不同权值数。

解析:这道题看起来和上一篇的普通莫队题很像,只不过从序列变成了树。

对于树上问题,有一种通法把它变成序列问题:构造一个欧拉序列。在深度优先遍历的时候,对于每个点到达时和离开时都记录下结果,这时每个节点都会出现两次。我们开设两个数组 \(fir_u\)\(lst_u\),分别记录下 \(u\) 第一次和第二次出现的位置。对于端点 \(x,y\),如果 \(fir_x<fir_y\),判断 \(x,y\) 的最近公共祖先,若为 \(x\),则路径对应到欧拉序列上的区间为 \([fir_x,fir_y]\) 中只出现一次的数字,若不为 \(x\) 则路径对应到欧拉序列上的区间为 \([lst_x,fir_y]\) 中只出现一次的数字,以及 \(x,y\) 的最近公共祖先。经过这样麻烦至极的操作,我们把树上问题变成了序列问题。

但是这么做直接套莫队还是很麻烦,我们要对维护的数组微调一下。维护一个数组 \(vis_u\) 表示数字 \(u\) 在当前区间内出现次数,\(0\) 表示两次 \(1\) 表示一次。给莫队插入时,我们利用异或操作,给 \(vis_u\) 异或上 \(1\)。若异或后值为 \(0\),则相当于普通莫队的删除,反之为插入。显然插入一个数和删除一个数在这个定义下是等价的,所以插入和删除直接利用同一个函数即可。现在插入和删除均为 \(O(1)\),可以直接套莫队解决了。另外本题权值的值域很大,需要做离散化。

虽然看起来很模板,但使用到的算法相当多(离散化,欧拉序列,倍增求 LCA,莫队),还是比较难写的。当本题要求强制在线的时候,有一种名为“树分块”的扩展解法,在此不做详述,感兴趣的读者可以自行查阅资料。

#include <algorithm>
#include <cmath>
#include <cstring>
#include <iostream>
#include <queue>
#include <vector>

#define get(x) (x) / len

using namespace std;

const int maxn = 1e5 + 10;

int n, m, len, a[maxn];
int head[maxn], nxt[maxn], to[maxn], tot;
int dep[maxn], anc[maxn][20];
int euler[maxn], fir[maxn], lst[maxn], idx;
int cnt[maxn], ans[maxn];
bool vis[maxn];
vector<int> rec;

struct node {
    int l, r, id, p;
    node(int id = 0, int l = 0, int r = 0, int p = 0)
        : l(l), r(r), id(id), p(p) {}
    inline bool operator<(const node &x) const {
        int i = get(l), j = get(x.l);
        if (i != j) return i < j;
        return r < x.r;
    }
} q[maxn];

void add(int u, int v) { nxt[tot] = head[u], to[tot] = v, head[u] = tot++; }

inline void init_hash() {
    sort(rec.begin(), rec.end());
    rec.erase(unique(rec.begin(), rec.end()), rec.end());
    for (int i = 1; i <= n; i++)
        a[i] = lower_bound(rec.begin(), rec.end(), a[i]) - rec.begin();
}

void dfs_euler(int u, int fa) {
    euler[++idx] = u;
    fir[u] = idx;
    for (int i = head[u]; ~i; i = nxt[i]) {
        int v = to[i];
        if (v != fa) dfs_euler(v, u);
    }
    euler[++idx] = u;
    lst[u] = idx;
}

void bfs_lca() {
    queue<int> Q;
    memset(dep, 0x3f, sizeof dep);
    dep[0] = 0, dep[1] = 1;
    Q.push(1);
    while (!Q.empty()) {
        int u = Q.front();
        Q.pop();
        for (int i = head[u]; ~i; i = nxt[i]) {
            int v = to[i];
            if (dep[v] > dep[u] + 1) {
                dep[v] = dep[u] + 1;
                anc[v][0] = u;
                for (int k = 1; k <= 15; k++)
                    anc[v][k] = anc[anc[v][k - 1]][k - 1];
                Q.push(v);
            }
        }
    }
}

int lca(int u, int v) {
    if (dep[u] < dep[v]) swap(u, v);
    for (int i = 15; i >= 0; i--)
        if (dep[anc[u][i]] >= dep[v]) u = anc[u][i];
    if (u == v) return u;
    for (int i = 15; i >= 0; i--)
        if (anc[u][i] != anc[v][i]) u = anc[u][i], v = anc[v][i];
    return anc[u][0];
}

void modify(int x, int &res) {
    vis[x] ^= 1;
    if (vis[x] == 0) {
        cnt[a[x]]--;
        if (!cnt[a[x]]) res--;
    } else {
        if (!cnt[a[x]]) res++;
        cnt[a[x]]++;
    }
}

int main(void) {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> a[i], rec.push_back(a[i]);
    memset(head, -1, sizeof(head));
    init_hash();
    for (int i = 1, u, v; i <= n - 1; i++) {
        cin >> u >> v;
        add(u, v), add(v, u);
    }
    dfs_euler(1, -1);
    bfs_lca();
    for (int i = 0; i < m; i++) {
        int u, v, l;
        cin >> u >> v;
        if (fir[u] > fir[v]) swap(u, v);
        l = lca(u, v);
        if (u == l)
            q[i] = node(i, fir[u], fir[v], 0);
        else
            q[i] = node(i, lst[u], fir[v], l);
    }
    len = sqrt(idx);
    sort(q, q + m);
    for (int i = 0, L = 1, R = 0, res = 0; i < m; i++) {
        int id = q[i].id, l = q[i].l, r = q[i].r, p = q[i].p;
        while (R < r) modify(euler[++R], res);
        while (R > r) modify(euler[R--], res);
        while (L < l) modify(euler[L++], res);
        while (L > l) modify(euler[--L], res);
        if (p) modify(p, res);
        ans[id] = res;
        if (p) modify(p, res);
    }
    for (int i = 0; i < m; i++) cout << ans[i] << '\n';
    return 0;
}

二次离线莫队

我们在使用莫队算法时,每次都会对区间端点做移动。在某些题中,询问的转化比较困难,所以我们要再做一遍离线操作进而优化。由于一共离线了两次(莫队的询问离线和优化操作的离线),所以叫做二次离线莫队。

【模板】莫队二次离线(第十四分块(前体))

洛谷传送门
给定一个序列 \(a\),静态查询 \(a_i \oplus a_j\) 的二进制表示下有 \(k\)\(1\) 的二元组 \((i,j)\) 的个数(\((i,j)\)\((j,i)\) 记为同一个二元组) 。\(\oplus\) 是指按位异或。

解析:本题为第十四分块的前体,第十四分块本体为[Ynoi2018] GOSICK,感兴趣可以做做。

假如我们已经维护好了 \([L,R]\) 区间中的配对,现在要求 \([L,R+1]\) 区间的配对,也就相当于求解 \([L,R]\)\(R+1\) 的配对个数。使用前缀和的思想,用 \(S_i\) 表示区间 \([1,i]\)\(R+1\) 配对个数,则答案等于\(S_R-S_{L-1}\)

定义两个函数:\(f(i)\) 表示 \(w_{[1,i]}\) 中与 \(w_{i+1}\) 配对个数,\(g(x)\) 表示前 \(i\) 个数中有多少个数与 \(x\) 配对。显然 \(f(i)=g(w_i+1)\),所以维护好 \(g(x)\) 就可以得到 \(f(i)\),进而求出 \(S_R\)

由于 \(g(x)\) 的取值受到 \(k\) 的限制,而 \(k\) 又很小,所以可以暴力求出。我们先求出 \([0,2^{14}-1]\) 中二进制下恰好为 \(k\)\(1\) 的数存放在数组 \(y_i\) 中,则在求解 \(g(x)\) 时要寻找与 \(w_i\) 配对的数,等价于 \(w_i \oplus x =y_i\),即要寻找所有的 \(x=y_i \oplus w_i\),每次让 \(g(y_i \oplus w_i)\) 自增,就可以得到所有的 \(g(x)\) 取值。顺带一提,求解 \(y_i\) 数组的时候可以用内建函数 __builtin_popcount,底层为查表;如果使用 C++17,还可以使用 numeric 库的 std::popcount,底层为二分。

\(S_R\) 可以在预处理 \(g(x)\)\(O(1)\) 求出,但 \(S_{L-1}\) 呢?对于移动区间 \([R+1,r]\),假如在线求解则要把 \([1,L-1]\) 的所有配对遍历一遍,很麻烦。我们可以先找出所有遍历的行为按照前缀和的方法递推求出。再次利用 \(g(x)\) 函数对问题求解,也可以快速求出。那么,莫队算法中单次指针移动的时间复杂度仍为 \(O(1)\),总的时间复杂度为 \(O(2^kn+n\sqrt n)\)

事实上,上面仅仅讨论了移动指针的一种情况,而莫队算法中共有四种移动(左端点左右移动,右端点左右移动),还要做十分复杂的分类讨论。具体操作请看代码。

#include <algorithm>
#include <cmath>
#include <cstring>
#include <iostream>
#include <vector>

#define get(x) (x) / len

using namespace std;

typedef long long ll;

const int maxn = 1e5 + 10;

int n, m, k, len;
int w[maxn], f[maxn], g[maxn];
ll ans[maxn];

struct node {
    int id, l, r;
    ll res;
    inline bool operator<(const node& x) const {
        int i = get(l), j = get(x.l);
        if (i != j) return i < j;
        return r < x.r;
    }
} q[maxn];

struct cord {
    int id, l, r, t;
};

vector<cord> ra[maxn];
vector<int> nums;

int main(void) {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m >> k;
    len = sqrt(n);
    for (int i = 1; i <= n; i++) cin >> w[i];
    for (int i = 0; i < 1 << 14; i++)
        if (__builtin_popcount(i) == k) nums.push_back(i);
    for (int i = 1; i <= n; i++) {
        for (auto y : nums) g[w[i] ^ y]++;
        f[i] = g[w[i + 1]];
    }
    for (int i = 0; i < m; i++) {
        cin >> q[i].l >> q[i].r;
        q[i].id = i;
    }
    sort(q, q + m);
    for (int i = 0, L = 1, R = 0; i < m; i++) {
        int l = q[i].l, r = q[i].r;
        if (R < r) ra[L - 1].push_back({i, R + 1, r, -1});
        while (R < r) q[i].res += f[R++];  // R+1 to r
        if (R > r) ra[L - 1].push_back({i, r + 1, R, 1});
        while (R > r) q[i].res -= f[--R];  // r+1 to R
        if (L < l) ra[R].push_back({i, L, l - 1, -1});
        while (L < l) q[i].res += f[L - 1] + !k, L++;  // L to l - 1
        if (L > l) ra[R].push_back({i, l, L - 1, 1});
        while (L > l) q[i].res -= f[L - 2] + !k, L--;  // l - 1 to L
    }

    memset(g, 0, sizeof(g));  //二次利用 g 函数
    for (int i = 1; i <= n; i++) {
        for (auto y : nums) g[w[i] ^ y]++;
        for (vector<cord>::iterator it = ra[i].begin(); it != ra[i].end(); it++)
            for (int x = it->l; x <= it->r; x++)
                q[it->id].res += g[w[x]] * it->t;
    }
    for (int i = 1; i < m; i++) q[i].res += q[i - 1].res;
    for (int i = 0; i < m; i++) ans[q[i].id] = q[i].res;
    for (int i = 0; i < m; i++) cout << ans[i] << '\n';
    return 0;
}

posted @ 2022-07-31 16:20  秋泉こあい  阅读(27)  评论(0编辑  收藏  举报