CF241B Friends

知识点:01 Trie

原题面:CFLuogu

「十二省联考 2019」异或粽子 是本题的另一数据范围版本。

简述

给定一长度为 \(n\) 的数列 \(a\),求两两异或值前 \(k\) 大的和,答案对 \(10^9 + 7\) 取模。
\(1\le n\le 5\times 10^4\)\(0\le k\le \frac{n(n-1)}{2}\)\(0\le a_i\le 10^9\)
6S,256MB。

分析

两两异或和是有序对,考虑令 \(k\) 翻倍,再使最终答案除 2,以除去有序对的限制。显然这样可以保证原问题中前 \(k\) 大元素都会被统计到两次,且根据异或的自反性,不合法的满足 \(l = r\) 数对贡献为 0,不会影响答案。

要求前 \(k\) 大异或值的和,考虑先求出第 \(k\) 的的异或值。显然可以二分解决。先对 \(a\) 建立 Trie 并在 Trie 上维护 \(\operatorname{size}\),之后二分答案,再枚举所有 \(a_i\),求得有多少个异或值 \(a_i\oplus a_j\) 不小于二分量,检查二分量是否是第 \(k\) 大即可。此过程可以在 Trie 上检查二分量各二进制与 \(a_i\) 各二进制的情况来解决,代码如下所示:

namespace Trie {
  LL QueryCnt(int now_, int lth_, LL val_, LL lim_) {
    if (!now_ && lth_ != 30) return 0; //无贡献的情况
    if (lth_ < 0) return siz[now_]; //深入到底,此时有 val[now_] = lim_
    int ch = 1ll * val_ >> lth_ & 1ll;
    if (lim_ >> lth_ & 1) { //lim_ 这一位是 1,若异或值不小于大于 lim_,则异或值此位必须为 1。
      return QueryCnt(tr[now_][ch ^ 1], lth_ - 1, val_, lim_);
    } else { //lim_ 这一位是 0,这一位是 1 的异或值一定比 lim_ 大,统计贡献 size,并统计异或值此位为 0 的贡献。
      return siz[tr[now_][ch ^ 1]] + 
             QueryCnt(tr[now_][ch], lth_ - 1, val_, lim_);
    }
  }
}
bool Check(LL mid_) {
  LL ret = 0;
  //QueryCnt 能返回所有异或值 a_i xor ? 不小于 mid 的数的个数
  for (int i = 1; i <= n; ++ i) ret += Trie::QueryCnt(0, 30, a[i], mid_);
  return k <= ret;
}

之后考虑求得所有不小于第 \(k\) 大值的异或值的和。仍考虑先枚举再枚举所有 \(a_i\),在 Trie 上检查二分量各二进制与 \(a_i\) 各二进制的情况,得到所有异或值。
观察上面代码可知,当第 \(k\) 大值的某一位是 0 时,说明这一位是 1 的状态为根的子树中所有 \(a_j\)\(a_i\) 的异或值大于第 \(k\) 大值,它们对答案有贡献。问题变为如何快速统计 Trie 的某棵子树中所有数与给定值的异或值的和。
这里有一个小 trick,先对 \(a\) 进行排序,再按照顺序把它们插入 Trie 中,这能使得 Trie 一棵子树中的数是原数列中一段连续的子段。其原理显然,可以类比基数排序。
数列与给定值的异或值的和很好做,考虑拆位,对每个数进行二进制分解,用前缀和分别维护每一位上 0/1 的个数。查询时取出两区间中各位 0/1 的个数,考察与每位不同的数的对数,统计贡献即可。

考虑复杂度,第一部分二分答案 + Trie 上贪心复杂度为 \(O(n\log^2 w)\),第二部分统计贡献 Trie 上贪心 + 拆位统计贡献复杂度也为 \(O(n\log^2 w)\)。其中 \(w = \max\{a_i\} \approx 30\)

代码

//知识点:01 Trie,贪心
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define LL long long
const int kN = 5e4 + 10;
const LL mod = 1e9 + 7;
const LL inv2 = 500000004;
//=============================================================
LL n, k, kth, ans, a[kN], sum[kN][32];
//=============================================================
inline LL read() {
  LL f = 1, w = 0;
  char ch = getchar();
  for (; !isdigit(ch); ch = getchar())
    if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
void Chkmax(int &fir, int sec) {
  if (sec > fir) fir = sec;
}
void Chkmin(int &fir, int sec) {
  if (sec < fir) fir = sec;
}
LL Calc(int l_, int r_, int val_) {
  if (!l_ && !r_) return 0;
  LL ret = 0;
  for (int i = 30; i >= 0; -- i) {
    LL cnt[2] = {0};
    cnt[1] = sum[r_][i] - sum[l_ - 1][i], cnt[0] = r_ - l_ + 1 - cnt[1];
    ret = (ret + (1ll << i) % mod * cnt[(val_ >> i & 1) ^ 1] % mod) % mod;
  }
  return ret;
}
namespace Trie {
  const int kMaxNode = kN << 7;
  int node_num, tr[kMaxNode][2];
  int l[kMaxNode], r[kMaxNode];
  LL siz[kMaxNode];
  void Insert(int val_, int id_) {
    int now_ = 0;
    for (int i = 30; i >= 0; -- i) {
      int ch = 1ll * val_ >> i & 1ll;
      if (!tr[now_][ch]) {
        tr[now_][ch] = ++ node_num;
        l[node_num] = id_;
      }
      now_ = tr[now_][ch];
      r[now_] = id_;
      siz[now_] ++;
    }
  }
  LL QueryCnt(int now_, int lth_, LL val_, LL lim_) {
    if (!now_ && lth_ != 30) return 0;
    if (lth_ < 0) return siz[now_];
    int ch = 1ll * val_ >> lth_ & 1ll;
    if (lim_ >> lth_ & 1) {
      return QueryCnt(tr[now_][ch ^ 1], lth_ - 1, val_, lim_);
    } else {
      return siz[tr[now_][ch ^ 1]] + 
             QueryCnt(tr[now_][ch], lth_ - 1, val_, lim_);
    }
  }
  LL QuerySum(int now_, int lth_, LL val_, LL lim_) {
    if (!now_ && lth_ != 30) return 0;
    if (lth_ < 0) return Calc(l[now_], r[now_], val_);
    int ch = 1ll * val_ >> lth_ & 1ll;
    if (lim_ >> lth_ & 1) {
      return QuerySum(tr[now_][ch ^ 1], lth_ - 1, val_, lim_);
    } else {
      return (Calc(l[tr[now_][ch ^ 1]], r[tr[now_][ch ^ 1]], val_) + 
             QuerySum(tr[now_][ch], lth_ - 1, val_, lim_)) % mod;
    }
  }
}
void Init() {
  n = read(), k = read() << 1ll;
  for (int i = 1; i <= n; ++ i) a[i] = read();
  std::sort(a + 1, a + n + 1);
  for (int i = 1; i <= n; ++ i) {
    Trie::Insert(a[i], i);
    for (int j = 30; j >= 0; -- j) {
      sum[i][j] = sum[i - 1][j] + (a[i] >> j & 1);
    }
  }
}
bool Check(LL mid_) {
  LL ret = 0;
  //返回大于 mid 的数的个数
  for (int i = 1; i <= n; ++ i) ret += Trie::QueryCnt(0, 30, a[i], mid_);
  return k <= ret;
}
//=============================================================
int main() {
  Init();
  // Check(1);
  for (LL l = 0, r = 2e9; l <= r; ) {
    int mid = (l + r) >> 1;
    if (Check(mid)) {
      kth = mid;
      l = mid + 1;
    } else {
      r = mid - 1;
    }
  }
  for (int i = 1; i <= n; ++ i) {
    ans = (ans + Trie::QuerySum(0, 30, a[i], kth + 1)) % mod;
    k -= Trie::QueryCnt(0, 30, a[i], kth + 1);
  }
  ans = 1ll * (ans + 1ll * k * kth % mod) % mod * inv2 % mod;
  printf("%lld\n", ans);
  return 0;
}
posted @ 2021-01-27 19:33  Luckyblock  阅读(88)  评论(0编辑  收藏  举报