YBTOJ [树状数组] 二进制
哇咔咔,此乃真好题!
Solution
首先不带 \(+x\) 的做法,相信大家都会,维护一下全局二进制每一位 \(1\) 的个数,把 \(y\) 二进制拆分一下,就知道答案了。
这个 \(+x\) 真滴很恶心啊!
考虑这样一个事实,非常滴实用:
对于一个 \(x\) \(and\) \(2^k\) 这个东西的答案是有周期性的!也就是说从 \(0\) 开始会有 \(2^k\) 个 \(0\), 然后是 \(2^k\) 个 \(2^k\) 因为这第 \(k\) 位变成了 \(1\) ,从二进制加法的角度就非常好理解。然后呢这个东西就等价于 \(x\) \(and\) \(2^k\) 有值当且仅当 \(x \mod{2^{k+1}} \in \left [2^k,2^{k+1} - 1 \right ]\) 对吧这代表了循环节。
所以啊!我们的答案就变成了每个二进制位上的询问,分开考虑。
问题回到 \(+x\) 上,我们对那个式子进行推导:
所以我们可以开 \(20\) 个树状数组,第 \(i\) 个树状数组下标位 \(j\) 的位置代表全局的 \(a[i] \mod {2^{i+1}}\) 有多少值为 \(j-1\) 的(因为树状数组下标从 \(1\) 开始嘛)。所以树状数组相当于对前面的值域做了个前缀和,然后根据我们推导的式子就变成了前缀和相减的形式。
然后经过取模我们可以避免负数的问题,但是会出现 \(r > l-1\) 的问题,但是没关系啊!我们是 $$循环节$$ 所以就变成左区间到 \(r\) 和 \(l\)到右区间的答案,都可以快速统计。
对于修改操作,我们扣掉原来的,加进去新的即可,简单易懂。
修改复杂度 \(O(\log^2{n})\)
查询复杂度 \(O(\log^2{n})\)
总复杂度 \(O(n\log^2{n} + m\log^2{n})\)
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define Rep(i,a,b) for(int i=(a);i<(b);++i)
#define rrep(i,a,b) for(int i=(a);i>=(b);--i)
using namespace std;
template <typename T>
inline void read(T &x){
x=0;char ch=getchar();bool f=0;
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=-x;
}
template <typename T,typename ...Args>
inline void read(T &tmp,Args &...tmps){read(tmp);read(tmps...);}
const int N = (1<<20) + 5;
const int M = 1e5 + 5;
int n,m,a[M];
struct BIT{
int c[N],len;
inline int lowbit(int x){return x & (-x);}
inline void upd(int x,int y){
++x;
for(;x<=len;x+=lowbit(x))c[x] += y;
}
inline int query(int x){
++x;
int res = 0;
for(;x;x-=lowbit(x))res += c[x];
return res;
}
}T[22];
#define f(x) ((1<<(x+1)) - 1)
#define p(x) (1<<(x))
signed main(){
read(n,m);
rep(i,1,n){
read(a[i]);
rep(j,0,19){
T[j].len = (1<<(j+1));
T[j].upd(a[i]%p(j+1),1);
}
}
while(m--){
int op,x,y;
read(op,x,y);
if(op == 1){
rep(j,0,19){
T[j].upd(a[x]%p(j+1),-1);
}
a[x] = y;
rep(j,0,19)T[j].upd(a[x]%p(j+1),1);
}
else{
long long ans = 0;
rep(j,0,19){
if(y&(1<<j)){
int l = (1<<j),r = f(j);
l = (l - 1 - x + p(20)) % p(j+1);
r = (r - x + p(20)) % p(j+1);
if(l <= r)ans += 1ll * (1<<j) * (T[j].query(r) - T[j].query(l));
else ans += 1ll * (1<<j) * (T[j].query(T[j].len-1) + T[j].query(r) - T[j].query(l));
}
}
printf("%lld\n",ans);
}
}
}