luogu4883 mzf的考验
题目描述:
题解:
当然splay。
区间翻转是基本操作。
区间异或?按套路记录区间内每一位$1$的个数,异或的时候按位取反即可。
区间查询同理。
因为要按位维护,所以复杂度多了个log。
不开O2只有30,开O2能过。
代码:
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; typedef long long ll; const int N = 100050; template<typename T> inline void read(T&x) { T f = 1,c = 0;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){c=c*10+ch-'0';ch=getchar();} x = f*c; } int n,m,a[N]; struct Splay { int fa[N],ch[N][2],v[N],siz[N],tag[N],mp[N][22],rt; bool res[N]; inline void rever(int u){swap(ch[u][0],ch[u][1]);res[u]^=1;} inline void add(int u,int k) { if(!u)return ; tag[u]^=k;v[u]^=k; for(int i=0;i<20;i++)if(k&(1<<i)) mp[u][i]=siz[u]-mp[u][i]; } inline void update(int u) { siz[u] = siz[ch[u][0]]+siz[ch[u][1]]+1; for(int i=0;i<20;i++) mp[u][i]=mp[ch[u][0]][i]+mp[ch[u][1]][i]+((v[u]>>i)&1); } inline void pushdown(int u) { if(tag[u]) { add(ch[u][0],tag[u]); add(ch[u][1],tag[u]); tag[u]=0; } if(res[u]) { rever(ch[u][0]); rever(ch[u][1]); res[u]=0; } } void rotate(int x) { int y = fa[x],z = fa[y],k = (ch[y][1]==x); ch[z][ch[z][1]==y] = x,fa[x] = z; ch[y][k] = ch[x][!k],fa[ch[x][!k]] = y; ch[x][!k] = y,fa[y] = x; update(y),update(x); } void down(int x) { if(fa[x])down(fa[x]); pushdown(x); } void splay(int x,int goal) { down(x); while(fa[x]!=goal) { int y = fa[x],z = fa[y]; if(z!=goal) (ch[y][1]==x)^(ch[z][1]==y)?rotate(x):rotate(y); rotate(x); } if(!goal)rt = x; } int build(int l,int r,int f) { if(l>r)return 0; int mid = (l+r)>>1; ch[mid][0] = build(l,mid-1,mid); ch[mid][1] = build(mid+1,r,mid); fa[mid] = f,v[mid] = a[mid-1]; update(mid); return mid; } int get_kth(int u,int k) { pushdown(u); int tmp = siz[ch[u][0]]; if(k<=tmp)return get_kth(ch[u][0],k); else if(k==tmp+1)return u; else return get_kth(ch[u][1],k-1-tmp); } void rvs(int l,int r) { int lp = get_kth(rt,l),rp = get_kth(rt,r+2); splay(lp,0),splay(rp,lp); rever(ch[rp][0]); } void ins(int l,int r,int d) { int lp = get_kth(rt,l),rp = get_kth(rt,r+2); splay(lp,0),splay(rp,lp); add(ch[rp][0],d); } ll query(int l,int r) { int lp = get_kth(rt,l),rp = get_kth(rt,r+2); splay(lp,0),splay(rp,lp); ll ret = 0; for(int i=0;i<20;i++) ret+=(1ll<<i)*mp[ch[rp][0]][i]; return ret; } }tr; int main() { read(n),read(m); for(int i=1;i<=n;i++)read(a[i]); tr.rt=tr.build(1,n+2,0); for(int op,l,r,d,i=1;i<=m;i++) { read(op),read(l),read(r); if(op==1)tr.rvs(l,r); else if(op==2){read(d);tr.ins(l,r,d);} else printf("%lld\n",tr.query(l,r)); } return 0; }