题解:AT_abc357_f [ABC357F] Two Sequence Queries

题意

维护一个数据结构,支持两个数列的区间求和,和查询区间内两数列各元素积的和。

分析

线段树万岁!

这道题要维护两个序列,所以线段树中要同时存储两个区间和。但还要在维护一个信息,是该区间内两序列元素积的和。大概长这样:

struct no
{
	int l,r;
	int da,db,ab;
	int ta,tb;
}t[maxn<<2];

其他的更新就不讲了,主要说一说积的和信息的更新。

当更新一个序列时,该信息要传递的信息其实是另一盒序列的和乘上该序列的懒标记,这点很好想。

然后就没有什么问题了。但是注意这道题要取模,你少取一个就废了。

Code

#include<bits/stdc++.h>
//#include<atcoder/modint>
#define int long long
using namespace std;
//using mint=atcoder::modint998244353;
inline int read()
{
	int w=1,s=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')w=-1;ch=getchar();}
	while(isdigit(ch)){s=s*10+(ch-'0');ch=getchar();}
	return w*s;
}
const int maxn=1e6+10;
const int mod=998244353;
struct no
{
	int l,r;
	int da,db,ab;
	int ta,tb;
}t[maxn<<2];
int a[maxn],b[maxn],n,Q;
void upd(int p)
{
	t[p].da=(t[p*2].da+t[p*2+1].da)%mod;
	t[p].db=(t[p*2].db+t[p*2+1].db)%mod;
	t[p].ab=(t[p*2].ab+t[p*2+1].ab)%mod;
}
void build(int p,int l,int r)
{
	t[p].l=l,t[p].r=r;
	if(l==r)
	{
		t[p].da=a[l]%mod;t[p].db=b[l]%mod;
		t[p].ab=a[l]*b[l]%mod;
		return ;
	}
	int mid=(l+r)>>1;
	build(p*2,l,mid);
	build(p*2+1,mid+1,r);
	upd(p);
}
void spread(int p)
{
	if(t[p].ta)
	{
		t[p*2].da+=t[p].ta*(t[p*2].r-t[p*2].l+1)%mod;t[p*2].da%=mod;
		t[p*2+1].da+=t[p].ta*(t[1+p*2].r-t[1+p*2].l+1)%mod;t[p*2+1].da%=mod;
		t[p*2].ta+=t[p].ta;t[p*2].ta%=mod;
		t[p*2+1].ta+=t[p].ta;t[p*2+1].ta%=mod;
		t[p*2].ab+=t[p].ta*t[p*2].db%mod;t[p*2].ab%=mod;
		t[p*2+1].ab+=t[p].ta*t[p*2+1].db%mod;t[p*2+1].ab%=mod;
	}
	if(t[p].tb)
	{
		t[p*2].db+=t[p].tb*(t[p*2].r-t[p*2].l+1)%mod;t[p*2].db%=mod;
		t[p*2+1].db+=t[p].tb*(t[1+p*2].r-t[1+p*2].l+1)%mod;t[p*2+1].db%=mod;
		t[p*2].tb+=t[p].tb;t[p*2].tb%=mod;
		t[p*2+1].tb+=t[p].tb;t[p*2+1].tb%=mod;
		t[p*2].ab+=t[p].tb*t[p*2].da%mod;t[p*2].ab%=mod;
		t[p*2+1].ab+=t[p].tb*t[p*2+1].da%mod;t[p*2+1].ab%=mod;
	}
	t[p].ta=0;
	t[p].tb=0;
}
void changea(int p,int l,int r,int k)
{
	if(t[p].l>=l&&t[p].r<=r)
	{
		t[p].da+=k*(t[p].r-t[p].l+1)%mod;t[p].da%=mod;
		t[p].ta+=k;t[p].ta%=mod;
		t[p].ab+=k*t[p].db%mod;t[p].ab%=mod;
		return ;
	}
	spread(p);
	int mid=(t[p].l+t[p].r)>>1;
	if(l<=mid)changea(p*2,l,r,k);
	if(mid<r) changea(p*2+1,l,r,k);
	upd(p);
}
void changeb(int p,int l,int r,int k)
{
	if(t[p].l>=l&&t[p].r<=r)
	{
		t[p].db+=k*(t[p].r-t[p].l+1)%mod;t[p].db%=mod;
		t[p].tb+=k;t[p].tb%=mod;
		t[p].ab+=k*t[p].da%mod;t[p].ab%=mod;
		return ;
	}
	spread(p);
	int mid=(t[p].l+t[p].r)>>1;
	if(l<=mid)changeb(p*2,l,r,k);
	if(mid<r) changeb(p*2+1,l,r,k);
	upd(p);
}
int ask(int p,int l,int r)
{
	if(t[p].l>=l&&t[p].r<=r)
	{
		return t[p].ab%mod;
	}
	spread(p);
	int mid=(t[p].l+t[p].r)>>1,sum=0;
	if(l<=mid)sum=(sum+ask(p*2,l,r)%mod)%mod;
	if(mid<r) sum=(sum+ask(p*2+1,l,r)%mod)%mod;
	return sum%mod;
}
signed main()
{
//  freopen("xxx.in","r",stdin);
//	freopen("xxx.out","w",stdout);
	cin>>n>>Q;
	for(int i=1;i<=n;i++)a[i]=read();
	for(int i=1;i<=n;i++)b[i]=read();
	build(1,1,n);
	while(Q--)
	{
		int opt=read(),l=read(),r=read();
		if(opt==3)
		{
			printf("%lld\n",ask(1,l,r)%mod);
			continue;
		} 
		int x=read();
		if(opt==1)changea(1,l,r,x);
		if(opt==2)changeb(1,l,r,x);
	}
	return 0;
}

posted @ 2024-07-17 16:56  Redamancy_Lydic  阅读(3)  评论(0编辑  收藏  举报