[THUSCH2017] 大魔法师
前期准备
1.熟练的掌握区间修改线段树
2.对矩阵乘法有部分的了解,知道如何使用
3.对卡常十分精通
题目大意
题目给定 \(n\) 个三元组,每个三元组包含 \(A\)、\(B\)、\(C\) 三个元素,一共进行 \(m\) 次操作,分别是下面七种之一:
1.令给定区间内,\(A_i=A_i+B_i\)
2.令给定区间内,\(B_i=B_i+C_i\)
3.令给定区间内,\(C_i=C_i+A_i\)
4.令给定区间内,\(A_i=A_i+v\)
5.令给定区间内,\(B_i=B_i\times v\)
6.令给定区间内,\(C_i=v\)
7.查询区间内每个元素 \(A\)、\(B\)、\(C\) 累加得到的和。
其中 \(1 \le n \le 2.5\times 10^5\),\(1 \le m \le 2.5\times 10^5\) ,元素中每个值对 \(998244353\) 取模。
题目时间限制 \(5\) 秒!!!
思路
因为题目要求写一个动态的区间修改和区间最大值,所以自然地就可以想到区间修改线段树。
但是因为这道题目处理的是三元组,所以如果一个一个处理的话,线段树的 lazy
数组的会非常难写。
于是顺理成章的,就应该使用矩阵乘法给线段树进行优化。
每个三元组在运算的过程中都可以看做一个矩阵,而这 \(7\) 个操作就只需要推 \(6\) 个矩阵并写一个区间求和就结束了。
所以我么需要做的事情就是将这 \(7\) 个矩阵推出来,接着将矩阵套到线段树上就可以了。为了便于线段树的书写,我使用的重载运算符(operator
),即重新定义符号。
做法
操作 \(1\)、\(2\)、\(3\)
如果你已经做个一些题目的话,那么你应该可以顺理成章的推出后面 \(3\) 个式子
操作\(4\)、\(5\)、\(6\)
现在你会发现这个给定的 \(v\) 不知道应该塞到哪里了,于是我们就应该添加辅助的维度
可以将原来的 \(\begin{bmatrix}A_i & B_i &C_i\end{bmatrix}\) 替换为 \(\begin{bmatrix}A_i & B_i &C_i &1\end{bmatrix}\) 辅助增加
操作 \(7\)
这个操作其实就可以直接写一个线段树的区间求和就可以了
一些细节
- 在定义矩阵的结构体里面,应该将矩阵清零而不是直接进行操作。
- 注意有加法与乘法的线段树
pushdown
处理lazy
数组时的顺序。 - 十年 OI 一场空,不开
long long
见祖宗。 - 记得在每一次运算后都要写取模操作,否则会溢出。
AC Code
#include<bits/stdc++.h>
#define int long long
#define m(s1) memset(s1.a,0,sizeof(s1.a))
const int mod=998244353;
const int N=1000005;
inline int read(){
int x=0;
char ch=getchar();
while(ch>'9'||ch<'0') ch=getchar();
while(ch<='9'&&ch>='0') x=(x<<1)+(x<<3)+ch-48,ch=getchar();
return x;
}
struct node{
int a[3][3],n,m;
node(){memset(a,0,sizeof(a));}
friend node operator + (const node a,const node b){
node s; s.n=a.n,s.m=a.m;
m(s);
for(register int i=0;i<a.n;++i)
for(register int j=0;j<a.m;++j)
s.a[i][j]=(a.a[i][j]+b.a[i][j])%mod;
return s;
}
friend node operator * (const node a,const node b){
node s; s.n=a.n,s.m=b.m;
m(s);
for(register int i=0;i<a.n;++i)
for(register int k=0;k<a.m;++k)
for(register int j=0;j<b.m;++j)
s.a[i][j]=(s.a[i][j]+a.a[i][k]*b.a[k][j]%mod)%mod;
return s;
}
friend node operator * (const node a,const int b){
node s; s.n=a.n,s.m=a.m;
m(s);
for(register int i=0;i<a.n;++i)
for(register int j=0;j<a.m;++j)
s.a[i][j]=a.a[i][j]*b%mod;
return s;
}
}s[8],sum[N],lazy1[N],lazy2[N];
inline void pre(){
s[1].a[0][0]=s[1].a[1][1]=s[1].a[2][2]=s[1].a[1][0]=1;
s[1].n=s[1].m=3;
s[2].a[0][0]=s[2].a[1][1]=s[2].a[2][2]=s[2].a[2][1]=1;
s[2].n=s[2].m=3;
s[3].a[0][0]=s[3].a[1][1]=s[3].a[2][2]=s[3].a[0][2]=1;
s[3].n=s[3].m=3;
s[4].a[0][0]=-1;
s[4].n=1,s[4].m=3;
s[5].a[0][0]=s[5].a[2][2]=1;
s[5].a[1][1]=-1;
s[5].n=s[5].m=3;
s[6].a[0][0]=s[6].a[1][1]=1;
s[6].n=s[6].m=3;
s[7].a[0][2]=-1;
s[7].n=1,s[7].m=3;
}
int n,m;
inline void updata(int k,int l,int r){
int mid=(l+r)/2;
lazy2[k*2]=lazy2[k*2]*lazy2[k];
lazy2[k*2+1]=lazy2[k*2+1]*lazy2[k];
lazy1[k*2]=lazy1[k*2]*lazy2[k]+lazy1[k];
lazy1[k*2+1]=lazy1[k*2+1]*lazy2[k]+lazy1[k];
sum[k*2]=sum[k*2]*lazy2[k]+lazy1[k]*(mid-l+1);
sum[k*2+1]=sum[k*2+1]*lazy2[k]+lazy1[k]*(r-mid);
m(lazy1[k]),m(lazy2[k]);
lazy2[k].a[0][0]=lazy2[k].a[1][1]=lazy2[k].a[2][2]=1;
}
inline void pre_lazy(int k){
lazy1[k].n=1,lazy1[k].m=3;
lazy2[k].n=3,lazy2[k].m=3;
lazy2[k].a[0][0]=lazy2[k].a[1][1]=lazy2[k].a[2][2]=1;
}
void build(int k,int l,int r){
sum[k].n=1,sum[k].m=3;
pre_lazy(k);
if(l==r){
sum[k].a[0][0]=read();
sum[k].a[0][1]=read();
sum[k].a[0][2]=read();
return;
}
int mid=(l+r)/2;
build(k*2,l,mid);
build(k*2+1,mid+1,r);
sum[k]=sum[k*2]+sum[k*2+1];
}
void up(int k,int l,int r,int ll,int rr,node &v,bool flag){
if(ll<=l&&rr>=r){
if(flag==0) lazy1[k]=lazy1[k]*v,lazy2[k]=lazy2[k]*v,sum[k]=sum[k]*v;
else lazy1[k]=lazy1[k]+v,sum[k]=sum[k]+v*(r-l+1);
return ;
}int mid=(l+r)/2;
updata(k,l,r);
if(ll<=mid) up(k*2,l,mid,ll,rr,v,flag);
if(mid<rr) up(k*2+1,mid+1,r,ll,rr,v,flag);
sum[k]=sum[k*2]+sum[k*2+1];
}
node ask(int k,int l,int r,int ll,int rr){
if(ll<=l&&rr>=r) return sum[k];
int mid=(l+r)/2;
updata(k,l,r);
node res; res.n=1,res.m=3;
if(ll<=mid) res=ask(k*2,l,mid,ll,rr);
if(mid<rr) res=res+ask(k*2+1,mid+1,r,ll,rr);
return res;
}
signed main(){
pre();
n=read();
build(1,1,n);
m=read();
for(register int i=1,op,l,r;i<=m;++i){
op=read(),l=read(),r=read();
if(op<=3) up(1,1,n,l,r,s[op],0);
if(op==4) s[4].a[0][0]=read(),up(1,1,n,l,r,s[4],1);
if(op==5) s[5].a[1][1]=read(),up(1,1,n,l,r,s[5],0);
if(op==6) s[7].a[0][2]=read(),up(1,1,n,l,r,s[6],0),up(1,1,n,l,r,s[7],1);
if(op==7){
node ans=ask(1,1,n,l,r);
printf("%lld %lld %lld\n",ans.a[0][0],ans.a[0][1],ans.a[0][2]);
}
}
return 0;
}