CF241B Friends 题解

异或粽子 的超级加强版,但是本题因为 \(m\) 很大,不能套用那一题的解法。


换一种思路:考虑把 \(a_i\) 从高位到低位插入 0-1 Trie 之后,二分第 \(m\) 大,通过第 \(m\) 大求答案。

对于二分的一个值 \(x\),枚举每个位置 \(i\),在 0-1 Trie 上找与 \(a_i\) 异或值大于等于 \(x\) 的值的个数。

类比求最大异或和的过程,考虑搜索到第 \(j\) 位。如果 \(x\) 的第 \(j\) 位为 \(1\),为了最终异或值大于等于 \(x\),可能的数一定在与 \(a_i\) 的第 \(j\) 位相异的子树中,递归即可;反之,如果 \(x\) 的第 \(j\) 位为 \(0\),与 \(a_i\) 的第 \(j\) 位相异的子树中的值一定全部满足条件,递归与 \(a_i\) 的第 \(j\) 位相同的子树即可。

于是可以在 \(O(n \log^2 w)\) 的时间内找到第 \(m\) 大的异或值(\(w\) 为值域,后文同),设这个值为 \(k\)


下一步是求前 \(m\) 大两两异或值的和。

容易想到,类似前文所述,只需要处理被计算的完整的子树与 \(a_i\) 的异或值的和。(即搜索时找到的 \(k\) 的第 \(j\) 位为 \(0\),与 \(a_i\) 的第 \(j\) 位相异的子树)直接对这些子树的根节点打标记,整体遍历一次 0-1 Trie 时容易得到这棵子树内每一位上 \(0\)\(1\) 的数量,答案也就容易统计了。

至多有 \(n \log w\) 个标记,处理每个标记需要枚举 \(\log w\) 位。同时,至多合并 \(O(n \log w)\) 次,单次合并的时间为 \(O(\log w)\)。综上,时间复杂度 \(O(n \log^2 w)\)


将两部分拼起来就得到了最终做法,时间复杂度 \(O(n \log^2 w)\),可以通过。

#include <iostream>
#include <map>
#include <vector>

using namespace std;

typedef long long ll;

const ll mod = 1e9 + 7;

static inline ll qpow(ll a, ll b) {
    ll ret = 1;
    while (b) {
        if (b & 1)
            ret = ret * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ret;
}
const ll _2 = qpow(2, mod - 2);

int n, m;
int a[50005];

int rt;
map<int, vector<int>> mp;
struct trie {
    int son[2];
    int val, cnt[32];
} tr[1600005];
int cnt;
static inline int insert(int x, int dep, int p) {
    if (!p)
        p = ++cnt;
    if (dep == -1) {
        ++tr[p].val;
        return p;
    }
    int tag = (x >> dep) & 1;
    tr[p].son[tag] = insert(x, dep - 1, tr[p].son[tag]);
    tr[p].val = tr[tr[p].son[0]].val + tr[tr[p].son[1]].val;
    return p;
}
static inline ll query(int x, int aim, int dep, int p) {
    if (!p)
        return 0;
    if (dep == -1)
        return tr[p].val;
    int tg = (x >> dep) & 1;
    int tg2 = (aim >> dep) & 1;
    if (tg2)
        return query(x, aim, dep - 1, tr[p].son[tg ^ 1]);
    return tr[tr[p].son[tg ^ 1]].val + query(x, aim, dep - 1, tr[p].son[tg]);
}

ll ans, sum;
static inline void query2(int x, int aim, int dep, int p) {
    if (!p)
        return;
    if (dep == -1)
        return;
    int tg = (x >> dep) & 1;
    int tg2 = (aim >> dep) & 1;
    if (tg2)
        return query2(x, aim, dep - 1, tr[p].son[tg ^ 1]);
    if (tr[p].son[tg ^ 1])
        mp[tr[p].son[tg ^ 1]].push_back(x);
    query2(x, aim, dep - 1, tr[p].son[tg]);
}
static inline void dfs(int dep, int p, int from, int val) {
    if (!p)
        return;
    if (dep == -1) {
        if (from)
            tr[p].cnt[dep + 1] += tr[p].val;
        if (mp.count(p))
            for (auto x : mp[p]) {
                sum += tr[p].val;
                ll tot = 0;
                for (int i = 0; i <= dep + 1; ++i) {
                    int tg = (x >> i) & 1;
                    if (tg)
                        tot += (ll)(tr[p].val - tr[p].cnt[i]) * (1ll << i) % mod;
                    else
                        tot += (ll)tr[p].cnt[i] * (1ll << i) % mod;
                }
                for (int i = dep + 2; i <= 30; ++i) {
                    int tg = (x >> i) & 1;
                    int tg2 = (val >> i) & 1;
                    if (tg ^ tg2)
                        tot += (ll)tr[p].val * (1ll << i) % mod;
                }
                ans = (ans + tot) % mod;
            }
        return;
    }
    if (tr[p].son[0]) {
        dfs(dep - 1, tr[p].son[0], 0, val);
        for (int i = 0; i <= 30; ++i)
            tr[p].cnt[i] += tr[tr[p].son[0]].cnt[i];
    }
    if (tr[p].son[1]) {
        dfs(dep - 1, tr[p].son[1], 1, val | (1 << dep));
        for (int i = 0; i <= 30; ++i)
            tr[p].cnt[i] += tr[tr[p].son[1]].cnt[i];
    }
    if (from)
        tr[p].cnt[dep + 1] += tr[p].val;
    if (mp.count(p))
        for (auto x : mp[p]) {
            sum += tr[p].val;
            ll tot = 0;
            for (int i = 0; i <= dep + 1; ++i) {
                int tg = (x >> i) & 1;
                if (tg)
                    tot += (ll)(tr[p].val - tr[p].cnt[i]) * (1ll << i) % mod;
                else
                    tot += (ll)tr[p].cnt[i] * (1ll << i) % mod;
            }
            for (int i = dep + 2; i <= 30; ++i) {
                int tg = (x >> i) & 1;
                int tg2 = (val >> i) & 1;
                if (tg ^ tg2)
                    tot += (ll)tr[p].val * (1ll << i) % mod;
            }
            ans = (ans + tot) % mod;
        }
}

static inline bool check(int x) {
    ll tot = 0;
    for (int i = 1; i <= n; ++i)
        tot += query(a[i], x, 30, rt);
    tot >>= 1;
    return tot >= m;
}

static inline void solve() {
    cin >> n >> m;
    if (m == 0) {
        cout << 0 << endl;
        return;
    }
    for (int i = 1; i <= n; ++i) {
        cin >> a[i];
        rt = insert(a[i], 30, rt);
    }
    int l = 1;
    int r = 1ll << 30;
    int ret = -1;
    while (l <= r) {
        int mid = (int)((l + r) >> 1);
        if (check(mid)) {
            ret = mid;
            l = mid + 1;
        } else {
            r = mid - 1;
        }
    }
    for (int i = 1; i <= n; ++i)
        query2(a[i], ret, 30, rt);
    dfs(30, rt, 0, 0);
    sum >>= 1;
    ans = ans * _2 % mod;
    ans = (ans + (m - sum) * ret % mod) % mod;
    cout << ans << endl;
}

signed main() {
#ifndef ONLINE_JUDGE
    freopen("1.in", "r", stdin);
#endif
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    solve();
    return 0;
}
posted @ 2024-08-24 22:47  bluewindde  阅读(4)  评论(0编辑  收藏  举报