并不对劲的loj2251:p3688[ZJOI2017]树状数组
题目大意
有人用错误的树状数组维护长度为\(n\)的01串\(a_1,...,a_n\)的区间异或和。
错误的树状数组:
void add(int x){for(;x;x-=lowbit(x))tr[x]^=1;return;}
int ask(int x){int k=0;for(;x<=n;x+=lowbit(x))k^=tr[x];return k;}
int query(int l,int r){return ask(l-1)^ask(r);}
有\(m\)次操作,操作分两种:
1.给定\(l,r\),在\(a_l,a_{l+1},...,a_r\)中随机选一个修改;
2.给定\(l,r\),问求\(a_l,a_{l+1},...,a_r\)的异或和,答案正确的概率是多少;
\(n,m\leq 10^5;\)
题解
发现这树状数组的ask求的是后缀异或和,所以query求的是:
\(l=1\)时,\(ask(l-1)=0\),query求的是\(a_r,...,a_n\)的异或和;\(l\neq 1\)时,\(ask(l-1)\bigoplus ask(r)=a_{l-1},...,a_{r-1}的异或和\)。
\(l=1\)时,答案正确的概率=\((a_r\bigoplus...\bigoplus a_n=a_1\bigoplus ...\bigoplus a_r)\)的概率=\(((a_r\bigoplus...\bigoplus a_n)\bigoplus (a_1\bigoplus ...\bigoplus a_r)=0)\)的概率。由\((a_r\bigoplus...\bigoplus a_n)\bigoplus (a_1\bigoplus ...\bigoplus a_r)=a_r\bigoplus(a_1\bigoplus...\bigoplus a_n)\),可知答案正确的概率=\((a_r\bigoplus...\bigoplus(a_1\bigoplus...\bigoplus a_n))=0\)的概率。
\(l\neq 1\)时,答案正确的概率=\((a_{l-1}=a_r)\)的概率。
可以用第一维是左端点、第二维是右端点、存左右端点相等的概率的树套树维护。
代码
#include<algorithm>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<iomanip>
#include<iostream>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<vector>
#define rep(i,x,y) for(register int i=(x);i<=(y);++i)
#define dwn(i,x,y) for(register int i=(x);i>=(y);--i)
#define view(u,k) for(int k=fir[u];~k;k=nxt[k])
#define maxn 100007
#define maxnd 20000007
#define LL long long
#define ls (u<<1)
#define rs (u<<1|1)
#define mi ((l+r)>>1)
#define lc son[u][0]
#define rc son[u][1]
using namespace std;
int read()
{
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)&&ch!='-')ch=getchar();
if(ch=='-')f=-1,ch=getchar();
while(isdigit(ch))x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
return x*f;
}
void write(int x)
{
if(x==0){putchar('0'),putchar('\n');return;}
int f=0;char ch[20];
if(x<0)putchar('-'),x=-x;
while(x)ch[++f]=x%10+'0',x/=10;
while(f)putchar(ch[f--]);
putchar('\n');
return;
}
int n,m,son[maxnd][2],mk[maxnd],cntnd,rt[maxn<<2],res,cnt;
const int mod=998244353;
int mo(int x){if(x<0)return x+mod;return x>=mod?x-mod:x;}
int merge(int x,int y){return mo((LL)x*mo(1-y)%mod+(LL)y*mo(1-x)%mod);}
int mul(int x,int y){int res=1;while(y){if(y&1)res=(LL)res*x%mod;x=(LL)x*x%mod,y>>=1;}return res;}
void add2(int & u,int l,int r,int l2,int r2,int k)
{
if(!u)u=++cntnd;
if(l2<=l&&r<=r2){mk[u]=merge(mk[u],k);return;}
if(l2<=mi)add2(lc,l,mi,l2,r2,k);
if(r2>mi)add2(rc,mi+1,r,l2,r2,k);
return;
}
void add1(int u,int l,int r,int l1,int r1,int l2,int r2,int k)
{
if(l1<=l&&r<=r1){add2(rt[u],1,n,l2,r2,k);return;}
if(l1<=mi)add1(ls,l,mi,l1,r1,l2,r2,k);
if(r1>mi)add1(rs,mi+1,r,l1,r1,l2,r2,k);
return;
}
void ask2(int u,int l,int r,int x2)
{
if(!u)return;
res=merge(res,mk[u]);
if(l==r)return;
if(x2<=mi)ask2(lc,l,mi,x2);
else ask2(rc,mi+1,r,x2);
}
void ask1(int u,int l,int r,int x1,int x2)
{
ask2(rt[u],1,n,x2);
if(l==r)return;
if(x1<=mi)ask1(ls,l,mi,x1,x2);
else ask1(rs,mi+1,r,x1,x2);
return;
}
int main()
{
n=read(),m=read();
while(m--)
{
int f=read(),l=read(),r=read();
if(f==1)
{
if(l<r)add1(1,1,n,l+1,r+1,l,r,mo(2*mul(r-l+1,mod-2)));
if(r+1<=n)add1(1,1,n,l+1,r+1,r+1,n,mul(r-l+1,mod-2));
add1(1,1,n,1,l,l,r,mul(r-l+1,mod-2));cnt++;
}
else
{
res=0;ask1(1,1,n,l,r);
if(l==1)write((cnt&1)?res:mo(1-res));
else write(mo(1-res));
}
}
return 0;
}