9.22 xor三元组计数

题意

给定一个长为\(N\)的序列\(a\),请求出所有满足下列条件的三元组\(<x,y,z>\)

  • \(1\leq x < y < z \leq N\)
  • \(a_x \oplus a_y < a_y \oplus a_z\)

这里的\(\oplus\)运算指按位异或


解法

异或\(+\)序列问题\(\to\) 01Trie

很自然的想到枚举\(y\),对\(1\to y-1\)\(y+1\to N\)分别维护一颗Trie,计算符合条件的\(<x,z>\)二元组个数

我们可以发现,如果按位考虑的话,比较\(x,\)\(z\)\(y\)的异或值大小,实际上起到决定作用的是\(x,z\)的二进制位从高到低第一个不同的位

暴力Trie上DFS的复杂度显然是不对的,可能达到指数级别

由于动态维护Trie每次只增删一条链,考虑统计这条链所带来的影响

\(f[i][0/1]\)为从高位到低位考虑,前\(i-1\)位相同,第\(i\)位不同的\(<x,z>\)对数。其中,第二维为\(0\)代表\(x\)的第\(i\)位为\(0\)\(z\)的第\(i\)位为\(1\)(所以我们的Trie是由高位到低位建立的)

这样我们只需在Trie上动态维护\(f\)数组,统计答案时直接用\(f\)数组更新即可


代码

#include <cstdio>
#include <cctype>
#include <cstring>

using namespace std;

const int MAX_N = 1e5 + 10;
const int lg = 30;

int read();

int a[MAX_N];

long long ans;
long long f[lg + 1][2];

struct Trie {
	
	int root, cnt;
	
	struct node {
		int sum;
		int ch[2];	
	} t[MAX_N * lg];
	
	void clear() {
		for (int i = 1; i <= cnt; ++i)  
			t[i].ch[0] = t[i].ch[1] = t[i].sum = 0;
		cnt = root = 1;
	}
	
	void ins(int x) {
		int p = root;
		for (int i = lg; i >= 0; --i) {
			int c = x >> i & 1;
			if (!t[p].ch[c])
				t[p].ch[c] = ++cnt;
			p = t[p].ch[c];	
			t[p].sum++;
		}
	}
	
	void era(int x) {
		int p = root;
		for (int i = lg; i >= 0; --i) {
			int c = x >> i & 1;
			p = t[p].ch[c];
			t[p].sum--;
		}
	}
		
} tr_a, tr_b;

// f[x][0] : pre 0 suf 1
// f[x][1] : pre 1 suf 0
void add(int x) {
	int p = 1;
	for (int i = lg; i >= 0; --i) {
		int c = x >> i & 1;
		f[i][c] += tr_b.t[tr_b.t[p].ch[c ^ 1]].sum;
		p = tr_b.t[p].ch[c];
	}
}

void del(int x) {
	int p = 1;
	for (int i = lg; i >= 0; --i) {
		int c = x >> i & 1;
		f[i][c ^ 1] -= tr_a.t[tr_a.t[p].ch[c ^ 1]].sum;
		p = tr_a.t[p].ch[c];
	}
}

int main() {
	
//	freopen("xyz.in", "r", stdin);
//	freopen("xyz.out", "w", stdout);
	
	int T = read();
	
	while (T--) {
		
		int N = read();
		for (int i = 1; i <= N; ++i)  a[i] = read();
		
		ans = 0; 
		tr_a.clear(), tr_b.clear();
		memset(f, 0, sizeof f);
		
		for (int i = 2; i <= N; ++i)  tr_b.ins(a[i]);
		
		for (int i = 2; i < N; ++i) {
			del(a[i]), tr_b.era(a[i]);
			add(a[i - 1]), tr_a.ins(a[i - 1]);
			for (int j = lg; j >= 0; --j) 
				ans += f[j][a[i] >> j & 1];
		}
		
		printf("%lld\n", ans);
	}
	
	return 0;
}

int read() {
	int x = 0, c = getchar();
	while (!isdigit(c))  c = getchar();
	while (isdigit(c))   x = x * 10 + c - 48, c = getchar();
	return x;
}
posted @ 2019-10-02 19:47  四季夏目天下第一  阅读(144)  评论(1编辑  收藏  举报