异或与区间加题解

异或与区间加题解

简要题意

给定 \(n,m,K,a_{1...n}\),和 \(m\) 个三元组 \((x_i,y_i,z_i)\),定义 \(calc(l,r)=a_l \bigoplus a_{l+1}\bigoplus ...\bigoplus a_r\)。对于每个三元组 \((x,y,z)\) ,对所有满足 \(x\le l\le r\le y\ ,\ calc(l,r)=K\) 的区间 \((l,r)\) 内的每个数 \(b_i\) 加上 \(z\),其中 \(b_{1..n}\)​​ 初始全为 0。输出对 \(2^{30}\) 取模。

\(0\le K,a_i<2^{30},1\le x\le y\le n,0\le z\le 10000\)

10 10 3//n m K
2 0 3 0 1 0 0 2 1 2//a[i]
1 10 1//x y z
3 10 9
10 10 5
4 10 10
9 10 8
7 7 8
3 5 10
7 8 9
7 9 7
7 8 7
1 4 54 53 52 72 99 126 114 39

题解

先来一个暴力的方法。首先容易想到对 \(a\) 求一遍前缀和,将 \(calc(l,r)=K\) 转化为 \(sum_r\bigoplus sum_{l-1}=K\)。将每个三元组按关键字排序(先x后y),然后从前往后扫描每一个区间。然后开一个树状数组,令 \(c_{x..y}\) 加上 \(z\)\(c_{pos}\) 表示:对于每个右端点位于 \(pos\) 的区间 \((l,r)\),应对的 \(b_{l...r}\) 需要加上 \(c_{pos}\)。但是这样可能会把 \(l<x\) 的区间也进行操作,所以我们应该从前往后扫描每一个位置。在扫描到位置 \(l\) 的时候,如果发现存在三元组满足 \(x=l\),那么我们令 \(c_{x..y}\) 加上 \(z\)。处理完 \(c\) 以后,找到满足 \(sum_r\bigoplus sum_{l-1}=K\)\(r\)(可以利用map找),然后再将 \(b_{l...r}\) 加上 \(c_{r}\),这一个区间加可以用差分处理。

#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int mod=1<<30;
int n,m,K,a[150010],sum[150010];
LL c[150010],cc[150010];
unordered_map<int,vector<int>>mp;
struct SYZ
{int x,y,z;}syz[150010];
inline int read()
{
	int x=0,w=0;char ch=0;
	while(!isdigit(ch)){w|=ch=='-';ch=getchar();}
	while(isdigit(ch)){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return w?-x:x;
}
bool cmp(SYZ n1,SYZ n2)
{
	if(n1.x^n2.x)return n1.x<n2.x;
	return n1.y<n2.y;
}
void change(int x,int y)
{for(;x<=n;x+=x&-x)c[x]+=y;}
int ask(int x,int y=0)
{for(;x;x-=x&-x)y+=c[x];return y;}
int main()
{
	n=read();m=read();K=read();
	mp[0].push_back(0);
	for(int i=1;i<=n;i++)
		mp[sum[i]=sum[i-1]^(a[i]=read())].push_back(i);
	for(int i=1;i<=m;i++){
		int x=read(),y=read(),z=read();
		syz[i]=(SYZ){x,y,z};
	}
	sort(syz+1,syz+1+m,cmp);
	for(auto&x:mp)//x.first是键,x.second是值
		reverse(x.second.begin(),x.second.end());
	for(int i=1,j=1;i<=n;i++){//i是左端点
		while(j<=m&&syz[j].x==i)
			change(1,syz[j].z),change(syz[j].y+1,-syz[j].z),j++;
		for(int x:mp[sum[i-1]^K]){//x是右端点
			if(x<i)break;
			int temp=ask(x);
			cc[i]+=temp;
			cc[x+1]-=temp;
		}
	}
	for(int i=1;i<=n;i++)
		printf("%lld%c",(cc[i]+=cc[i-1])%=mod," \n"[i==n]);
}

这里的 \(cc\)\(b\) 的差分数组。

此方法的瓶颈在于:对于一个 \(l\) ,满足条件的 \(r\) 可能会非常多。

我们可以在当 \(r\) 的数量小于 \(\sqrt{n}\) 时用上述方法,当 \(r\) 数量过多时需要换一种方法。值得注意的是,这样不同的 \(sum_r\) 不会超过 \(\sqrt{n}\) 个。

我们不妨对每一个这样的 \(sum_r\) 单独处理,我们先暴力找到所有的 \(l\)\(r\) ,利用前缀和可以计算出区间 \(xx,yy\) 内有多少个 \(l\)\(r\)

\(prez_i\) 表示:满足 \(x\le i\le y\) 的所有三元组的 \(z\) 的和,利用差分可以快速求出。

我们要分别扫描所有的 \(l\),\(r\) ,扫描 \(l\) 时,让 \(cc_l\) 加上一些东西;扫描 \(r\) 时,让 \(cc_{r+1}\) 减去一些东西。

我们不妨先考虑一个弱化版本,即:\(y=n\)。如果此方法可行的话,我们可以试图将 \((x,y,z)\) 拆分成 \((x,n,z)\)\((y+1,n,-z)\)。注意,直接拆分会错误地统计上这样的区间:\(x\le l\le y<r\)。我们需要额外的操作减去这样的贡献。

显然,对于 \((x,y,z)\) 只需要 \(l\ge x\) 即可。我们可以扫描每一个 \(l\) ,计算 \(x\le l\)\(z\) 的和,以及 \(r\) 的数量。前者即是 \(prez_l\),后者用差分统计即可。有:\(cc_l+=cnt(r)*prez_l\)。同样地,对于 \(r\) 我们沿用类似的方法,但稍微麻烦点。对于 \(cc_{r+1}\) 我们要减去的是:每一个 \(x\le l\) 对应的 \(z\)。这个可以一遍扫描一遍统计,初始令 \(temp=0\) ,从左往右扫描时,如果 \(pos\) 是左端点,则令 \(temp+=prez_{pos}\)。当扫描到一个右端点 \(r\) 时,令 \(cc_{r+1}-=temp\)。意思是遇到一个左端点 \(l\),那么它后面的右端点统统加上它左边的三元组的 \(z\) (即 \(prez_l\))。

#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int mod=1<<30;
int n,m,K,B,a[150010],sum[150010];
int sX[150010],sY[150010];
LL c[150010],cc[150010],prez[150010];
unordered_map<int,vector<int>>mp;
struct SYZ
{int x,y,z;}syz[150010];
inline int read()
{
	int x=0,w=0;char ch=0;
	while(!isdigit(ch)){w|=ch=='-';ch=getchar();}
	while(isdigit(ch)){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return w?-x:x;
}
bool cmp(SYZ n1,SYZ n2)
{
	if(n1.x^n2.x)return n1.x<n2.x;
	return n1.y<n2.y;
}
void change(int x,int y)
{for(;x<=n;x+=x&-x)c[x]+=y;}
int ask(int x,int y=0)
{for(;x;x-=x&-x)y+=c[x];return y;}
void solve(int Y)//Y=sum[r]
{
	int X=K^Y;//X=sum[l-1]
	for(int i=1;i<=n;i++){
		sX[i]=sX[i-1]+(sum[i-1]==X);
		sY[i]=sY[i-1]+(sum[i]==Y);
	}
	for(int i=1;i<=n;i++)
	if(sum[i-1]==X)
		cc[i]+=prez[i]*(sY[n]-sY[i-1]);
	LL temp=0;
	for(int i=1;i<=n;i++){
		if(sum[i-1]==X)temp+=prez[i];
		if(sum[i]==Y)cc[i+1]-=temp;
	}
}
int main()
{
	n=read();m=read();K=read();B=sqrt(n);
	mp[0].push_back(0);
	for(int i=1;i<=n;i++)
		mp[sum[i]=sum[i-1]^(a[i]=read())].push_back(i);
	for(int i=1;i<=m;i++){
		int x=read(),y=read(),z=read();
		syz[i]=(SYZ){x,y,z};
		prez[x]+=z;prez[y+1]-=z;
	}
	for(int i=1;i<=n;i++)
		prez[i]+=prez[i-1];
	sort(syz+1,syz+1+m,cmp);
	for(auto&x:mp)//x.first是键,x.second是值
		reverse(x.second.begin(),x.second.end());
	for(int i=1,j=1;i<=n;i++){//i是左端点
		while(j<=m&&syz[j].x==i)
			change(1,syz[j].z),change(syz[j].y+1,-syz[j].z),j++;
		if(mp[sum[i-1]^K].size()<B)
		for(int x:mp[sum[i-1]^K]){//x是右端点
			if(x<i)break;
			int temp=ask(x);
			cc[i]+=temp;
			cc[x+1]-=temp;
		}
	}
	for(auto&x:mp)
	if(x.second.size()>=B)
		solve(x.first);
	for(int i=1;i<=n;i++)
		printf("%lld%c",(cc[i]+=cc[i-1])%=mod," \n"[i==n]);
}

现在我们考虑怎么样拆分一个三元组,以及如何处理错误统计的 \(l,r\)

对于 \((x,y,z)\) ,在处理 \(cc_l\) 时,我们不想让 \(y<r\) 的那些区间统计上 \(z\)。我么需要新开一个数组,统计上需要减去的这些 \(z\)。(此时树状数组的 \(c\) 数组已经没用了我们不如再次利用 \(c\))我们令 \(c_x+=z*cnt(r)\),这些 \(r\) 要满足 \(r>y\)。同时令 \(c_{y+1}-=z*cnt(r)\)。这个意思是:在扫描到 \(l\ge x\)\(l\) 时,统计的答案要减去 \(c_x\),因为多出来的 \(r\) 不应该统计上去。统计到 \(y\) 后面的 \(l\) 时,不用减去这些了,因为本来就没有统计上(不明白为什么没统计上的话,可以看一下\(prez\)​)。

在处理 \(cc_{r+1}\) 时,我们应当减去 \(x\le l\le y<r\) 对应的 \(z\)。我们在扫描三元组 \((x,y,z)\) 时,令 \(c[y+1]+=z*cnt(l)\),这些 \(l\) 满足 \(x\le l \le y\)。意思是,扫描到 \(r\ge y+1\)\(r\) 时,统计的答案要少减去 \(c[y+1]\)。因为,前面对应的那些三元组,贡献要减少 \(c[y+1]\) ,因为 \(r\) 越界了,那些 \(l\) 不会和这个 \(r\) 产生贡献。

#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int mod=1<<30;
int n,m,K,B,a[150010],sum[150010];
int sX[150010],sY[150010];
LL c[150010],cc[150010],prez[150010];
unordered_map<int,vector<int>>mp;
struct SYZ
{int x,y,z;}syz[150010];
inline int read()
{
	int x=0,w=0;char ch=0;
	while(!isdigit(ch)){w|=ch=='-';ch=getchar();}
	while(isdigit(ch)){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return w?-x:x;
}
bool cmp(SYZ n1,SYZ n2)
{
	if(n1.x^n2.x)return n1.x<n2.x;
	return n1.y<n2.y;
}
void change(int x,int y)
{for(;x<=n;x+=x&-x)c[x]+=y;}
int ask(int x,int y=0)
{for(;x;x-=x&-x)y+=c[x];return y;}
void solve(int Y)//Y=sum[r]
{
	int X=K^Y;//X=sum[l-1]
	for(int i=1;i<=n;i++){
		sX[i]=sX[i-1]+(sum[i-1]==X);
		sY[i]=sY[i-1]+(sum[i]==Y);
	}
	memset(c,0,sizeof c);
	for(int i=1;i<=m;i++){
		int x=syz[i].x,y=syz[i].y,z=syz[i].z;
		c[x]+=1ll*z*(sY[n]-sY[y]);
		c[y+1]-=1ll*z*(sY[n]-sY[y]);
	}
	for(int i=1;i<=n;i++){
		c[i]+=c[i-1];
		if(sum[i-1]==X)
			cc[i]+=prez[i]*(sY[n]-sY[i-1])-c[i];
	}
	memset(c,0,sizeof c);
	for(int i=1;i<=m;i++){
		int x=syz[i].x,y=syz[i].y,z=syz[i].z;
		c[y+1]+=1ll*z*(sX[y]-sX[x-1]);
	}
	LL temp=0;
	for(int i=1;i<=n;i++){
		c[i]+=c[i-1];
		if(sum[i-1]==X)temp+=prez[i];
		if(sum[i]==Y)
			cc[i+1]-=temp-c[i];
	}
}
int main()
{
	n=read();m=read();K=read();B=sqrt(n);
	mp[0].push_back(0);
	for(int i=1;i<=n;i++)
		mp[sum[i]=sum[i-1]^(a[i]=read())].push_back(i);
	for(int i=1;i<=m;i++){
		int x=read(),y=read(),z=read();
		syz[i]=(SYZ){x,y,z};
		prez[x]+=z;prez[y+1]-=z;
	}
	for(int i=1;i<=n;i++)
		prez[i]+=prez[i-1];
	sort(syz+1,syz+1+m,cmp);
	for(auto&x:mp)//x.first是键,x.second是值
		reverse(x.second.begin(),x.second.end());
	for(int i=1,j=1;i<=n;i++){//i是左端点
		while(j<=m&&syz[j].x==i)
			change(1,syz[j].z),change(syz[j].y+1,-syz[j].z),j++;
		if(mp[sum[i-1]^K].size()<B)
		for(int x:mp[sum[i-1]^K]){//x是右端点
			if(x<i)break;
			int temp=ask(x);
			cc[i]+=temp;
			cc[x+1]-=temp;
		}
	}
	for(auto&x:mp)
	if(x.second.size()>=B)
		solve(x.first);
	for(int i=1;i<=n;i++)
		printf("%lld%c",(cc[i]+=cc[i-1])%=mod," \n"[i==n]);
}
posted @ 2024-05-01 16:08  zYzYzYzYz  阅读(50)  评论(0编辑  收藏  举报