维护序列 LibreOJ - 10129
原题链接
考察:线段树
思路:
很明显是设置两个懒标记,tag_1标记修改的累加和,tag_2标记修改的累乘积.但是直接这么写会WA,原因是具有运算优先级.假设修改区间为s,s*c1+c2 与 s*c2+c1是完全不同的结果.
需要统一运算优先级.也就是规定先乘还是后乘.如果我们后乘,即(s+c)*d这样比较难扩展.比如(s+c)*d+t ,如果当前区间"裂开"分给子节点的话,比较难形成(s+c)*d的形式.
考虑先*,那么在进行2操作直接tag_1+=c,进行1操作时需要tag_1+=tag_1*c,tag_2*=c.
push_down操作时,也是需要将子节点转化为 s*c+b的形式.所以子节点的
- tag_1 += tr[u].tag_2*tag_1+tr[u].tag_1.
- tag_2*= tr[u].tag_2;
- sum = tr[u].tag_2*sum+tr[u].tag_1*len.
更详细的解释在GO
luogu关于本题写得最好的一篇题解. GOOO!
Code
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 100010;
struct Node{
int l,r,sum;
LL tag_1,tag_2;//+ *
}tr[N<<2];
int n,p,a[N],m;
int get(int u)
{
return tr[u].r-tr[u].l+1;
}
void push_up(int u)
{
tr[u].sum = (tr[u<<1].sum+tr[u<<1|1].sum)%p;
}
void push_down(int u)
{//+ *
tr[u<<1].sum = (tr[u<<1].sum*tr[u].tag_2+tr[u].tag_1*get(u<<1)%p)%p;
tr[u<<1|1].sum = (tr[u<<1|1].sum*tr[u].tag_2+tr[u].tag_1*get(u<<1|1)%p)%p;
tr[u<<1].tag_1 = (tr[u<<1].tag_1*tr[u].tag_2%p+tr[u].tag_1)%p;
tr[u<<1|1].tag_1 = (tr[u<<1|1].tag_1*tr[u].tag_2%p+tr[u].tag_1)%p;
tr[u].tag_1 = 0;
tr[u<<1].tag_2 = tr[u].tag_2*tr[u<<1].tag_2%p;
tr[u<<1|1].tag_2 = tr[u].tag_2*tr[u<<1|1].tag_2%p;
tr[u].tag_2 = 1;
}
void build(int u,int l,int r)
{
tr[u] = {l,r,0,0,1};
if(l==r) { tr[u].sum = a[l]; return;}
int mid = l+r>>1;
build(u<<1,l,mid); build(u<<1|1,mid+1,r);
push_up(u);
}
void modify(int u,int l,int r,int mul,int add)
{
if(tr[u].l>=l&&tr[u].r<=r)
{
tr[u].sum = ((LL)tr[u].sum*mul+add*get(u)%p)%p;
tr[u].tag_1 = tr[u].tag_1*mul+add;
tr[u].tag_2 = tr[u].tag_2*mul;
return;
}
int mid = tr[u].l+tr[u].r>>1;
push_down(u);
if(l<=mid) modify(u<<1,l,r,mul,add);
if(mid<r) modify(u<<1|1,l,r,mul,add);
push_up(u);
}
int query(int u,int l,int r)
{
if(tr[u].l>=l&&tr[u].r<=r) return tr[u].sum;
push_down(u);
int mid = tr[u].l+tr[u].r>>1;
LL res = 0;
if(l<=mid) res+=query(u<<1,l,r);
res%=p;
if(mid<r) res+=query(u<<1|1,l,r);
res%=p;
return res;
}
int main()
{
scanf("%d%d",&n,&p);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
scanf("%d",&m);
build(1,1,n);
while(m--)
{
int op,l,r,c;
scanf("%d%d%d",&op,&l,&r);
if(op==3)
{
printf("%d\n",query(1,l,r));
continue;
}
scanf("%d",&c);
if(op==1) modify(1,l,r,c,0);
if(op==2) modify(1,l,r,1,c);
}
return 0;
}