Luogu P11217 Solution
闲话
一道质量不错的签到题。
题解
由于有区间操作,考虑使用线段树。
不难通过线段树维护需要区间增减操作的序列,同时不难通过二分寻找最后一个循环之后最后一个不会导致 youyou 死亡的垃圾桶。实际上,本题的难点在于 \(1\le q\le 10^6\),暴力实现上述思路的时间复杂度为 \(O(q\log^2 n)\),结合线段树的较大常数,通过此题需要卡常。
然而,线段树的结构决定了上述思路可以在 \(O(q\log n)\) 的时间复杂度内实现。
假设线段树维护了一个长度为 \(n\) 的序列 \(a\)。给出值 \(x\),你需要找出最大的 \(i\),满足 \(1\le i\le n\) 且 \(\sum_{j=1}^i a_j<x\)。
由于线段树的每个节点有左子节点和右子节点,则你可以确定你的目标在左子节点内或右子节点内。
具体地说,若 \(x\) 小于左子节点内元素之和,则你的目标在左子节点内,否则可能在右子节点内。
按照上述方法操作,遍历到线段树的叶子节点时一定可以找出答案。
时间复杂度 \(O(n+q\log n)\)。
代码
#include <cctype>
#include <cstdio>
using namespace std;
const int N = 2e5 + 10;
using ll = long long;
int n, q, a[N];
ll w, tr[N << 2], tag[N << 2];
inline void build(int x, int l, int r)
{
if (l == r)
return tr[x] = a[l], void();
int mid = (l + r) >> 1;
build(x << 1, l, mid);
build(x << 1 | 1, mid + 1, r);
tr[x] = tr[x << 1] + tr[x << 1 | 1];
}
void psh(int x, int l, int r)
{
int mid = (l + r) >> 1;
tr[x << 1] += tag[x] * (mid - l + 1);
tr[x << 1 | 1] += tag[x] * (r - mid);
tag[x << 1] += tag[x], tag[x << 1 | 1] += tag[x];
tag[x] = 0;
}
void update(int x, int l, int r, int lb, int rb, ll v)
{
if (l >= lb and r <= rb)
{
tr[x] += v * (r - l + 1);
tag[x] += v;
return;
}
if (tag[x])
psh(x, l, r);
int mid = (l + r) >> 1;
if (lb <= mid)
update(x << 1, l, mid, lb, rb, v);
if (rb > mid)
update(x << 1 | 1, mid + 1, r, lb, rb, v);
tr[x] = tr[x << 1] + tr[x << 1 | 1];
}
int lwr(int x, int l, int r, ll v)
{
if (l == r)
return tr[x] >= v ? l - 1 : l;
if (tag[x])
psh(x, l, r);
int mid = (l + r) >> 1;
ll res = 0;
if (tr[x << 1] < v)
res = lwr(x << 1 | 1, mid + 1, r, v - tr[x << 1]);
else
res = lwr(x << 1, l, mid, v);
tr[x] = tr[x << 1] + tr[x << 1 | 1];
return res;
}
ll run()
{
ll tmp = w, tsm = tr[1], res = 0, rd = 1;
// printf("tsm %lld\n", tsm);
while (tmp > tsm)
tmp -= tsm, res += n, tsm <<= 1, rd <<= 1;
// printf("t %lld %lld\n", res, (tmp + rd - 1) / rd);
return res + lwr(1, 1, n, (tmp + rd - 1) / rd);
}
template <typename _Tp> inline void read(_Tp &x)
{
char ch;
while (ch = getchar(), !isdigit(ch) and ~ch)
;
x = (ch ^ 48);
while (ch = getchar(), isdigit(ch))
x = (x << 3) + (x << 1) + (ch ^ 48);
}
template <typename _Tp, typename... _Args> inline void read(_Tp &x, _Args &...args)
{
read(x);
read(args...);
}
template <typename _Tp> inline void print(_Tp x)
{
if (x > 9)
print(x / 10);
putchar((x % 10) ^ 48);
}
int main()
{
// freopen("wxyt.in", "r", stdin);
// freopen("wxyt.out", "w", stdout);
read(n, q, w);
for (int i = 1; i <= n; i++)
{
read(a[i]);
}
build(1, 1, n);
for (int i = 1, l, r, v; i <= q; i++)
{
read(l, r, v);
update(1, 1, n, l, r, v);
print(run());
putchar('\n');
}
}