浅谈莫队2

莫队

二次离线莫队

AcWing 2535. 二次离线莫队

本题有若干个询问,每个询问都要求出某个区间中异或和在二进制表示中有 \(k\)\(1\) 的数对个数。

我们规定,如果某两个数的异或和在二进制表示中有 \(k\)\(1\),我们就称这两个数是配对的,因此每个询问就变成了求某个区间中有多少数对是配对的。

本题需要用到二次离线莫队来做,而二次离线莫队就是一共需要离线两次来做,在做莫队时,我们每次都会对一段区间去查询一个数,然后我们都会去对两个端点进行移动,然后在新的维护区间中去求我们的询问。

而对于二次离线莫队,就是当我们每次更新完维护区间之后,对于区间的询问很难算,所以我们需要在每次更新完维护区间之后,再把当前询问单独拎出来再重新离线求当前询问的值。二次离线算法思维难度一般不高,但是代码实现中的细节非常多,且每道题都需要重新思考,可以说是非常恶心。

要想使用莫队算法来求,那么每次我们都需要从上一个询问区间 \([l, r]\) 的信息快速得到当前询问区间 \([L, R]\) 的信息。以右端点 \(r\) 为例,当 \(r\) 右移后,我们需要将 \(w_{r+1}\) 加入到维护区间中,那么我们就需要考虑将 \(w_{r+1}\) 加入到维护区间中后,维护区间的信息该怎么去维护。当我们将 \(w_{r+1}\) 加入后,需要求一下它对于配对的数量有什么样的影响,显然配对的数量只会增加,至于增加的数量,就是 \([l, r]\) 中和 \(w_{r+1}\) 配对的数的个数,这一步可以用前缀和来求。

设 $S_i$ 表示 $w_1 \sim w_i$ 中有多少个数和 $w_{r+1}$ 配对,此时对于 $[l, r]$ 中和 $w_{r+1}$ 配对的数的个数就是 $S_r-S_{l-1}$。接下来就是要求 $S_r$ 和 $S_{l-1}$,对于 $S_r$,其实是问 $w_1 \sim w_r$ 中有多少个数和 $w_{r+1}$ 配对,可以发现要询问的数就是区间的后一个数,而 $S_{l-1}$ 则没有这么好的性质,$w_{l-1}$ 和 $w_{r+1}$ 之间的距离是非常随机的,毫无规律可循,因此 $S_r$ 和 $S_{l-1}$ 其实是两类询问,分两种情况来考虑。

首先对于 \(S_r\),由于这一部分是非常有规律的,所以可以提前预处理,设 \(f_i\) 表示 \(w_1 \sim w_i\) 中与 \(w_{i+1}\) 配对的数个数,而 \(S_r\) 显然就是 \(f_i\),因此我们需要快速的预处理出 \(f_i\),可以用一个 \(g_x\) 表示前 \(i\) 个数中有多少个数与 \(x\) 配对。当我们把 \(g_x\) 预处理出来,则 \(f_i = g_{w_{i+1}}\)

因此我们现在就是需要求出 \(i\) 阶段的 \(g_x\),假设当前 \(g_x\) 表示前 \(i-1\) 个数中有多少个数与 \(x\) 配对,这里我们可以先预处理出 \(0 \sim 2^{14}-1\) 中所有有 \(k\)\(1\) 的数 \(y_i\),我们想从前 \(i-1\) 个数的 \(g_x\) 变成前 \(i\) 个数的 \(g_x\),相当于是加入了一个新的数 \(w_i\),此时我们只要找出所有和 \(w_i\) 配对的 \(x\),令 \(g_x+1\),最终就能得到前 \(i\) 个数的 \(g_x\),而我们要找的所有 \(x\) 则必须满足 \(w_i~xor~x=y_i\),这个条件等价于 \(x=w_i~xor~y_i\),因此我们可以枚举所有不同的 \(y_i\),通过 \(y_i~xor~w_i\) 计算出所有的 \(x\),因为 \(y_i\) 不同,所以得出的 \(x\) 也不同。

综上所述,对于前 \(i-1\) 个数的 \(g_x\),我们枚举所有 \(y_i\),令 \(g_{y_i~xor~w_i}+1\),最终就能得到前 \(i\) 个数的 \(g_x\),然后再用 \(g_x\) 计算 \(f_i\) 即可。这样我们就能用 \(g_x\) 作为辅助递推得出所有的 \(f_i\)。这一部分预处理一共只需要做一次,由于 \(k\) 比较小,\(y_i\) 最多只有三千多个,因此预处理的计算量最多只有三千多万。

接下来我们需要想办法求出 $S_{l-1}$,$S_r$ 通过我们刚才的分析,我们可以在做莫队的过程中在线求出来,但是 $S_{l-1}$ 并不能马上求出来,因此我们只能先将所有要求 $S_{l-1}$ 的问题先找出来,然后我们再离线把所有要求的 $S_{l-1}$ 求出来,最后我们才能求出 $S_r - S_{l-1}$。

要求 \(S_{l-1}\),其实就是求 \(w_1 \sim w_{l-1}\) 中有多少个数是和 \(w_{r+1}\) 配对的,以此类推,在从 \(r\) 移动到 \(R\) 的过程中,我们会将 \(r+1,r+2,…,R\) 都加入到维护区间中,因此对于 \(\forall x \in [r+1,R]\),都要求出 \(w_1 \sim w_{l-1}\) 中和 \(w_x\) 配对的数的个数,可以发现这些问题都是问某个固定前缀中,某个区间的每个数和这个前缀中有多少个数配对。我们可以将这些询问全部找出来,然后按照从前往后的顺序计算所有询问,我们先算一下所有前缀是 \(1\) 的询问,再算一下所有前缀是 \(2\) 的询问,以此类推,直到我们算完所有前缀是 \(l-1\) 的询问后,我们就将所有的询问都处理完了。

由于我们是从前往后做所有询问,因此每一次前缀中只会增加一个数,因此这里我们同样可以用一个 \(g_x\) 数组来作为辅助,表示的内容和上面相同,同样表示前 \(i\) 个数中与 \(x\) 配对的数的个数,因此对于 \(\forall x \in [r+1,R]\),我们想求的 \(w_1 \sim w_{l-1}\) 中和 \(x\) 配对的数的个数恰好就是 \(g_x\),直接按照上面更新 \(g_x\) 的思路依次往后求即可。而这一部分的询问数量应该取决于两个指针移动的次数,这个在基础莫队中就已经证明过是 \(O(\sqrt{n})\) 级别的,因此我们就能用一个 \(O(n\sqrt{n})\) 的离线做法求出所有的 \(S_{l-1}\),然后就能把这部分在莫队中无法解决的问题统一计算出来。

到此我们就能将前一个询问到当前询问的增量 \(S_r-S_{l-1}\) 求出来,但是这并不是当前询问的答案,如果我们想求某一个询问的结果的话,还需要将前面求出来的所有增量累加成前缀和才是最终答案。

注意,上面我们推导了 \(r\) 向右移动到 \(R\) 这一种情况,实际上两个指针一共有四种情况,而其他三种情况都按照上面同样的形式去分析即可,代码实现时需要根据每种情况的区别做一些更细致的处理,这里就不过多赘述,直接体现在代码中。

#include <bits/stdc++.h>

#define rint register int
#define int long long
#define endl '\n'

using namespace std;

const int N = 1e5 + 5;

int n, m, k, len;
int w[N], g[N], f[N];
int ans[N];
struct node
{
    int id, l, r, t;
    int res;
} q[N];
vector<node> range[N];

int get(int i){return i / len;}

bool cmp(node a, node b)
{
    int l = get(a.l), r = get(b.l);
    if (l != r) return l < r;
    return a.r < b.r;
}

bool count(int i)
{
    int res = 0;
    for (rint j = 0; j < 14; j++)
        if (i >> j & 1)
            res++;
    return res == k;
}

signed main()
{
    cin >> n >> m >> k;
    
    for (rint i = 1; i <= n; i++)
    {
		cin >> w[i];
	}
	
    for (rint i = 1; i <= m; i++)
    {
        int l, r;
        cin >> l >> r;
        q[i] = {i, l, r};
    }

    vector<int> nums;
    for (rint i = 0; i < (1 << 14); i++)
    {
        if (count(i))
        {
			nums.push_back(i);
		}
	}
            
    for (rint i = 1; i <= n; i++)
    {
        for (auto y : nums) g[w[i] ^ y]++;
        f[i] = g[w[i + 1]];
    }

    len = sqrt(n);
    sort(q + 1, q + m + 1, cmp);

    for (rint i = 1, L = 1, R = 0; i <= m; i++)
    {
        int l = q[i].l, r = q[i].r;
        
        if (R < r) range[L - 1].push_back({i, R + 1, r, -1});
        while (R < r) q[i].res += f[R++];
        
        if (R > r) range[L - 1].push_back({i, r + 1, R, 1});
        while (R > r) q[i].res -= f[--R];
        
        if (L < l) range[R].push_back({i, L, l - 1, -1});
        while (L < l) q[i].res += f[L - 1] + !k, L++;
        
        if (L > l) range[R].push_back({i, l, L - 1, 1});
        while (L > l) q[i].res -= f[L - 2] + !k, L--;
    }

    memset(g, 0, sizeof g);
    
    for (rint i = 1; i <= n; i++)
    {
        for (auto y : nums) g[w[i] ^ y]++;
        for (auto &rg : range[i])
        {
            int id = rg.id, l = rg.l, r = rg.r, t = rg.t;
            for (rint x = l; x <= r; x++)
            {
                q[id].res += t * g[w[x]];				
			}
        }
    }

    for (rint i = 2; i <= m; i++)
    {
        q[i].res += q[i - 1].res;		
	}

    for (rint i = 1; i <= m; i++)
    {
        ans[q[i].id] = q[i].res;
	}
	
    for (rint i = 1; i <= m; i++)
    {
		cout << ans[i] << endl;
	}

    return 0;
}

树上莫队

SP10707

先将整棵树的欧拉序求出来,记录每个点第一次出现的位置 $first[i]$ 和最后一次出现的位置 $last[i]$,然后观察树中的路径 $[l,r](first[l]<first[r])$ 可以发现两种情况:

  1. 如果路径是一条从上往下的直链,则其所有点对应欧拉序中 $first[l]$ 到 $first[r]$ 中出现一次的点
  2. 否则其所有点对应欧拉序中 $first[l]$ 到 $last[r]$ 中出现一次的点加上 $lca(l,r)$

理解一下会发现的确这样,然后问题就转化为普通莫队问题了

    #include <bits/stdc++.h>
    
    #define rint register int
    #define int long long
    #define endl '\n'
    #define queue queue__
    
    using namespace std;
    
    const int N = 2e6 + 5;
    const int M = 1e7 + 5;
    
    int n, m, len;
    int h[N], e[M], ne[M], idx;
    int w[N], seq[N], first[N], last[N], top;
    int queue[N], dep[N], fa[N][25], cnt[N];
    int ans[N];
    bool st[N];
    vector<int> nums;
    
    struct node
    {
        int id, l, r, p;
    } q[N];
    
    void add(int a, int b)
    {
        e[++idx] = b, ne[idx] = h[a], h[a] = idx;
    }
    
    int get(int i){return i / len;}
    
    bool cmp(node a, node b)
    {
        int l = get(a.l), r = get(b.l);
        if (l != r) return l < r;
        return a.r < b.r;
    }
    
    void dfs(int x, int father)
    {
        seq[++top] = x;
        first[x] = top;
        for (rint i = h[x]; i; i = ne[i])
        {
            int y = e[i];
            if (y != father)
            {
                dfs(y, x);			
    		}
        }
        seq[++top] = x;
        last[x] = top;
    }
    
    void bfs(int s)
    {
        memset(dep, -1, sizeof dep);
        int hh = 0, tt = 0;
        queue[0] = s, dep[s] = 0;
        while (hh <= tt)
        {
            int x = queue[hh++];
            for (rint i = h[x]; i; i = ne[i])
            {
                int y = e[i];
                if (dep[y] == -1)
                {
                    dep[y] = dep[x] + 1;
                    fa[y][0] = x;
    				queue[++tt] = y;
                    for (rint j = 1; j <= 20; j++)
                    {
                        fa[y][j] = fa[fa[y][j - 1]][j - 1];					
    				}
                }
            }
        }
    }
    
    int lca(int a, int b)
    {
        if (dep[a] < dep[b]) swap(a, b);
        for (rint i = 20; i >= 0; i--)
            if (dep[fa[a][i]] >= dep[b])
                a = fa[a][i];
        if (a == b) return a;
        for (rint i = 20; i >= 0; i--)
            if (fa[a][i] != fa[b][i])
                a = fa[a][i], b = fa[b][i];
        return fa[a][0];
    }
    
    void change(int x, int &res)
    {
        st[x] ^= 1;
        if (st[x] == 0)
        {
            cnt[w[x]]--;
            if (!cnt[w[x]]) res--;
        }
        else
        {
            if (!cnt[w[x]]) res++;
            cnt[w[x]]++;
        }
    }
    
    signed main()
    {
        cin >> n >> m;
        
        for (rint i = 1; i <= n; i++)
        {
    		cin >> w[i];
    		nums.push_back(w[i]);
    	}
    	
        sort(nums.begin(), nums.end());
        nums.erase(unique(nums.begin(), nums.end()), nums.end());
        
        for (rint i = 1; i <= n; i++)
        {
            w[i] = lower_bound(nums.begin(), nums.end(), w[i]) - nums.begin();		
    	}
    
        for (rint i = 1; i < n; i++)
        {
            int a, b;
            cin >> a >> b;
            add(a, b);
    		add(b, a);
        }
        
        bfs(1);
    	dfs(1, 1);
        
    	len = sqrt(top);
    	
        for (rint i = 1; i <= m; i++)
        {
            int x, y;
            cin >> x >> y;
            if (first[x] > first[y]) swap(x, y);
            int p = lca(x, y);
            if (x == p) q[i] = {i, first[x], first[y]};
            else q[i] = {i, last[x], first[y], p};
        }
    
        sort(q + 1, q + m + 1, cmp);
    
        for (rint k = 1, i = 1, j = 0, res = 0; k <= m; k++)
        {
            int id = q[k].id, l = q[k].l, r = q[k].r, p = q[k].p;
            while (i < l) change(seq[i++], res);
            while (i > l) change(seq[--i], res);
            while (j < r) change(seq[++j], res);
            while (j > r) change(seq[j--], res);
            if (p) change(p, res);
            ans[id] = res;
            if (p) change(p, res);
        }
    
        for (rint i = 1; i <= m; i++)
        {
    		cout << ans[i] << endl;
    	}
    
        return 0;
    }  
    
posted @ 2024-01-01 14:09  PassName  阅读(7)  评论(0编辑  收藏  举报