AtCoder Beginner Contest 357 F Solution
0. 闲话
上午做了三套地理卷子,只有一套上八十。
愁啊。
一看这次比赛的 F,更愁了。
怎么这次 ABC 又成 TemplateCoder 了。
没打真是亏大了。
遂写了这篇题解。
1. 题面
题面在这,自己看。
非常简洁且底层。
2. 思路
线段树板题,只是加了点东西。
这道题要维护两个数组,所以需要两个 tag 分别给 \(A\) 和 \(B\) 用。
build 的时候非常简单,只需要分别计算出每个 \(A_iB_i\) 然后存储区间和就可以了。
update 也差不多,只是在更新 \(A\) 的时候对应答案加上 \(x\) 乘 \(B\) 的对应区间和,更新 \(B\) 的时候对应答案加上 \(x\) 乘 \(A\) 的对应区间和就行了。
push 要考虑的事情就多了。
首先考虑,对于一个区间,设其目前 \(A\) 和为 \(a\),\(B\) 和为 \(b\),区间大小为 \(s\) ,现在要给区间内的每个 \(A_i\) 加上 \(x\) ,每个 \(B_i\) 加上 \(y\),原来该区间对应的答案为 \(r\),那么答案应该如何更新。
新的答案 \(s'\) 由四个部分组成。
第一个部分,原来的 \(s\)。
第二个部分,\(ay\),因为增加 \(y\) 带来的影响是由对应位置上的 \(A_i\) 决定的,它们的和为 \(a\),带来的总增量就是 \(ay\)。
第三个部分,\(bx\),由第二个部分同理可得,不进行赘述。
第四个部分,\(xys\)。因为,在计算第二个和第三个部分时,我们忽略了 \(x\) 和 \(y\) 互相之间的影响。实际上,假设区间内元素都是 \(0\),那么对区间内的每一个 \(A_i\) 和 \(B_i\) 执行上述操作就等同于给每个位置上的答案加上 \(xy\),总和即为 \(xys\)。
综上所述,
使用上方的公式进行下传不难。
记住要模,还要下传标记。
(我因为忘记下传标记错了两次,虽然是在赛后)
可以使用 AtCoder 给的 modint 库大幅降低码量。
3. 代码
#include <iostream>
#include <atcoder/modint>
using namespace std;
using namespace atcoder;
const int N = 2e5 + 10;
using ll = long long;
using mint = modint998244353;
struct st
{
mint as, bs, absm;
st operator+(const st &x) const
{
return {as + x.as, bs + x.bs, absm + x.absm};
}
};
st tr[N << 2];
mint taga[N << 2], tagb[N << 2];
ll a[N], b[N], opt, v;
int n, m;
void psh(int x, int l, int r)
{
int mid = (l + r) >> 1;
tr[x << 1].absm += tr[x << 1].as * tagb[x] + tr[x << 1].bs * taga[x] + taga[x] * tagb[x] * (mid - l + 1);
tr[x << 1 | 1].absm += tr[x << 1 | 1].as * tagb[x] + tr[x << 1 | 1].bs * taga[x] + taga[x] * tagb[x] * (r - mid);
tr[x << 1].as += taga[x] * (mid - l + 1);
tr[x << 1 | 1].as += taga[x] * (r - mid);
tr[x << 1].bs += tagb[x] * (mid - l + 1);
tr[x << 1 | 1].bs += tagb[x] * (r - mid);
taga[x << 1] += taga[x], taga[x << 1 | 1] += taga[x];
tagb[x << 1] += tagb[x], tagb[x << 1 | 1] += tagb[x];
taga[x] = tagb[x] = 0;
}
void build(int x, int l, int r)
{
if (l == r)
{
tr[x] = {a[l], b[l], a[l] * b[l]};
return;
}
int mid = (l + r) >> 1;
build(x << 1, l, mid);
build(x << 1 | 1, mid + 1, r);
tr[x] = tr[x << 1] + tr[x << 1 | 1];
}
void update(int x, int l, int r, int lb, int rb)
{
if (l >= lb and r <= rb)
{
tr[x].absm += v * (opt ? tr[x].as : tr[x].bs);
(opt ? tr[x].bs : tr[x].as) += v * (r - l + 1);
(opt ? tagb[x] : taga[x]) += v;
return;
}
if (taga[x] != 0 or tagb[x] != 0)
psh(x, l, r);
int mid = (l + r) >> 1;
if (lb <= mid)
update(x << 1, l, mid, lb, rb);
if (rb > mid)
update(x << 1 | 1, mid + 1, r, lb, rb);
tr[x] = tr[x << 1] + tr[x << 1 | 1];
}
st query(int x, int l, int r, int lb, int rb)
{
if (l >= lb and r <= rb)
{
return tr[x];
}
if (taga[x] != 0 or tagb[x] != 0)
psh(x, l, r);
int mid = (l + r) >> 1;
st ret{0, 0, 0};
if (lb <= mid)
ret = ret + query(x << 1, l, mid, lb, rb);
if (rb > mid)
ret = ret + query(x << 1 | 1, mid + 1, r, lb, rb);
tr[x] = tr[x << 1] + tr[x << 1 | 1];
return ret;
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
{
scanf("%lld", a + i);
}
for (int i = 1; i <= n; i++)
{
scanf("%lld", b + i);
}
build(1, 1, n);
for (int i = 1, l, r; i <= m; i++)
{
scanf("%lld%d%d", &opt, &l, &r);
if (opt <= 2)
opt--, scanf("%lld", &v), update(1, 1, n, l, r);
else
printf("%lld\n", query(1, 1, n, l, r).absm);
}
}
不多不少,刚刚一百行。