维护序列 LibreOJ - 10129

原题链接
考察:线段树
思路:
  很明显是设置两个懒标记,tag_1标记修改的累加和,tag_2标记修改的累乘积.但是直接这么写会WA,原因是具有运算优先级.假设修改区间为s,s*c1+c2 与 s*c2+c1是完全不同的结果.
  需要统一运算优先级.也就是规定先乘还是后乘.如果我们后乘,即(s+c)*d这样比较难扩展.比如(s+c)*d+t ,如果当前区间"裂开"分给子节点的话,比较难形成(s+c)*d的形式.
  考虑先*,那么在进行2操作直接tag_1+=c,进行1操作时需要tag_1+=tag_1*c,tag_2*=c.
  push_down操作时,也是需要将子节点转化为 s*c+b的形式.所以子节点的

  1. tag_1 += tr[u].tag_2*tag_1+tr[u].tag_1.
  2. tag_2*= tr[u].tag_2;
  3. sum = tr[u].tag_2*sum+tr[u].tag_1*len.


  更详细的解释在GO
  luogu关于本题写得最好的一篇题解. GOOO!

Code

#include <iostream> 
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 100010;
struct Node{
	int l,r,sum;
	LL tag_1,tag_2;//+ *
}tr[N<<2];
int n,p,a[N],m;
int get(int u)
{
	return tr[u].r-tr[u].l+1;
}
void push_up(int u)
{
	tr[u].sum = (tr[u<<1].sum+tr[u<<1|1].sum)%p;
}
void push_down(int u)
{//+ *
    tr[u<<1].sum = (tr[u<<1].sum*tr[u].tag_2+tr[u].tag_1*get(u<<1)%p)%p;
    tr[u<<1|1].sum = (tr[u<<1|1].sum*tr[u].tag_2+tr[u].tag_1*get(u<<1|1)%p)%p;
	tr[u<<1].tag_1 = (tr[u<<1].tag_1*tr[u].tag_2%p+tr[u].tag_1)%p;
	tr[u<<1|1].tag_1 = (tr[u<<1|1].tag_1*tr[u].tag_2%p+tr[u].tag_1)%p;
	tr[u].tag_1 = 0;
	tr[u<<1].tag_2 = tr[u].tag_2*tr[u<<1].tag_2%p;
	tr[u<<1|1].tag_2 = tr[u].tag_2*tr[u<<1|1].tag_2%p;
	tr[u].tag_2 = 1;
}
void build(int u,int l,int r)
{
	tr[u] = {l,r,0,0,1};
	if(l==r) { tr[u].sum = a[l]; return;} 
	int mid = l+r>>1;
	build(u<<1,l,mid); build(u<<1|1,mid+1,r);
	push_up(u);
}
void modify(int u,int l,int r,int mul,int add)
{
	if(tr[u].l>=l&&tr[u].r<=r)
	{
		tr[u].sum = ((LL)tr[u].sum*mul+add*get(u)%p)%p;
		tr[u].tag_1 = tr[u].tag_1*mul+add;
		tr[u].tag_2 = tr[u].tag_2*mul;
		return;
	}
	int mid = tr[u].l+tr[u].r>>1;
	push_down(u);
	if(l<=mid) modify(u<<1,l,r,mul,add);
	if(mid<r) modify(u<<1|1,l,r,mul,add);
	push_up(u);
}
int query(int u,int l,int r)
{
	if(tr[u].l>=l&&tr[u].r<=r) return tr[u].sum;
	push_down(u);
	int mid = tr[u].l+tr[u].r>>1;
	LL res = 0;
	if(l<=mid) res+=query(u<<1,l,r);
	res%=p;
	if(mid<r) res+=query(u<<1|1,l,r);
	res%=p;
	return res;
}
int main()
{
	scanf("%d%d",&n,&p);
	for(int i=1;i<=n;i++) scanf("%d",&a[i]);
	scanf("%d",&m);
	build(1,1,n);
	while(m--)
	{
		int op,l,r,c;
		scanf("%d%d%d",&op,&l,&r);
		if(op==3)
		{
			printf("%d\n",query(1,l,r));
			continue;
		}
		scanf("%d",&c);
		if(op==1) modify(1,l,r,c,0);
		if(op==2) modify(1,l,r,1,c);
	}
	return 0;
}
posted @ 2021-05-19 20:52  acmloser  阅读(102)  评论(0编辑  收藏  举报