BZOJ4260 - Codechef REBXOR(trie)

题意

给定一个N个元素的数组,求任意两个不重叠的连续区间的异或和之和的最大值。

思路

先预处理出异或前缀和,然后用trie树维护预处理出每个位置到左端点的最大的区间异或和。反方向也处理一遍。最后枚举中间的分割位置即可。

#include <bits/stdc++.h>

#define endl '\n'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define FILE freopen(".//data_generator//in.txt","r",stdin),freopen("res.txt","w",stdout)
#define FI freopen(".//data_generator//in.txt","r",stdin)
#define FO freopen("res.txt","w",stdout)
#define pb push_back
#define mp make_pair
#define seteps(N) fixed << setprecision(N) 
typedef long long ll;

using namespace std;
/*-----------------------------------------------------------------*/

ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f

const int N = 5e5 + 10;
const int MX = 1e7 + 10;
const double eps = 1e-5;

int tr[MX][2];
int ct;


void insert(int x) {
    int cur = 0;
    for(int i = 30; i >= 0; i--) {
        bool ntp = (x & (1 << i));
        if(!tr[cur][ntp]) {
            tr[cur][ntp] = ++ct;
            cur = ct;
            tr[cur][0] = tr[cur][1] = 0; 
        } else {
            cur = tr[cur][ntp];
        }
    }
}

int findmax(int x) {
    int cur = 0;
    int res = 0;
    for(int i = 30; i >= 0; i--) {
        bool ntp = (x & (1 << i));
        if(tr[cur][!ntp]) {
            cur = tr[cur][!ntp];
            res ^= (1 << i);
        } else {
            cur = tr[cur][ntp];
        }
    }
    return res;
}

int arr[N];
int pre[N];
int last[N];
int mxp[N];
int mxl[N];

int main() {
    IOS;
    int n;
    cin >> n;
    for(int i = 1; i <= n; i++) {
        cin >> arr[i];
    }
    n++;
    for(int i = 1; i <= n; i++) {
        pre[i] = (pre[i - 1] ^ arr[i]);
    }
    for(int i = n; i >= 1; i--) {
        last[i] = (last[i + 1] ^ arr[i]);
    }
    for(int i = 0; i <= n; i++) {
        insert(pre[i]);
        mxp[i] = max(mxp[i - 1], findmax(pre[i]));
    }   
    ct = 0;
    tr[0][0] = tr[0][1] = 0;
    for(int i = n; i >= 0; i--) {
        insert(last[i]);
        mxl[i] = max(mxl[i - 1], findmax(last[i]));
    } 
    int ans = 0;
    for(int i = 1; i <= n - 1; i++) {
        ans = max(ans, mxp[i] + mxl[i + 1]);
    }
    cout << ans<< endl;
}
posted @ 2020-08-13 22:06  limil  阅读(67)  评论(0编辑  收藏  举报