AtCoder abc 141 F - Xor Sum 3(线性基)

传送门

题意:
给出\(n\)个数\(a_i\),现在要将其分为两堆,使得这两堆数的异或和相加最大。

思路:

  • 考虑线性基贪心求解。
  • 但直接上线性基求出一组的答案是行不通的,原因之后会说。
  • 注意到如果二进制中某一位\(1\)的个数出现了奇数次,那么无论怎么分,都会有一组中这位为\(1\);对于出现偶数次的位,两组中该位都可以有\(1\),或者都没有\(1\)
  • 那么我们只需要贪心地插入二进制\(1\)的个数为偶数的那些位就行了,显然这样能使得最终答案最大。

下面口胡一下为什么不能直接用线性基来搞:
如果贪心地利用线性基直接求出一组答案,假设第\(i\)位二进制出现次数为奇数,那么我们可能就以\(i\)为基底,那么其余偶数位作为基底的"可能性"就降低了,所以我们在插入线性基的时候要避免奇数个数的位,这样能使答案最大。

#include <bits/stdc++.h>
#define fi first
#define se second
#define MP make_pair
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
const int N = 1e5 + 5;

int n;
ll a[N], p[62];
bool chk[62];

void insert(ll x) {
    for(int i = 60; i >= 0; i--) {
        if(chk[i]) continue;
        if(x >> i & 1) {
            if(!p[i]) {
                p[i] = x;
                break;
            }
            x ^= p[i];
        }
    }
}

int main() {
    ios::sync_with_stdio(false); cin.tie(0);
    cin >> n;
    ll all = 0;
    for(int i = 1; i <= n; i++) cin >> a[i], all ^= a[i];
    for(int i = 0; i <= 60; i++) {
        if(all >> i & 1) chk[i] = 1;
    }
    for(int i = 1; i <= n; i++) insert(a[i]);
    ll ans = 0;
    for(int i = 0; i <= 60; i++) {
        if(chk[i])
        for(int j = 0; j <= 60; j++) {
            if(p[j] >> i & 1) {
                p[j] ^= (1ll << i);
            }
        }
    }
    for(int i = 60; i >= 0; i--) {
        if((p[i] ^ ans) > ans) ans = p[i] ^ ans;
    }
    cout << ans + (ans ^ all);
    return 0;
}

P.S:实现的话可以一开始就将\(a\)数组\(chk\)了的位的值减去,就让这些位不参与运算,写起来能更加简洁。
如下:

Code
#include <bits/stdc++.h>
#define fi first
#define se second
#define MP make_pair
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
const int N = 1e5 + 5;
 
int n;
ll a[N], p[62];
 
void insert(ll x) {
    for(int i = 60; i >= 0; i--) {
        if(x >> i & 1) {
            if(!p[i]) {
                p[i] = x;
                break;
            }
            x ^= p[i];
        }
    }
}
 
int main() {
    ios::sync_with_stdio(false); cin.tie(0);
    cin >> n;
    ll all = 0;
    for(int i = 1; i <= n; i++) cin >> a[i], all ^= a[i];
    for(int i = 0; i <= 60; i++) {
        if(all >> i & 1) {
            for(int j = 1; j <= n; j++) {
                if(a[j] >> i & 1) a[j] -= (1ll << i);
            }
        }
    }
    for(int i = 1; i <= n; i++) insert(a[i]);
    ll ans = 0;
    for(int i = 60; i >= 0; i--) {
        if((p[i] ^ ans) > ans) ans = p[i] ^ ans;
    }
    cout << ans + (ans ^ all);
    return 0;
}
posted @ 2019-09-16 14:04  heyuhhh  阅读(623)  评论(0编辑  收藏  举报