【BZOJ】1798: [Ahoi2009]Seq 维护序列seq 线段树多标记(区间加+区间乘)

【题意】给定序列,支持区间加和区间乘,查询区间和取模。n<=10^5。

【算法】线段树

【题解】线段树多重标记要考虑标记与标记之间的相互影响。

对于sum*b+a,+c直接加上即可。

*c后就是(sum*b+a)*c=sum*b*b+a*c,也就是加法的部分也要乘。

所以,每次在乘法的时候要把加法标记也乘上。下传时先传乘法。

注意乘法初始值为1,但是取模后可能为0。

#include<cstdio>
#include<cstring>
#include<cctype>
#include<algorithm>
using namespace std;
int read(){
    int s=0,t=1;char c;
    while(!isdigit(c=getchar()))if(c=='-')t=-1;
    do{s=s*10+c-'0';}while(isdigit(c=getchar()));
    return s*t;
}
const int maxn=100010;
struct tree{int l,r,a,b,sum;}t[maxn*4];
int n,MOD,a[maxn];
int M(int x){return x>=MOD?x-MOD:x;}
void up(int k){t[k].sum=M(t[k<<1].sum+t[k<<1|1].sum);}
void modify_a(int k,int x){t[k].sum=M(t[k].sum+1ll*(t[k].r-t[k].l+1)*x%MOD);t[k].a=M(t[k].a+x);}//
void modify_b(int k,int x){t[k].sum=1ll*t[k].sum*x%MOD;t[k].b=1ll*t[k].b*x%MOD;t[k].a=1ll*t[k].a*x%MOD;}
void down(int k){
    if(t[k].b!=1){// 0
        modify_b(k<<1,t[k].b);modify_b(k<<1|1,t[k].b);
        t[k].b=1;
    }
    if(t[k].a){
        modify_a(k<<1,t[k].a);modify_a(k<<1|1,t[k].a);
        t[k].a=0;
    }
}
void build(int k,int l,int r){
    t[k].l=l;t[k].r=r;t[k].a=0;t[k].b=1;
    if(l==r){t[k].sum=a[l];return;}
    int mid=(l+r)>>1;
    build(k<<1,l,mid);build(k<<1|1,mid+1,r);
    up(k);
}
void add(int k,int l,int r,int x){
    if(l<=t[k].l&&t[k].r<=r){modify_a(k,x);return;}
    down(k);
    int mid=(t[k].l+t[k].r)>>1;//
    if(l<=mid)add(k<<1,l,r,x);
    if(r>mid)add(k<<1|1,l,r,x);
    up(k);
}
void mul(int k,int l,int r,int x){
    if(l<=t[k].l&&t[k].r<=r){modify_b(k,x);return;}
    down(k);
    int mid=(t[k].l+t[k].r)>>1;//
    if(l<=mid)mul(k<<1,l,r,x);
    if(r>mid)mul(k<<1|1,l,r,x);
    up(k);
}
int query(int k,int l,int r){
    if(l<=t[k].l&&t[k].r<=r){return t[k].sum;}
    down(k);
    int mid=(t[k].l+t[k].r)>>1,sum=0;
    if(l<=mid)sum=query(k<<1,l,r);
    if(r>mid)sum=M(sum+query(k<<1|1,l,r));
    return sum;
}
int main(){
    n=read();MOD=read();
    for(int i=1;i<=n;i++)a[i]=read()%MOD;
    build(1,1,n);
    int m=read();
    while(m--){
        int k=read(),x=read(),y=read();
        if(k==3){printf("%d\n",query(1,x,y));continue;}
        int z=read();
        if(k==1)mul(1,x,y,z%MOD);else add(1,x,y,z%MOD);
    }
    return 0;
}
View Code

 

posted @ 2018-03-02 15:08  ONION_CYC  阅读(294)  评论(0编辑  收藏  举报