理解 ”加法+乘法“ 线段树 - Luogu P3373 【模板】线段树 2
step 0 - 引子
在看了众多题解后,我仍然没明白这样一个问题:为什么会想到延迟标记下传要分加法和乘法,而做乘法时还要把加法的 tag 一并乘了?如何想到这种联系?
显然,在尝试一个 lazytag 不行之后,我们肯定能想到同时维护两个延迟标记。最开始是天真的,尝试把加法和乘法分离来写——结果发现优先度的问题我似乎并不能解决:该先做哪个运算?本来想得到 \(((a+b)\cdot c+d)\cdot e\) 的,结果得到了 \((a+b+d)\cdot c\cdot e\)……
经过几番思考,看了几轮题解,我似乎有点思路了:其实我们就是要避开优先级!也就是说在运算的过程中先做哪个运算与原数无关,原数在运算结束时才起效。这样,我们才能保证以 \(O(1)\) 的复杂度正确地计算某个结点的值。如果不能避开,那延迟标记就做不了,还不如写一棵普普通通的线段树/kk。
假设我们的原数等于 \(a\)。我们对它进行运算,得到答案 \(\newcommand\opt[1]{\ {\rm opt}_{#1}\ } ((((a \opt 1 x_1)\opt 2 x_2)\opt 3 x_3)\opt 4 \dots \opt {n-1} x_{n-1})\opt n x_n\),其中 \(\opt i\) 指的是第 \(i\) 次运算的类型,在本文中,表示 \(+\) 或 \(\times\)。
step 1 - 只加 与 只乘
我们先讨论只有加法与只有乘法两种情况,以便理解优先度的概念。
-
假设 \(\forall i\in[1,n],\opt i=+\)。通过模拟一系列程序运算的过程,我们知道答案是 \(((((a+x_1)+x_2)+x_3)+\dots +x_{n-1})+x_n\),\(x_i\) 表示的是第 \(i\) 次加的数 \(x\)。但是由于括号的原因,有个先后顺序,得先算括号内的,再算括号外的,因此计算每次都要涉及原数 \(a\)。想要避开优先度?我们可以通过多年的数学经验,由结合律得 \((x_1+x_2+\dots+x_{n-1}+x_n)+a\)。这样一来,由于加法的交换律,我们知道运算的先后与 \(a\) 无关,\(a\) 只在运算末尾起效了!
-
同样的,假设 \(\forall i\in[1,n],\opt i=\times\),由于乘法也有结合律和交换律,满足 \(((((a\times x_1)\times x_2)\times x_3)\times \dots \times x_{n-1})\times x_n=(x_1\times x_2\times\dots\times x_{n-1}\times x_n)\times a\),运算的先后也与 \(a\) 无关。于是避开了优先度。
step 2 - 又加又乘
若以 \(a\) 为主元,把原式拆开得到的多项式 \(S\) 是一次的,形如 \(S=B+K\cdot a\)(\(B\) 是一个常数多项式,\(K\) 是一个与 \(a\) 无关的单项式),我们称这个式子的一次项系数 \(K\) 为该式子的系数,记作 \(K_S\);\(B\) 称为常数,记作 \(B_S\)。
由 step 1 情况 1 得:\(B\) 与 \(a\) 无关,可以避开优先度。由情况 2 得:\(K\) 的运算先后也与 \(a\) 无关,可以避开优先度。又,假设我们知道 \(B,K,a\),则一定可以得到唯一确定的 \(S\) 的值。综上,若 \(S\) 为一次多项式,则可以用延迟标记做。下面我们来证明 \(B\) 是一个常数多项式,\(K\) 是一个与 \(a\) 无关的单项式。
设 \(S_i\) 表示经过 \(i\) 次运算后的 \(a\)。通过定义,\(S_0=a\),则 \(K_{S_0}=1,B_{S_0}=0\).
对式子 \(S_i\) 进行运算,得到 \(S_{i+1}=S_i\opt i x_i\)。若 \(\opt i=+\),则 \(K_{S_{i+1}}=K_{S_i},B_{S_{i+1}}=B_{S_i}+x_i\);否则 \(K_{S_{i+1}}=K_{S_i}\times x_i,B_{S_{i+1}}=B_{S_i}\times x_i\)。
由数学归纳法得: \(K_{S_n}=\prod \limits_{i=1}^n {x_i}^{[\opt i=\times]}\),\(B_{S_n}\) 也只与 \(x\) 有关。那么 \(S_n\) 就是一个一次多项式。
扩展一下,若把 \(a\) 也理解成一个满足性质的 \(S\),即 \(a=K_{a'} \cdot a'+B_{a'}\)。那么对于 \(a'\) 来说, \(K_a = K_a\times K_{a'},B_a=K_a\times B_{a'}+B_a\)。
step 3 - 线段树
回到线段树。定义函数 \(R(i)\) 表示节点 \(i\) 的真实值,\(v_i\) 表示节点的储存值,\(l_i\) 表示它覆盖的长度。
我们发现若节点 \(a\) 被完全覆盖,由于 \(v_a=R(a)\),该节点很好计算。若加法,\(v_a+l_a\cdot x\to v_a\);否则,\(v_a\cdot x\to v_a\)。
下传标记,因为 \(K,B\) 与 \(v\) 无关,我们正好可以利用 \(K\) 和 \(B\)。假设当前节点 \(i\) 的标记为 \(K_i,B_i\),即 \(R(i)=K_i\cdot v_i+B_i\)。则在接收父节点 \(f\) 标记后,\(K_i\cdot K_f\to K_i\),然后 \(K_f\cdot B_i+B_f\to B_i\),最后 \(R(i)\to v_i\)。\(K,B\) 的维护在 step 2 已给出。
step 4 - 总结(思维方式)
若线段树节点的真实值与储存值满足多项式关系,且每次运算不会使改多项式次数变化,则可以下传延迟标记。(不过我目前没有想到有没有高于一次的多项式关系的运算还不升次或降次……)
step 5 - 上代码!
code:
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+1;
int n,m,P,A[N],tree[N*4];
int markA[N*4],markM[N*4];
void dfs(int l=1,int r=n,int root=1){
markM[root]=1;
if(l==r){tree[root]=A[l]%P;return ;}
int mid=(l+r)>>1;
dfs(l,mid,root<<1);
dfs(mid+1,r,root<<1|1);
tree[root]=(tree[root<<1]+tree[root<<1|1])%P;
}
void pushdown(int p,int l,int r){
int len=r-l+1;
tree[p<<1]=(tree[p<<1]*markM[p]+markA[p]*(len-len/2))%P;
tree[p<<1|1]=(tree[p<<1|1]*markM[p]+markA[p]*(len/2))%P;
markM[p<<1]=markM[p<<1]*markM[p]%P;
markM[p<<1|1]=markM[p<<1|1]*markM[p]%P;
markA[p<<1]=(markA[p<<1]*markM[p]+markA[p])%P;
markA[p<<1|1]=(markA[p<<1|1]*markM[p]+markA[p])%P;
markM[p]=1;
markA[p]=0;
}
inline void UpdA(int d,int L,int R,int l=1,int r=n,int p=1){
if(L>r||l>R)return ;
if(L<=l&&r<=R){
tree[p]=(tree[p]+d*(r-l+1))%P;
if(r>l)markA[p]=(markA[p]+d)%P;
return ;
}int mid=(l+r)>>1;
pushdown(p,l,r);
UpdA(d,L,R,mid+1,r,p<<1|1);
UpdA(d,L,R,l,mid,p<<1);
tree[p]=(tree[p<<1]+tree[p<<1|1])%P;
}
void UpdM(int d,int L,int R,int l=1,int r=n,int p=1){
if(L>r||l>R)return ;
if(L<=l&&r<=R){
tree[p]=tree[p]*d%P;
if(r>l){
markM[p]=markM[p]*d%P;
markA[p]=markA[p]*d%P;
}return ;
}int mid=(l+r)>>1;
pushdown(p,l,r);
UpdM(d,L,R,mid+1,r,p<<1|1);
UpdM(d,L,R,l,mid,p<<1);
tree[p]=(tree[p<<1]+tree[p<<1|1])%P;
}
int query(int L,int R,int l=1,int r=n,int p=1){
if(L>r||l>R)return 0;
if(L<=l&&r<=R)return tree[p]%P;
int mid=(l+r)>>1;
pushdown(p,l,r);
return (query(L,R,mid+1,r,p<<1|1)+query(L,R,l,mid,p<<1))%P;
}
signed main(){
scanf("%lld%lld%lld",&n,&m,&P);
for(int i=1;i<=n;i++)scanf("%lld",A+i);
dfs();
for(int i=1;i<=m;i++){
int o,x,y,k;
scanf("%lld%lld%lld",&o,&x,&y);
if(o!=3){
scanf("%lld",&k);
if(o==2)UpdA(k,x,y);
else UpdM(k,x,y);
}else printf("%lld\n",query(x,y));
}
return 0;
}