异或与区间加题解
异或与区间加题解
简要题意
给定 \(n,m,K,a_{1...n}\),和 \(m\) 个三元组 \((x_i,y_i,z_i)\),定义 \(calc(l,r)=a_l \bigoplus a_{l+1}\bigoplus ...\bigoplus a_r\)。对于每个三元组 \((x,y,z)\) ,对所有满足 \(x\le l\le r\le y\ ,\ calc(l,r)=K\) 的区间 \((l,r)\) 内的每个数 \(b_i\) 加上 \(z\),其中 \(b_{1..n}\) 初始全为 0。输出对 \(2^{30}\) 取模。
\(0\le K,a_i<2^{30},1\le x\le y\le n,0\le z\le 10000\)
10 10 3//n m K
2 0 3 0 1 0 0 2 1 2//a[i]
1 10 1//x y z
3 10 9
10 10 5
4 10 10
9 10 8
7 7 8
3 5 10
7 8 9
7 9 7
7 8 7
1 4 54 53 52 72 99 126 114 39
题解
先来一个暴力的方法。首先容易想到对 \(a\) 求一遍前缀和,将 \(calc(l,r)=K\) 转化为 \(sum_r\bigoplus sum_{l-1}=K\)。将每个三元组按关键字排序(先x后y),然后从前往后扫描每一个区间。然后开一个树状数组,令 \(c_{x..y}\) 加上 \(z\)。\(c_{pos}\) 表示:对于每个右端点位于 \(pos\) 的区间 \((l,r)\),应对的 \(b_{l...r}\) 需要加上 \(c_{pos}\)。但是这样可能会把 \(l<x\) 的区间也进行操作,所以我们应该从前往后扫描每一个位置。在扫描到位置 \(l\) 的时候,如果发现存在三元组满足 \(x=l\),那么我们令 \(c_{x..y}\) 加上 \(z\)。处理完 \(c\) 以后,找到满足 \(sum_r\bigoplus sum_{l-1}=K\) 的 \(r\)(可以利用map找),然后再将 \(b_{l...r}\) 加上 \(c_{r}\),这一个区间加可以用差分处理。
#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int mod=1<<30;
int n,m,K,a[150010],sum[150010];
LL c[150010],cc[150010];
unordered_map<int,vector<int>>mp;
struct SYZ
{int x,y,z;}syz[150010];
inline int read()
{
int x=0,w=0;char ch=0;
while(!isdigit(ch)){w|=ch=='-';ch=getchar();}
while(isdigit(ch)){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return w?-x:x;
}
bool cmp(SYZ n1,SYZ n2)
{
if(n1.x^n2.x)return n1.x<n2.x;
return n1.y<n2.y;
}
void change(int x,int y)
{for(;x<=n;x+=x&-x)c[x]+=y;}
int ask(int x,int y=0)
{for(;x;x-=x&-x)y+=c[x];return y;}
int main()
{
n=read();m=read();K=read();
mp[0].push_back(0);
for(int i=1;i<=n;i++)
mp[sum[i]=sum[i-1]^(a[i]=read())].push_back(i);
for(int i=1;i<=m;i++){
int x=read(),y=read(),z=read();
syz[i]=(SYZ){x,y,z};
}
sort(syz+1,syz+1+m,cmp);
for(auto&x:mp)//x.first是键,x.second是值
reverse(x.second.begin(),x.second.end());
for(int i=1,j=1;i<=n;i++){//i是左端点
while(j<=m&&syz[j].x==i)
change(1,syz[j].z),change(syz[j].y+1,-syz[j].z),j++;
for(int x:mp[sum[i-1]^K]){//x是右端点
if(x<i)break;
int temp=ask(x);
cc[i]+=temp;
cc[x+1]-=temp;
}
}
for(int i=1;i<=n;i++)
printf("%lld%c",(cc[i]+=cc[i-1])%=mod," \n"[i==n]);
}
这里的 \(cc\) 是 \(b\) 的差分数组。
此方法的瓶颈在于:对于一个 \(l\) ,满足条件的 \(r\) 可能会非常多。
我们可以在当 \(r\) 的数量小于 \(\sqrt{n}\) 时用上述方法,当 \(r\) 数量过多时需要换一种方法。值得注意的是,这样不同的 \(sum_r\) 不会超过 \(\sqrt{n}\) 个。
我们不妨对每一个这样的 \(sum_r\) 单独处理,我们先暴力找到所有的 \(l\) 和 \(r\) ,利用前缀和可以计算出区间 \(xx,yy\) 内有多少个 \(l\) 或 \(r\)。
设 \(prez_i\) 表示:满足 \(x\le i\le y\) 的所有三元组的 \(z\) 的和,利用差分可以快速求出。
我们要分别扫描所有的 \(l\),\(r\) ,扫描 \(l\) 时,让 \(cc_l\) 加上一些东西;扫描 \(r\) 时,让 \(cc_{r+1}\) 减去一些东西。
我们不妨先考虑一个弱化版本,即:\(y=n\)。如果此方法可行的话,我们可以试图将 \((x,y,z)\) 拆分成 \((x,n,z)\) 和 \((y+1,n,-z)\)。注意,直接拆分会错误地统计上这样的区间:\(x\le l\le y<r\)。我们需要额外的操作减去这样的贡献。
显然,对于 \((x,y,z)\) 只需要 \(l\ge x\) 即可。我们可以扫描每一个 \(l\) ,计算 \(x\le l\) 的 \(z\) 的和,以及 \(r\) 的数量。前者即是 \(prez_l\),后者用差分统计即可。有:\(cc_l+=cnt(r)*prez_l\)。同样地,对于 \(r\) 我们沿用类似的方法,但稍微麻烦点。对于 \(cc_{r+1}\) 我们要减去的是:每一个 \(x\le l\) 对应的 \(z\)。这个可以一遍扫描一遍统计,初始令 \(temp=0\) ,从左往右扫描时,如果 \(pos\) 是左端点,则令 \(temp+=prez_{pos}\)。当扫描到一个右端点 \(r\) 时,令 \(cc_{r+1}-=temp\)。意思是遇到一个左端点 \(l\),那么它后面的右端点统统加上它左边的三元组的 \(z\) (即 \(prez_l\))。
#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int mod=1<<30;
int n,m,K,B,a[150010],sum[150010];
int sX[150010],sY[150010];
LL c[150010],cc[150010],prez[150010];
unordered_map<int,vector<int>>mp;
struct SYZ
{int x,y,z;}syz[150010];
inline int read()
{
int x=0,w=0;char ch=0;
while(!isdigit(ch)){w|=ch=='-';ch=getchar();}
while(isdigit(ch)){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return w?-x:x;
}
bool cmp(SYZ n1,SYZ n2)
{
if(n1.x^n2.x)return n1.x<n2.x;
return n1.y<n2.y;
}
void change(int x,int y)
{for(;x<=n;x+=x&-x)c[x]+=y;}
int ask(int x,int y=0)
{for(;x;x-=x&-x)y+=c[x];return y;}
void solve(int Y)//Y=sum[r]
{
int X=K^Y;//X=sum[l-1]
for(int i=1;i<=n;i++){
sX[i]=sX[i-1]+(sum[i-1]==X);
sY[i]=sY[i-1]+(sum[i]==Y);
}
for(int i=1;i<=n;i++)
if(sum[i-1]==X)
cc[i]+=prez[i]*(sY[n]-sY[i-1]);
LL temp=0;
for(int i=1;i<=n;i++){
if(sum[i-1]==X)temp+=prez[i];
if(sum[i]==Y)cc[i+1]-=temp;
}
}
int main()
{
n=read();m=read();K=read();B=sqrt(n);
mp[0].push_back(0);
for(int i=1;i<=n;i++)
mp[sum[i]=sum[i-1]^(a[i]=read())].push_back(i);
for(int i=1;i<=m;i++){
int x=read(),y=read(),z=read();
syz[i]=(SYZ){x,y,z};
prez[x]+=z;prez[y+1]-=z;
}
for(int i=1;i<=n;i++)
prez[i]+=prez[i-1];
sort(syz+1,syz+1+m,cmp);
for(auto&x:mp)//x.first是键,x.second是值
reverse(x.second.begin(),x.second.end());
for(int i=1,j=1;i<=n;i++){//i是左端点
while(j<=m&&syz[j].x==i)
change(1,syz[j].z),change(syz[j].y+1,-syz[j].z),j++;
if(mp[sum[i-1]^K].size()<B)
for(int x:mp[sum[i-1]^K]){//x是右端点
if(x<i)break;
int temp=ask(x);
cc[i]+=temp;
cc[x+1]-=temp;
}
}
for(auto&x:mp)
if(x.second.size()>=B)
solve(x.first);
for(int i=1;i<=n;i++)
printf("%lld%c",(cc[i]+=cc[i-1])%=mod," \n"[i==n]);
}
现在我们考虑怎么样拆分一个三元组,以及如何处理错误统计的 \(l,r\) 。
对于 \((x,y,z)\) ,在处理 \(cc_l\) 时,我们不想让 \(y<r\) 的那些区间统计上 \(z\)。我么需要新开一个数组,统计上需要减去的这些 \(z\)。(此时树状数组的 \(c\) 数组已经没用了我们不如再次利用 \(c\))我们令 \(c_x+=z*cnt(r)\),这些 \(r\) 要满足 \(r>y\)。同时令 \(c_{y+1}-=z*cnt(r)\)。这个意思是:在扫描到 \(l\ge x\) 的 \(l\) 时,统计的答案要减去 \(c_x\),因为多出来的 \(r\) 不应该统计上去。统计到 \(y\) 后面的 \(l\) 时,不用减去这些了,因为本来就没有统计上(不明白为什么没统计上的话,可以看一下\(prez\))。
在处理 \(cc_{r+1}\) 时,我们应当减去 \(x\le l\le y<r\) 对应的 \(z\)。我们在扫描三元组 \((x,y,z)\) 时,令 \(c[y+1]+=z*cnt(l)\),这些 \(l\) 满足 \(x\le l \le y\)。意思是,扫描到 \(r\ge y+1\) 的 \(r\) 时,统计的答案要少减去 \(c[y+1]\)。因为,前面对应的那些三元组,贡献要减少 \(c[y+1]\) ,因为 \(r\) 越界了,那些 \(l\) 不会和这个 \(r\) 产生贡献。
#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int mod=1<<30;
int n,m,K,B,a[150010],sum[150010];
int sX[150010],sY[150010];
LL c[150010],cc[150010],prez[150010];
unordered_map<int,vector<int>>mp;
struct SYZ
{int x,y,z;}syz[150010];
inline int read()
{
int x=0,w=0;char ch=0;
while(!isdigit(ch)){w|=ch=='-';ch=getchar();}
while(isdigit(ch)){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return w?-x:x;
}
bool cmp(SYZ n1,SYZ n2)
{
if(n1.x^n2.x)return n1.x<n2.x;
return n1.y<n2.y;
}
void change(int x,int y)
{for(;x<=n;x+=x&-x)c[x]+=y;}
int ask(int x,int y=0)
{for(;x;x-=x&-x)y+=c[x];return y;}
void solve(int Y)//Y=sum[r]
{
int X=K^Y;//X=sum[l-1]
for(int i=1;i<=n;i++){
sX[i]=sX[i-1]+(sum[i-1]==X);
sY[i]=sY[i-1]+(sum[i]==Y);
}
memset(c,0,sizeof c);
for(int i=1;i<=m;i++){
int x=syz[i].x,y=syz[i].y,z=syz[i].z;
c[x]+=1ll*z*(sY[n]-sY[y]);
c[y+1]-=1ll*z*(sY[n]-sY[y]);
}
for(int i=1;i<=n;i++){
c[i]+=c[i-1];
if(sum[i-1]==X)
cc[i]+=prez[i]*(sY[n]-sY[i-1])-c[i];
}
memset(c,0,sizeof c);
for(int i=1;i<=m;i++){
int x=syz[i].x,y=syz[i].y,z=syz[i].z;
c[y+1]+=1ll*z*(sX[y]-sX[x-1]);
}
LL temp=0;
for(int i=1;i<=n;i++){
c[i]+=c[i-1];
if(sum[i-1]==X)temp+=prez[i];
if(sum[i]==Y)
cc[i+1]-=temp-c[i];
}
}
int main()
{
n=read();m=read();K=read();B=sqrt(n);
mp[0].push_back(0);
for(int i=1;i<=n;i++)
mp[sum[i]=sum[i-1]^(a[i]=read())].push_back(i);
for(int i=1;i<=m;i++){
int x=read(),y=read(),z=read();
syz[i]=(SYZ){x,y,z};
prez[x]+=z;prez[y+1]-=z;
}
for(int i=1;i<=n;i++)
prez[i]+=prez[i-1];
sort(syz+1,syz+1+m,cmp);
for(auto&x:mp)//x.first是键,x.second是值
reverse(x.second.begin(),x.second.end());
for(int i=1,j=1;i<=n;i++){//i是左端点
while(j<=m&&syz[j].x==i)
change(1,syz[j].z),change(syz[j].y+1,-syz[j].z),j++;
if(mp[sum[i-1]^K].size()<B)
for(int x:mp[sum[i-1]^K]){//x是右端点
if(x<i)break;
int temp=ask(x);
cc[i]+=temp;
cc[x+1]-=temp;
}
}
for(auto&x:mp)
if(x.second.size()>=B)
solve(x.first);
for(int i=1;i<=n;i++)
printf("%lld%c",(cc[i]+=cc[i-1])%=mod," \n"[i==n]);
}