【UOJ 56】线段树区间加和乘
【题目描述】:
如题,已知一个数列,你需要进行下面两种操作:
1.将某区间每一个数加上x
2.将某区间每一个数乘上x
3.求出某区间每一个数的和
【输入描述】:
第一行包含三个整数N、M、P,分别表示该数列数字的个数、操作的总个数和模数。
第二行包含N个用空格分隔的整数,其中第i个数字表示数列第i项的初始值。
接下来M行每行包含3或4个整数,表示一个操作,具体如下:
操作1: 格式:1 x y k 含义:将区间[x,y]内每个数乘上k(>0)
操作2: 格式:2 x y k 含义:将区间[x,y]内每个数加上k(>0)
操作3: 格式:3 x y 含义:输出区间[x,y]内每个数的和对P取模所得的结果
【输出描述】:
输出包含若干行整数,即为所有操作3的结果。
【样例输入】:
5 5 38
1 5 4 2 3
2 1 4 1
3 2 5
1 2 4 2
2 3 5 5
3 1 4
【样例输出】:
17
2
【时间限制、数据范围及描述】:
时间:1s 空间128M
对于30%的数据:N<=10,M<=10
对于70%的数据:N<=1,000,M<=10,000
对于100%的数据:N<=100,000,M<=100,000
题解:新学区间乘操作 √
#include<cstdio> #include<iostream> #include<cmath> #include<cstdlib> #include<cstring> #include<algorithm> #include<bits/stdc++.h> typedef long long ll; using namespace std; const int N=100003; int yc,n,T,x,y; ll z,mod,tim[N*4]; ll a[N*4],sum[N*4],add[N*4]; void pushup(int rt) { sum[rt]=(sum[rt*2]+sum[rt*2+1])%mod; } void build(int l,int r,int rt){ tim[rt]=1; add[rt]=0; if(l==r){ sum[rt]=a[l]; return ; } int m=(l+r)/2; build(l,m,rt*2); build(m+1,r,rt*2+1); pushup(rt); } void pushdown(int rt,int ln,int rn){ if(add[rt]==0 && tim[rt]==1) return; if(ln+rn==1) return ; sum[rt*2]=(sum[rt*2]*tim[rt]+add[rt]*ln)%mod; sum[rt*2+1]=(sum[rt*2+1]*tim[rt]+add[rt]*rn)%mod; add[rt*2]=(add[rt*2]*tim[rt]+add[rt])%mod; add[rt*2+1]=(add[rt*2+1]*tim[rt]+add[rt])%mod; add[rt]=0; tim[rt*2]=(tim[rt*2]*tim[rt])%mod; tim[rt*2+1]=(tim[rt*2+1]*tim[rt])%mod; tim[rt]=1; } void update1(int L,int R,ll c,int l,int r,int rt){ int m=(l+r)/2; pushdown(rt,m-l+1,r-m); if(L<=l && r<=R) { sum[rt]=(sum[rt]*c)%mod; tim[rt]*=c; return ; } if(L<=m) update1(L,R,c,l,m,rt*2); if(R>m) update1(L,R,c,m+1,r,rt*2+1); pushup(rt); } void update2(int L,int R,ll c,int l,int r,int rt){ int m=(l+r)/2; pushdown(rt,m-l+1,r-m); if(L<=l && r<=R) { sum[rt]=(sum[rt]+c*(r-l+1))%mod; add[rt]+=c; return ; } if(L<=m) update2(L,R,c,l,m,rt*2); if(R>m) update2(L,R,c,m+1,r,rt*2+1); pushup(rt); } ll query(int L,int R,int l,int r,int rt){ if(L<=l && r<=R) return sum[rt]; int m=(l+r)/2; pushdown(rt,m-l+1,r-m); ll ans=0; if(L<=m) ans=(ans+query(L,R,l,m,rt*2))%mod; if(R>m) ans=(ans+query(L,R,m+1,r,rt*2+1))%mod; return ans; } int main(){ scanf("%d %d %d",&n,&T,&mod); for(int i=1;i<=n;i++) tim[i]=1; for(int i=1;i<=n;i++) scanf("%lld",&a[i]); build(1,n,1); while(T--){ scanf("%d",&yc); if(yc==1){ scanf("%d %d %lld",&x,&y,&z); update1(x,y,z,1,n,1); } if(yc==2){ scanf("%d %d %lld",&x,&y,&z); update2(x,y,z,1,n,1); } if(yc==3){ scanf("%d %d",&x,&y); cout<<query(x,y,1,n,1)<<endl; } } return 0; }