[Codechef REBXOR]Nikitosh and xor (Trie,异或)

题目传送门

分析:首次考虑暴力枚举 \(l_{1},r_{1},l_{2},r_{2}\),配合前缀和时间复杂度 \(O(N^{4})\),需要想办法优化。对于这种两段区间不重合的,我们考虑枚举两段区间之间的断点,设 \(max\_{l}[x]\)表示由区间 \([1,x]\)所能得到的区间异或最大值, \(max\_{r}[x]\)表示由区间 \([x,n]\)所能得到的区间异或最大值,那么答案即为 \(\max(max\_l[i]+max\_r[i+1])(i \in [1,n))\)。现在要想办法计算 \(max\_l\)\(max\_r\),考虑更新 \(max\_{l}[x]\),不难得出 \(max\_{l}[x] = \max(max\_l[x-1], \max(a_{i} \oplus a_{i+1} \oplus... \oplus a_{x})(i \in [1,x]))\)\(\max(a_{i} \oplus a_{i+1} \oplus... \oplus a_{x})(i \in [1,x])\)通过 \(Trie\)树和异或前缀和即可求出,不会的话可以看下这道题。然后就可以 \(O(N)\)求解了。

#include<cstdio>
#include<algorithm>
using namespace std;
const int N = 4e5 + 5;

struct Trie{
    int root, id;
    bool bit[32];
	
    struct Node{
        int val, siz, ch[2];
        Node(){  ch[0] = ch[1] = -1,  val = siz = 0; }
    }node[N * 32];

    void get(int x){
        for(int i = 0; i < 32; ++i, x >>= 1)  bit[i] = x & 1;
    }

    void init(){
        for(int i = 0; i <= id; ++i){
            node[i].ch[0] = node[i].ch[1] = -1;
            node[i].val = node[i].siz = 0;
        }
        id = root = 0;
    }

    void insert(int x){
        get(x);       
        int u = root;
        for(int i = 31; i >= 0; --i){
            if(node[u].ch[bit[i]] == -1)  node[u].ch[bit[i]] = ++id;
            u = node[u].ch[bit[i]];
            ++node[u].siz;
        }
        node[u].val = x;
    }

    int find(int x){    // 返回与x异或最大的数
        get(x);
        int u = root;
        for(int i = 31; i >= 0; --i){
            int s1 = node[u].ch[!bit[i]], s2 = node[u].ch[bit[i]];
            if(s1 != -1 && node[s1].siz > 0)  u = s1;
            else if(s2 != -1 && node[s2].siz > 0)  u = s2;
            else  return x;	// 注意根据需要调整返回值 
        }  
        return node[u].val;
    }
}trie;

int n, ans;
int a[N], max_l[N], max_r[N], p[N];

void work(){
    int Xor = 0;
    trie.insert(0);
    for(int i = 1; i <= n; ++i){
        Xor ^= a[i];
        p[i] = max(p[i - 1], trie.find(Xor) ^ Xor);        
        trie.insert(Xor);
    }
}

int main(){
    scanf("%d", &n);
    for(int i = 1; i <= n; ++i)  scanf("%d", &a[i]);
    work();
    for(int i = 1; i <= n; ++i)  max_l[i] = p[i];
    for(int l = 1, r = n; l < r; ++l, --r)  swap(a[l], a[r]);
    trie.init();
    work();
    for(int i = n; i; --i)  max_r[i] = p[n - i + 1];
    for(int i = 1; i < n; ++i)  ans = max(ans, max_l[i] + max_r[i + 1]);
    printf("%d", ans);
    return 0;  
}
posted @ 2020-12-04 15:16  のNice  阅读(143)  评论(0编辑  收藏  举报