Count Pairs With XOR in a Range

Count Pairs With XOR in a Range

Given a (0-indexed) integer array nums and two integers low and high , return the number of nice pairs.

A nice pair is a pair (i, j) where 0 <= i < j < nums.length and low <= (nums[i] XOR nums[j]) <= high.

Example 1:

Input: nums = [1,4,2,7], low = 2, high = 6
Output: 6
Explanation: All nice pairs (i, j) are as follows:
    - (0, 1): nums[0] XOR nums[1] = 5 
    - (0, 2): nums[0] XOR nums[2] = 3
    - (0, 3): nums[0] XOR nums[3] = 6
    - (1, 2): nums[1] XOR nums[2] = 6
    - (1, 3): nums[1] XOR nums[3] = 3
    - (2, 3): nums[2] XOR nums[3] = 5

Example 2:

Input: nums = [9,8,4,2,1], low = 5, high = 14
Output: 8
Explanation: All nice pairs (i, j) are as follows:
​​​​​    - (0, 2): nums[0] XOR nums[2] = 13
    - (0, 3): nums[0] XOR nums[3] = 11
    - (0, 4): nums[0] XOR nums[4] = 8
    - (1, 2): nums[1] XOR nums[2] = 12
    - (1, 3): nums[1] XOR nums[3] = 10
    - (1, 4): nums[1] XOR nums[4] = 9
    - (2, 3): nums[2] XOR nums[3] = 6
    - (2, 4): nums[2] XOR nums[4] = 5

 Constraints:

  • $1 \leq \text{nums.length} \leq 2 \times {10}^4$
  • $1 \leq \text{nums}[i] \leq 2 \times {10}^4$
  • $1 \leq \text{low} \leq high \leq 2 \times {10}^4$

 

解题思路

  一开始没想到怎么做,然后看了眼标签发现是trie就自己写出来了,很经典的trie题,主要是没想到。

  因为问的是异或后在$[\text{low}, ~\text{high}]$范围内的数,因此可以先求出异或结果不超过$\text{high}$的个数$f(\text{high})$,再求出异或结果不超过$\text{low-1}$的个数$f(\text{low-1})$,你那么$[\text{low}, ~\text{high}]$范围内的数的个数就是$f(\text{high}) - f(\text{low-1})$。

  每个数的最大数值不超过$2 \times {10}^4$,意味着转换成二进制后最多有$\left\lceil \log{2 \times {10}^4} \right\rceil = 15$位。因为比较的时候是从最高位开始比较,因此在trie中插入某个数的二进制串时应该从最高位开始往最低位依次插入。

  当枚举到$a_i$,此时第$0 \sim i-1$个数都已插入到trie中,现在问前面有多少个数与$a_i$异或后的结果不超过$s$,即问$f(s)$是多少。依次从高位往低位枚举,当枚举到第$k$位时,如果$s$的第$k$位为$1$,$a_i$的第$k$位为$t$,那么很显然如果异或后的结果$x$的第$k$位为$0$,那么那么$x$剩下的位可以任意取值都不会超过$s$,此时只需看看在trie中有多少数的第$k$位是$t$(因为$t \oplus t = 0$),然后再向下走到第$k$位为$!t$的节点(因为$!t \oplus t = 1$),对应的异或结果的第$k$位为$1$。如果$s$的第$k$位为$0$,那么异或后的结果$x$的第$k$位为只能取$0$,此时只能向下走到第$k$位为$t$的节点,对应的异或结果的第$k$位为$0$。可以发现前当枚举到第$k$位时,得到的异或结果的前$k$位与$s$的前$k$位相同(不会超过$s$)。

  可以发现在枚举的过程中需要知道有多少个数的第$k$位为某个值。这个只需要开个数组来记录每个节点会被多少个数用到,在插入的时候每走到一个节点则该节点的计数加$1$。

  还需要注意的是在比较的过程中可能会出现下一个要走的节点不存在的情况,这时直接返回已累加的答案就好了。

  AC代码如下,时间复杂度为$O(15 \cdot n)$:

 1 const int N = 3e5 + 10;
 2 
 3 int tr[N][2], idx;
 4 int cnt[N];
 5 
 6 class Solution {
 7 public:
 8     void add(int x) {
 9         int p = 0;
10         for (int i = 14; i >= 0; i--) {
11             int t = x >> i & 1;
12             if (!tr[p][t]) tr[p][t] = ++idx;
13             p = tr[p][t];
14             cnt[p]++;   // 每走过一个节点就加1
15         }
16     }
17     
18     int query(int x, int s) {
19         int p = 0, ret = 0;
20         for (int i = 14; i >= 0; i--) {
21             int t = x >> i & 1;
22             if (s >> i & 1) {
23                 ret += cnt[tr[p][t]];   // 把第i位为t的数的个数加上,异或后的结果的第i位为0
24                 p = tr[p][!t];  // 走到第i位为!t的节点,异或结果的第i位为1
25             }
26             else {
27                 p = tr[p][t];   // 只能保证异或后的结果的第i位为0,因此走到第i位为t的节点
28             }
29             if (p == 0) return ret; // 无法往下走
30         }
31         return ret + cnt[p];
32     }
33     
34     int countPairs(vector<int>& nums, int low, int high) {
35         int n = nums.size();
36         idx = 0;
37         memset(tr, 0, sizeof(tr));
38         memset(cnt, 0, sizeof(cnt));
39         int ret = 0;
40         for (int i = 0; i < n; i++) {
41             ret += query(nums[i], high) - query(nums[i], low - 1);
42             add(nums[i]);
43         }
44         return ret;
45     }
46 };
posted @ 2023-01-12 16:45  onlyblues  阅读(32)  评论(0编辑  收藏  举报
Web Analytics