洛谷题单指南-线段树-P3373 【模板】线段树 2
原题链接:https://www.luogu.com.cn/problem/P3373
题意解读:对于序列a[n],支持三种操作:1.对区间每个数乘上一个数 2.对区间每个数加上一个数 3.求区间和
解题思路:由于支持乘、加两种区间修改操作,是线段树的另一种典型应用:多个懒标记
显然,这里需要两个懒标记,mul表示对子节点区间每个数乘mul,add表示对子节点区间每个数加上add,节点定义如下:
struct Node
{
int l, r;
LL sum; //区间和
LL mul; //懒标记,子节点区间每个数乘上mul,默认值为1
LL add; //懒标记,子节点区间每个数加上add,默认值为0
} tr[N * 4];
下面就要考虑sum、mul、add如何修改的问题
对于一个节点u,
如果要对其区间每个数乘mul,则有tr[u].sum = tr[u].sum * mul
如果要对其区间每个数加add,则有tr[u].sum = tr[u].sum + (tr[u].r - tr[u].l + 1) * add
在区间更新时,可以把乘和加统一成一个操作:tr[u].sum = tr[u].sum * mul + (tr[u].r - tr[u].l + 1) * add(加操作时mul设置为1,乘操作时add设置为0)
上面解决了sum修改的问题,接下来,就要看mul、add如何修改,关键在于要考虑mul、add的优先级?
1、先加后乘
假设先执行加法,后执行乘法,那么对于懒标记mul,add,意味着对其区间每一个数x都执行(x + add) * mul,
如果再来一个加add'操作,区间每一个数变成(x + add) * mul + add',不难分析,无法通过将add、mul进行更新得到形如(x + add) * mul的形式,
所以先加后乘不可行。
2、先乘后加
假设先执行乘法,后执行加法,那么对于懒标记mul,add,意味着对其区间每一个数x都执行x * mul + add,
如果再来一个加add'操作,区间每一个数变成x * mul + add + add',显然通过将add += add',即可以通过x * mul + add得到正确的结果;
如果再来一个乘mul'操作,区间每一个数变成(x * mul + add) * mul' = x * mul * mul' + add * mul',显然通过将mul *= mul', add * mul',即可以通过x * mul + add得到正确的结果。
确定了操作优先级,也就确定了懒标记的更新方式,可以将乘和加统一处理:
对于一个节点u,对其区间每个数乘mul,加add,如果只加则mul=1,如果只乘则add=0,懒标记更新方式为:
tr[u].mul = tr[u].mul * mul
void addtag(int u, LL mul, LL add)
{
tr[u].sum = (tr[u].sum * mul + (tr[u].r - tr[u].l + 1) * add) % m;
tr[u].mul = tr[u].mul * mul % m;
tr[u].add = (tr[u].add * mul + add) % m;
}
最后要注意的还是开long long。
100分代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 100005;
struct Node
{
int l, r;
LL sum; //区间和
LL mul; //懒标记,子节点区间每个数乘上mul,默认值为1
LL add; //懒标记,子节点区间每个数加上add,默认值为0
} tr[N * 4];
LL a[N];
int n, q, m;
void pushup(int u)
{
tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % m;
}
void build(int u, int l, int r)
{
tr[u] = {l, r, 0, 1, 0};
if(l == r) tr[u].sum = a[l];
else
{
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void addtag(int u, LL mul, LL add)
{
tr[u].sum = (tr[u].sum * mul + (tr[u].r - tr[u].l + 1) * add) % m;
tr[u].mul = tr[u].mul * mul % m;
tr[u].add = (tr[u].add * mul + add) % m;
}
void pushdown(int u)
{
addtag(u << 1, tr[u].mul, tr[u].add);
addtag(u << 1 | 1, tr[u].mul, tr[u].add);
tr[u].mul = 1;
tr[u].add = 0;
}
LL query(int u, int l, int r)
{
if(tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
else if(tr[u].l > r || tr[u].r < l) return 0;
else
{
pushdown(u);
return (query(u << 1, l, r) + query(u << 1 | 1, l, r)) % m;
}
}
void update(int u, int l, int r, LL mul, LL add)
{
if(tr[u].l >= l && tr[u].r <= r) addtag(u, mul, add);
else if(tr[u].l > r || tr[u].r < l) return;
else
{
pushdown(u);
update(u << 1, l, r, mul, add);
update(u << 1 | 1, l, r, mul, add);
pushup(u);
}
}
int main()
{
cin >> n >> q >> m;
for(int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
int op, x, y, k;
while(q--)
{
cin >> op >> x >> y;
if(op == 1)
{
cin >> k;
update(1, x, y, k, 0); //乘k加0
}
else if(op == 2)
{
cin >> k;
update(1, x, y, 1, k); //乘1加k
}
else cout << query(1, x, y) << endl;
}
}