Tokitsukaze and Min-Max XOR

Tokitsukaze and Min-Max XOR

题目描述

Tokitsukaze 有一个长度为 $n$ 的序列 $a_1,a_2,\ldots,a_n$​ 和一个整数 $k$。

她想知道有多少种序列 $b_1,b_2,\ldots,b_m$​,满足:

  • $1 \leq b_i \leq n$
  • $b_{i−1}<b_i​$ $(2 \leq i \leq m)$
  • $\min⁡(a_{b_1} \, , a_{b_2} \, , \ldots, a_{b_m}) \oplus \max⁡(a_{b_1} \, , a_{b_2} \, , \ldots, a_{b_m}) \leq k$

其中 $\oplus$ 为按位异或,具体参见 百度百科:异或

答案可能很大,请输出  $\bmod 10^9+7$ 后的结果。

输入描述:

第一行包含一个整数 $T$ ($1 \leq T \leq 2 \cdot 10^5$),表示 $T$ 组测试数据。

对于每组测试数据:

第一行包含两个整数 $n$, $k$ ($1 \leq n \leq 2 \cdot 10^5$; $0 \leq k \leq 10^9$)。

第二行包含 $n$ 个整数 $a_1,a_2,\ldots,a_n$​ ($0 \leq a_i \leq 10^9$)。

保证 $\sum{n}$ 不超过 $2 \cdot 10^5$。

输出描述:

对于每组测试数据,输出一个整数,表示答案 $\bmod 10^9+7$ 后的结果。

示例1

输入

3
3 2
1 3 2
5 3
1 3 5 2 4
5 0
0 0 0 0 0

输出

6
10
31

说明

第一组测试数据,$k$ 为 $2$:

  1. 选择的序列 $b$ 为 $[1]$,$\min⁡(a_1) \oplus \max⁡(a_1)=1 \oplus 1=0 \leq 2$;
  2. 选择的序列 $b$ 为 $[2]$,$\min⁡(a_2) \oplus \max⁡(a_2)=3 \oplus 3=0 \leq 2$;
  3. 选择的序列 $b$ 为 $[3]$,$\min⁡(a_3) \oplus \max⁡(a_3)=2 \oplus 2=0 \leq 2$;
  4. 选择的序列 $b$ 为 $[1,2]$,$\min⁡(a_1,a_2) \oplus \max⁡(a_1,a_2)=1 \oplus 3=2 \leq 2$;
  5. 选择的序列 $b$ 为 $[2,3]$,$\min⁡(a_2,a_3) \oplus \max⁡(a_2,a_3)=2 \oplus 3=1 \leq 2$;
  6. 选择的序列 $b$ 为 $[1,2,3]$,$\min⁡(a_1,a_2,a_3) \oplus \max⁡(a_1,a_2,a_3)=1 \oplus 3=2 \leq 2$;

所以第一组测试数据的答案为 $6$ 。

 

解题思路

  看了点提示就做出来了,解法和昨天想到的思路差不多一样,不过最后没多少时间写了。

  容易知道 $b_1, \ldots, b_m$ 实际上是 $a$ 的一个子序列,并且由于我们只关注子序列中的最大值和最小值,因此可以先对 $a$ 从小到大排序,再选择子序列。接着对子序列中的最大值进行分类,可以分成 $n$ 类。即从左到右依次枚举 $a_i$ 作为子序列中的最大值,那么最小值就会在 $a_j, \, j \in [0, i]$ 中选。当满足 $a_i \oplus a_j \leq k$,那么以 $a_i$ 为最大值,$a_j$ 为最小值的子序列的数量就是 $2^{\max\{ 0,i-j-1 \}}$,特别的当 $i=j$ 时答案为 $1$。

  暴力的做法就是逐个枚举 $a_j$ 判断是否满足条件,时间复杂度是 $O(n^2)$ 的。由于涉及到异或运算所以尝试能不能用 trie 来维护 $a_j$ 的信息。如果 $a_j$ 满足条件,那么对答案的贡献是 $2^{i-j-1}$,也就是 $\frac{1}{2^{j+1}} \cdot 2^i$,因此在把 $a_j$ 按位插入 trie 中时,同时在对应节点加上 $\frac{1}{2^{j+1}}$。

  枚举到 $a_i$ 时,此时已经往 trie 中插入了 $a_0 \sim a_{i-1}$,枚举 $a_i$ 的每一位,用 $x_i$ 和 $m_i$ 分别表示 $a_i$ 和 $m$ 在二进制下第 $i$ 位上的值。如果 $x_i \oplus 0 < m_i$,说明此时 $0$ 的分支剩余的 $a_j$ 都满足条件,把该分支节点上的关于 $\frac{1}{2^{j+1}}$ 的和累加到答案 $s$。同理如果 $x_i \oplus 1 < m_i$,说明此时 $1$ 的分支剩余的 $a_j$ 都满足条件。然后走到下一个分支节点,如果 $x_i \oplus 0 = m_i$ 则走到 $0$ 的分支节点,否则走到 $1$ 的分支节点。最后以 $a_i$ 为最大值的子序列的数量就是 $1 + s \cdot 2^{i}$。

  AC 代码如下,时间复杂度为$O\left(n (\log{A} + \log{n})\right)$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 2e5 + 10, mod = 1e9 + 7;

int n, m;
int a[N];
int tr[N * 30][2], idx, s[N * 30];

int qmi(int a, int k) {
    int ret = 1;
    while (k) {
        if (k & 1) ret = 1ll * ret * a % mod;
        a = 1ll * a * a % mod;
        k >>= 1;
    }
    return ret;
}

void add(int x, int c) {
    int p = 0;
    for (int i = 29; i >= 0; i--) {
        int t = x >> i & 1;
        if (!tr[p][t]) tr[p][t] = ++idx;
        p = tr[p][t];
        s[p] = (s[p] + c) % mod;
    }
}

int query(int x, int c) {
    int p = 0, ret = 0;
    for (int i = 29; i >= 0; i--) {
        int t = x >> i & 1;
        if (tr[p][0] && t < (m >> i & 1)) ret = (ret + 1ll * s[tr[p][0]] * c) % mod;
        if (tr[p][1] && (t ^ 1) < (m >> i & 1)) ret = (ret + 1ll * s[tr[p][1]] * c) % mod;
        if (t == (m >> i & 1)) {
            if (!tr[p][0]) return ret;
            else p = tr[p][0];
        }
        else {
            if (!tr[p][1]) return ret;
            else p = tr[p][1];
        }
    }
    ret = (ret + 1ll * s[p] * c) % mod;
    return ret;
}

void solve(){
    scanf("%d %d", &n, &m);
    for (int i = 0; i < n; i++) {
        scanf("%d", a + i);
    }
    sort(a, a + n);
    idx = 0;
    for (int i = 0; i <= n * 30; i++) {
        tr[i][0] = tr[i][1] = s[i] = 0;
    }
    int ret = 0;
    for (int i = 0; i < n; i++) {
        ret = (ret + 1 + query(a[i], qmi(2, i))) % mod;
        add(a[i], qmi(qmi(2, i + 1), mod - 2));
    }
    printf("%d\n", ret);
}

int main() {
    int t;
    scanf("%d", &t);
    while (t--) {
        solve();
    }
    
    return 0;
}

 

参考资料

  【题解】2024牛客寒假算法基础集训营2:https://ac.nowcoder.com/discuss/1251379/

posted @ 2024-02-06 18:40  onlyblues  阅读(26)  评论(0编辑  收藏  举报
Web Analytics