【模板】线段树,区间加区间乘
洛谷3378
1 #include<cstdio> 2 #include<algorithm> 3 #define ls (cur<<1) 4 #define rs (cur<<1|1) 5 #define mid ((a[cur].l+a[cur].r)>>1) 6 #define len(x) (a[x].r-a[x].l+1) 7 #define LL long long 8 using namespace std; 9 const int maxn=800010; 10 LL n,m,p,k,x,y,z; 11 struct tree{ 12 int l,r; 13 LL mul,del,sum; 14 }a[maxn]; 15 void read(LL &k){ 16 k=0; int f=1; char c=getchar(); 17 while (c<'0'||c>'9')c=='-'&&(f=-1),c=getchar(); 18 while ('0'<=c&&c<='9')k=k*10+c-'0',c=getchar(); 19 k*=f; 20 } 21 void build(int cur,int l,int r){ 22 a[cur].l=l; a[cur].r=r; a[cur].mul=1; 23 if (l<r){ 24 build(ls,l,mid); 25 build(rs,mid+1,r); 26 a[cur].sum=a[ls].sum+a[rs].sum; 27 } 28 else read(a[cur].sum); 29 } 30 void pushdown(int cur){ 31 if (len(cur)==1||(a[cur].del==0&&a[cur].mul==1)) return; 32 a[ls].mul*=a[cur].mul; a[ls].mul%=p; 33 a[rs].mul*=a[cur].mul; a[rs].mul%=p; 34 a[ls].del*=a[cur].mul; a[ls].del+=a[cur].del; a[ls].del%=p; 35 a[rs].del*=a[cur].mul; a[rs].del+=a[cur].del; a[rs].del%=p; 36 a[ls].sum*=a[cur].mul; a[ls].sum%=p; a[ls].sum+=len(ls)*a[cur].del; a[ls].sum%=p; 37 a[rs].sum*=a[cur].mul; a[rs].sum%=p; a[rs].sum+=len(rs)*a[cur].del; a[rs].sum%=p; 38 a[cur].mul=1; a[cur].del=0; 39 } 40 void multiply(int cur,int l,int r,int mul){ 41 if (l<=a[cur].l&&a[cur].r<=r){ 42 a[cur].mul*=mul; a[cur].mul%=p; 43 a[cur].del*=mul; a[cur].del%=p; 44 a[cur].sum*=mul; a[cur].sum%=p; 45 } 46 else{ 47 pushdown(cur); 48 if (l<=mid) multiply(ls,l,r,mul); 49 if (r>mid) multiply(rs,l,r,mul); 50 a[cur].sum=a[ls].sum+a[rs].sum; 51 } 52 } 53 void add(int cur,int l,int r,int delta){ 54 if (l<=a[cur].l&&a[cur].r<=r){ 55 a[cur].del+=delta; 56 a[cur].sum+=len(cur)*delta; 57 } 58 else{ 59 pushdown(cur); 60 if (l<=mid) add(ls,l,r,delta); 61 if (r>mid) add(rs,l,r,delta); 62 a[cur].sum=a[ls].sum+a[rs].sum; 63 } 64 } 65 LL query(int cur,int l,int r){ 66 if (l<=a[cur].l&&a[cur].r<=r) return a[cur].sum%p; 67 else{ 68 pushdown(cur); 69 LL ret=0; 70 if (l<=mid) ret+=query(ls,l,r); 71 if (r>mid) ret+=query(rs,l,r); 72 return ret%p; 73 } 74 } 75 int main(){ 76 read(n); read(m); read(p); 77 build(1,1,n); 78 for (int i=1;i<=m;i++){ 79 read(k); 80 if (k==3){ 81 read(x); read(y); 82 printf("%lld\n",query(1,x,y)); 83 } 84 else{ 85 read(x); read(y); read(z); 86 if (k==1) multiply(1,x,y,z); 87 else add(1,x,y,z); 88 } 89 } 90 return 0; 91 }