【JSOI2014】【BZOJ5039】序列维护(线段树模板)

problem

已知一个数列,你需要进行下面三种操作:
1.将某区间每一个数乘上x
2.将某区间每一个数加上x
3.求出某区间每一个数的和

solution

区间修改+区间查询。
维护两个LazyTag

codes

#include<iostream>
#include<algorithm>
using namespace std;
const int maxn = 100010;
typedef long long LL;

int n, m;
LL a[maxn],mod;
struct node{
    int l, r;
    LL val, addmark, mulmark;
}sgt[maxn<<2];
void build(int p, int l, int r){
    sgt[p].l = l, sgt[p].r = r;
    sgt[p].mulmark=1, sgt[p].addmark=0;
    if(l == r){
        sgt[p].val = a[l];
    }else{
        int m = (l+r)/2;
        build(p*2,l,m);
        build(p*2+1,m+1,r);
        sgt[p].val = sgt[p*2].val+sgt[p*2+1].val;
    }
    sgt[p].val %= mod;
}
void pushdown(int p){
    if(sgt[p].addmark==0&&sgt[p].mulmark==1)return ;
    //初始化父节点
    LL t1 = sgt[p].addmark, t2 = sgt[p].mulmark;
    sgt[p].addmark = 0, sgt[p].mulmark = 1;
    //维护标记
    sgt[p*2].mulmark = (sgt[p*2].mulmark*t2)%mod;
    sgt[p*2+1].mulmark = (sgt[p*2+1].mulmark*t2)%mod;
    sgt[p*2].addmark = (sgt[p*2].addmark*t2+t1)%mod;
    sgt[p*2+1].addmark = (sgt[p*2+1].addmark*t2+t1)%mod;
    //更新当前值,我们规定乘法优先更新(加法优先会损失精度)
    int l = sgt[p].l, r = sgt[p].r, m = (l+r)/2;
    sgt[p*2].val=(sgt[p*2].val*t2+t1*(m-l+1))%mod;//先乘以乘法标记再加上已用乘法标记更新过的加法标记。
    sgt[p*2+1].val=(sgt[p*2+1].val*t2+t1*(r-m))%mod;
}
void add(int p, int l, int r, LL v){
    if(l <= sgt[p].l && sgt[p].r <= r){
        sgt[p].val = (sgt[p].val+(sgt[p].r-sgt[p].l+1)*v)%mod;
        sgt[p].addmark = (sgt[p].addmark+v)%mod;
        return ;
    }
    pushdown(p);
    int m = (sgt[p].l+sgt[p].r)/2;
    if(l <= m)add(p*2,l,r,v);
    if(r > m)add(p*2+1,l,r,v);
    sgt[p].val = (sgt[p*2].val+sgt[p*2+1].val)%mod;
}
void times(int p, int l, int r, LL v){
    if(l <= sgt[p].l && sgt[p].r <= r){
        sgt[p].val = (sgt[p].val*v)%mod;
        sgt[p].mulmark = (sgt[p].mulmark*v)%mod;
        sgt[p].addmark = (sgt[p].addmark*v)%mod;//原先的加法标记也要乘
        return ;
    }
    pushdown(p);
    int m = (sgt[p].l+sgt[p].r)/2;
    if(l <= m)times(p*2,l,r,v);
    if(r > m)times(p*2+1,l,r,v);
    sgt[p].val = (sgt[p*2].val+sgt[p*2+1].val)%mod;
}
LL query(int p, int l, int r){
    if(l <= sgt[p].l && sgt[p].r <= r)return sgt[p].val;
    pushdown(p); //pushdown
    LL m = (sgt[p].l+sgt[p].r)/2, ans = 0;
    if(l <= m)ans += query(p*2,l,r);
    if(r > m)ans += query(p*2+1,l,r);
    return ans%mod;
}

int main(){
    ios::sync_with_stdio(false);
    cin>>n>>mod;
    for(int i = 1; i <= n; i++)cin>>a[i];
    build(1,1,n);
    cin>>m;
    for(int i = 1; i <= m; i++){
        int op;  cin>>op;
        if(op == 1){
            LL x, y, z;  cin>>x>>y>>z;
            times(1,x,y,z);
        }else if(op == 2){
            LL x, y, z;  cin>>x>>y>>z;
            add(1,x,y,z);
        }else{
            LL x, y;  cin>>x>>y;
            cout<<query(1,x,y)%mod<<"\n";
        }
    }
    return 0;
}
posted @ 2018-05-30 21:22  gwj1139177410  阅读(150)  评论(0编辑  收藏  举报
选择