线段树教做人系列(2)HDU 4867 XOR
题意:给你一个数组a,长度为。有两种操作。一种是改变数组的某个元素的值,一种是满足某种条件的数组b有多少种。条件是:b[i] <= a[i],并且b[1]^b[2]...^b[n] = k的数组有多少种。数组a的元素都小于1000.
思路:因为数很小,我们把数变成二进制数,然后拆分二进制数。比如1101可以拆成10xx,x可为0可为1,有点像数位dp的试填法。我们对每个a[i]存若干个前缀,记录前缀的长度,以及是这个前缀的二进制数有多少个。然后我们合并相邻的区间,直接暴力二重循环,然后合并。其实这个过程更像dp的过程,感觉只是用了线段树的划分成区间,然后合并的思想。
思路参考这两篇博客:https://blog.csdn.net/jtjy568805874/article/details/56488626, https://blog.csdn.net/qian99/article/details/38171951
代码:
#include <bits/stdc++.h> #define pii pair<int, int> #define lowbit(x) (x & (-x)) #define mk make_pair #define ls(x) (x << 1) #define rs(x) ((x << 1) | 1) #define LL long long using namespace std; const LL mod = 1000000007; const int maxn = 20010; int Log[2010], a[maxn]; struct node { pii x; LL tot; bool operator < (const node& rhs) const { return x < rhs.x; } }; vector<node> tr[maxn * 4], tmp; int get(int x, int y) { tr[x].clear(); tr[x].push_back((node){mk(y, 10), 1}); for (int i = y; i; i -= lowbit(i)) { tr[x].push_back((node){mk(i - lowbit(i), 10 - Log[lowbit(i)]), lowbit(i)}); } } void pushup(int x, int l, int r) { tmp.clear(); tr[x].clear(); for (int i = 0; i < tr[l].size(); i++) { for (int j = 0; j < tr[r].size(); j++) { node t1 = tr[l][i], t2 = tr[r][j]; int now = t1.x.first ^ t2.x.first, len = min(t1.x.second, t2.x.second); now = ((now >> (10 - len)) << (10 - len)); tmp.push_back((node){mk(now, len), (t1.tot * t2.tot) % mod}); } } sort(tmp.begin(), tmp.end()); for (int i = 0, j; i < tmp.size(); i = j) { node tmp1 = (node){tmp[i].x, 0}; for (j = i; j < tmp.size() && tmp[i].x == tmp[j].x; j++) { tmp1.tot = (tmp1.tot + tmp[j].tot) % mod; } tr[x].push_back(tmp1); } } void build(int x, int l, int r) { if(l == r) { get(x, a[l]); return; } int mid = (l + r) >> 1; build(ls(x), l ,mid); build(rs(x), mid + 1, r); pushup(x, ls(x), rs(x)); } LL inv(int x) { return x == 1 ? 1 : 1ll * inv(mod % x)*(mod - mod / x) % mod; } void update(int x, int l, int r, int pos, int val) { if(l == r) { get(x, val); return; } int mid = (l + r) >> 1; if(pos <= mid) update(ls(x), l, mid, pos, val); else update(rs(x), mid + 1, r, pos, val); pushup(x, ls(x), rs(x)); } LL query(int x) { LL ans = 0; for (int i = 0; i < tr[1].size(); i++) { node t = tr[1][i]; if((t.x.first ^ x) >> (10 - t.x.second)) continue; ans = (ans + (t.tot * inv(1 << (10 - t.x.second))) % mod) % mod; } return ans; } int main() { int n, m; int x, y; char s[10]; for (int i = 1; i <= 10; i++) Log[1 << i] = i; int T; cin >> T; while(T--) { scanf("%d%d", &n, &m); for (int i = 1; i <= n; i++) scanf("%d", &a[i]); build(1, 1, n); while(m--) { scanf("%s",s + 1); if(s[1] == 'Q') { scanf("%d", &x); printf("%lld\n", query(x)); } else { scanf("%d%d", &x, &y); update(1, 1, n, x + 1, y); } } } }