线段树
重构一下线段树的博客,关于线段树的相关定义以及证明过段时间再补
首先是一个简单线段树,这里叫做伪线段树,其实本质就是一个二叉树,仅能支持单点操作:
单点修改 + 区间查询
// 单点修改查询
// http://ybt.ssoier.cn:8088/problem_show.php?pid=1549
// https://www.luogu.com.cn/problem/P1198
// 用一维数组来存,当作完全二叉树来存
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10;
long long int m, p, n, last, t;
char op;
struct node
{
int l, r, v;
} tr[N * 4];
void pushup(int u) // 更新每个区间的最大值
{
tr[u].v = max(tr[u * 2].v, tr[2 * u + 1].v);
}
void build(int u, int l, int r) // 建立线段树
{
tr[u] = {l, r};
if (l == r)
return;
int mid = l + r >> 1;
build(2 * u, l, mid), build(2 * u + 1, mid + 1, r); // 向左建立,向右建立
}
int query(int u, int l, int r) // 查询
{
if (tr[u].l >= l && tr[u].r <= r)
return tr[u].v;
int mid = tr[u].l + tr[u].r >> 1, v = 0;
if (l <= mid)
v = query(2 * u, l, r);
if (r > mid)
v = max(v, query(2 * u + 1, l, r));
return v;
}
int modify(int u, int x, int v) // 修改
{
if (tr[u].l == x && tr[u].r == x)
tr[u].v = v;
else
{
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid)
modify(2 * u, x, v);
else
modify(2 * u + 1, x, v);
pushup(u);
}
}
int main()
{
cin >> m >> p;
build(1, 1, m);
while (m--)
{
cin >> op >> t;
if (op == 'Q')
{
last = query(1, n - t + 1, n);
cout << last << endl;
}
else
{
modify(1, n + 1, (last + t) % p);
n++;
}
}
return 0;
}
对于线段树的区间求和操作,我们只需要在线段树上对一个节点向上更新一下区间和就可以:
// //区间最大字段和
#include <bits/stdc++.h>
using namespace std;
const int N = 2e6 + 10;
int n, m, x, y, k, a[N];
struct node
{
int l, r, all, lmax, rmax, sum;
} tr[N * 4];
void pushup(node &w, node &l, node &r)
{
w.sum = l.sum + r.sum;
w.lmax = max(l.lmax, l.sum + r.lmax), w.rmax = max(r.sum + l.rmax, r.rmax);
w.all = max(max(l.all, r.all), l.rmax + r.lmax);
}
void pushup(int u)
{
pushup(tr[u], tr[2 * u], tr[2 * u + 1]);
}
void build(int u, int l, int r)
{
if (l == r)
tr[u] = {l, r, a[l], a[l], a[l], a[l]};
else
{
tr[u] = {l, r};
int mid = l + r >> 1;
build(2 * u, l, mid), build(2 * u + 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int x, int v)
{
if (tr[u].l == x && tr[u].r == x)
tr[u] = {x, x, v, v, v, v};
else
{
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid)
modify(2 * u, x, v);
else
modify(2 * u + 1, x, v);
pushup(u);
}
}
node query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r)
return tr[u];
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid)
return query(2 * u, l, r);
else if (l > mid)
return query(2 * u + 1, l, r);
else
{
auto left = query(2 * u, l, r), right = query(2 * u + 1, l, r);
node res;
pushup(res, left, right);
return res;
}
}
int main()
{
cin >> n >> m;
for (int i = 1; i <= n; i++)
cin >> a[i];
build(1, 1, n);
while (m--)
{
cin >> k >> x >> y;
if (k == 1)
{
if (x > y)
swap(x, y);
cout << query(1, x, y).all << endl;
}
else
modify(1, x, y);
}
return 0;
}
此外,线段树还能够支持区间的最大公约数,也就向上更新一下就可以了:
// 区间最大公约数
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 10;
char op;
int n, m, a[N];
struct node
{
int l, r, sum, d;
} tr[4 * N];
void pushup(node &w, node &l, node &r)
{
w.sum = l.sum + r.sum;
w.d = __gcd(l.d, r.d);
}
void pushup(int u)
{
pushup(tr[u], tr[2 * u], tr[2 * u + 1]);
}
void build(int u, int l, int r)
{
if (l == r)
tr[u] = {l, r, a[r] - a[r - 1], a[r] - a[r - 1]};
else
{
tr[u] = {l, r};
int mid = l + r >> 1;
build(2 * u, l, mid), build(2 * u + 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int x, int v)
{
if (tr[u].l == x && tr[u].r == x)
tr[u] = {x, x, tr[u].sum + v, tr[u].sum + v};
else
{
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid)
modify(2 * u, x, v);
else
modify(2 * u + 1, x, v);
pushup(u);
}
}
node query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r)
return tr[u];
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid)
return query(2 * u, l, r);
else if (l > mid)
return query(2 * u + 1, l, r);
else
{
auto left = query(2 * u, l, r), right = query(2 * u + 1, l, r);
node res;
pushup(res, left, right);
return res;
}
}
int main()
{
int l, r, d;
cin >> n >> m;
for (int i = 1; i <= n; i++)
cin >> a[i];
build(1, 1, n);
while (m--)
{
cin >> op >> l >> r;
if (op == 'Q')
{
auto left = query(1, 1, l), right = query(1, l + 1, r);
cout << abs(__gcd(left.sum, right.d)) << endl;
}
else
{
cin >> d;
modify(1, l, d);
if (r + 1 <= n)
modify(1, r + 1, -d);
}
}
return 0;
}
区间修改 + 区间查询
接下来是支持区间修改的,对于区间操作,我们可以做个标记,然后只修改父节点,如果要查询子节点的时候我们直接向下更新标记,子区间再进行修改即可
区间加模板
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e6 + 10;
struct node
{
int l, r, sum, lazy;
} tr[N * 4];
int a[N];
void pushdown(int u)
{
auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
left.sum += (left.r - left.l + 1) * root.lazy, left.lazy += root.lazy;
right.sum += (right.r - right.l + 1) * root.lazy, right.lazy += root.lazy;
root.lazy = 0;
}
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void build(int u, int l, int r)
{
if (l == r)
tr[u] = {l, r, a[l], 0};
else
{
tr[u] = {l, r};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
int query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r)
return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1, sum = 0;
if (l <= mid)
sum = query(u << 1, l, r);
if (r > mid)
sum = sum + query(u << 1 | 1, l, r);
return sum;
}
void modify(int u, int l, int r, int k)
{
if (tr[u].l >= l && tr[u].r <= r)
{
tr[u].sum += (tr[u].r - tr[u].l + 1) * k;
tr[u].lazy += k;
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid)
modify(u << 1, l, r, k);
if (r > mid)
modify(u << 1 | 1, l, r, k);
pushup(u);
}
signed main()
{
std::ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
int n, m;
cin >> n >> m;
for (int i = 1; i <= n; i++)
cin >> a[i];
build(1, 1, n);
while (m--)
{
int op, l, r, k;
cin >> op >> l >> r;
if (op == 1)
cin >> k, modify(1, l, r, k);
else
cout << query(1, l, r) << endl;
}
return 0;
}
区间乘
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 2e5 + 10;
int n, p, m, a[N], mod, x, y, z, k;
struct node
{
int l, r, sum, add, mul;
} tr[4 * N];
void pushup(int u)
{
tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % mod;
}
void val(node &t, int add, int mul)
{
t.sum = (t.sum * mul + (t.r - t.l + 1) * add) % mod;
t.mul = t.mul * mul % mod, t.add = (t.add * mul + add) % mod;
}
void pushdown(int u)
{
val(tr[u << 1], tr[u].add, tr[u].mul), val(tr[u << 1 | 1], tr[u].add, tr[u].mul);
tr[u].mul = 1, tr[u].add = 0;
}
void build(int u, int l, int r)
{
if (l == r)
tr[u] = {l, r, a[r], 0, 1};
else
{
tr[u] = {l, r, 0, 0, 1};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int l, int r, int add, int mul)
{
if (tr[u].l >= l && tr[u].r <= r)
val(tr[u], add, mul);
else
{
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid)
modify(u << 1, l, r, add, mul);
if (r > mid)
modify(u << 1 | 1, l, r, add, mul);
pushup(u);
}
}
int query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r)
return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1, res = 0;
if (l <= mid)
res = query(u << 1, l, r) % mod;
if (r > mid)
res = (res + query(u << 1 | 1, l, r)) % mod;
return res;
}
signed main()
{
cin >> n >> m >> p;
mod = p;
for (int i = 1; i <= n; i++)
cin >> a[i];
build(1, 1, n);
while (m--)
{
cin >> z >> x >> y;
if (z == 1)
cin >> k, modify(1, x, y, 0, k);
else if (z == 2)
cin >> k, modify(1, x, y, k, 1);
else
cout << query(1, x, y) % mod << endl;
}
return 0;
}
区间最大公约数
在前面, 我们已经介绍了一种不带区间修改的线段树方法, 现在我们来讨论一下如何对于一个带修的线段树进行区间查询, 首先对于最大公约数, 我们有如下公式: \(\text{gcd}(a, b, c) = \text{gcd}(a, b - a, c - b)\)
一个显然的方法就是我们用差分的形式在线段树上表示出来, 所以这样我们就将一个区间修改变成了一个单点修改, 那么对于这个修改, 很明显我们可以直接将变化后的值与原来的最大公约数再做一遍 \(gcd\) 即可, 这里给出一个证明: 假设 \(\text{gcd}(a, b, c) = x\) 此时我们将 \(b\) 变成 \(b + k\), 那么有: \(\text{gcd}(a, b, c) = \text{gcd}(x, b)\)
这里的 \(b\) 是已经修改过的. 时间复杂度 \(O(mlog^2n)\)
至此就完成了, 后面在区间最值还会用到, 不再给出证明
区间异或
此外,线段树还支持区间修改以及区间异或,这里与之前有所不同,如果按照之前的做法,我们进行区间修改的操作是要具体到某个叶节点的,因为这样我们才能进行区间的整体修改,但是这样的复杂度很明显,对于具体的区间修改,我们的复杂度是
\(O(n)\) 的,考虑如何进行优化:
首先对于一个 \(01\) 串,如果我们按位取反,很明显我们的0会变成1,1会变成0,设我们原来有 \(x\) 个1,那么取反之后我们就有 \(size - x\) 个1
那么由此联想到,我们可以把线段树的每个叶节点的值拆分成二进制模式,然后对其位数进行操作,对于异或的x,我们把他拆位,如果某个位是1,那么就打个标记代表要进行取反,那么这个区间中这个位的总数就变成了区间长度-区间内的数\(\oplus\) 这个位=1的个数,对于求和操作,把每个位的个数乘上 \(2^i\) 然后相加即可
区间异或
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 2e5 + 10, mod = 1e9 + 7;
struct node
{
int l, r, should[25], tag[25];
//should是每个位为1的个数
//tag是每个位是否需要取反
} tr[N << 1];
int a[N];
void pushup(int u)
{
//统计区间每个位的总数
for (int i = 1; i <= 21; i++)
tr[u].should[i] = tr[u << 1].should[i] + tr[u << 1 | 1].should[i];
}
void pushdown(int u)
{
auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
for (int i = 1; i <= 21; i++)
{
if (root.tag[i]) //若需要取反,则取反
{
left.should[i] = (left.r - left.l + 1) - left.should[i];
right.should[i] = (right.r - right.l + 1) - right.should[i];
}
left.tag[i] ^= root.tag[i];
right.tag[i] ^= root.tag[i];
root.tag[i] = 0;
}
}
void build(int u, int l, int r)
{
tr[u] = {l, r};
if (l == r)
{
for (int i = 1; i <= 21; i++)
if (a[l] & (1 << (i - 1)))
tr[u].should[i] = 1; //标记为1
}
else
{
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int l, int r, int x)
{
if (tr[u].l >= l && tr[u].r <= r)
{
for (int i = 1; i <= 21; i++)
{
if (x & (1 << (i - 1)))
//区间长度-目前的位=1的个数
tr[u].should[i] = (tr[u].r - tr[u].l + 1) - tr[u].should[i];
tr[u].tag[i] ^= (x & (1 << (i - 1))) ? 1 : 0;
}
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid)
modify(u << 1, l, r, x);
if (r > mid)
modify(u << 1 | 1, l, r, x);
pushup(u);
}
int query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r)
{
int ans = 0;
for (int i = 1; i <= 21; i++)
{
int now = (1 << (i - 1));
ans += (now * tr[u].should[i]);
}
return ans;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1, res = 0;
if (l <= mid)
res += query(u << 1, l, r);
if (r > mid)
res += query(u << 1 | 1, l, r);
return res;
}
signed main()
{
std::ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
int n;
cin >> n;
for (int i = 1; i <= n; i++)
cin >> a[i];
build(1, 1, n);
int m;
cin >> m;
while (m--)
{
int op;
cin >> op;
if (op == 1)
{
int l, r;
cin >> l >> r;
cout << query(1, l, r) << '\n';
}
else
{
int x, l, r;
cin >> l >> r >> x;
modify(1, l, r, x);
}
}
return 0;
}
同时进行操作
对于区间同时进行多种操作, 我们考虑添加多个懒标记进行操作, 同时要注意次序问题, 这里拿出 Hdu 4578 举例:
题目要求区间加, 乘, 修改成相同的数, 我们考虑用 \(3\) 个 \(Lazy - tag\) 来进行标记, 这里主要难点在于求出区间的和, 平方和, 立方和
考虑推公式:
代码:
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e6 + 10, mod = 10007;
struct tree{
int l, r, tag[3], sum[3];
}tr[N << 1];
int a[N];
void pushup(int u){
tr[u].sum[0] = (tr[u << 1].sum[0] + tr[u << 1 | 1].sum[0]) % mod;
tr[u].sum[1] = (tr[u << 1].sum[1] + tr[u << 1 | 1].sum[1]) % mod;
tr[u].sum[2] = (tr[u << 1].sum[2] + tr[u << 1 | 1].sum[2]) % mod;
}
void make1(tree &x, int val){
x.tag[0] = 0;
x.tag[1] = 1;
x.tag[2] = val;
x.sum[0] = (x.r - x.l + 1) * x.tag[2] % mod;
x.sum[1] = (x.r - x.l + 1) * x.tag[2] % mod * x.tag[2] % mod;
x.sum[2] = (x.r - x.l + 1) * x.tag[2] % mod * x.tag[2] % mod * x.tag[2] % mod;
}
// ok, let's fix everything now
void make2(tree &x, int add){
// 这里注意, 要倒序更新, 类似与背包, 上一个更改的值会影响下一个的值
x.sum[2] = ((x.sum[2] + (x.r - x.l + 1) * add % mod * add % mod * add % mod) % mod +
(3 * x.sum[0] % mod * add % mod * add % mod + 3 * x.sum[1] % mod * add % mod) % mod) % mod;
x.sum[1] = ((x.sum[1] + (x.r - x.l + 1) * add % mod * add % mod) % mod + 2 * x.sum[0] % mod * add % mod) % mod;
x.sum[0] = ((x.r - x.l + 1) * add % mod + x.sum[0] % mod) % mod;
x.tag[0] = (x.tag[0] + add) % mod;
}
void make3(tree &x, int mul){
x.sum[0] = x.sum[0] * mul % mod;
x.sum[1] = x.sum[1] * mul % mod * mul % mod;
x.sum[2] = x.sum[2] * mul % mod * mul % mod * mul % mod;
x.tag[0] = x.tag[0] * mul % mod;
x.tag[1] = x.tag[1] * mul % mod;
}
void pushdown(int u){
// 这里的顺序也不可变 , 由于我们后面的区间加操作与求和相关, 所以如果有相乘操作要先进行操作
if(tr[u].tag[2] != 0){
make1(tr[u << 1], tr[u].tag[2]);
make1(tr[u << 1 | 1], tr[u].tag[2]);
tr[u].tag[2] = 0;
}
if(tr[u].tag[1] != 1){
make3(tr[u << 1], tr[u].tag[1]);
make3(tr[u << 1 | 1], tr[u].tag[1]);
tr[u].tag[1] = 1;
}
if(tr[u].tag[0] != 0){
make2(tr[u << 1], tr[u].tag[0]);
make2(tr[u << 1 | 1], tr[u].tag[0]);
tr[u].tag[0] = 0;
}
}
void build(int u, int l, int r ){
if(l == r){
tr[u] = {l, r, {0, 1, 0}, {0, 0, 0}};
return;
}
tr[u] = {l, r, {0, 1, 0}, {0, 0, 0}};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify(int u, int l, int r, int add, int mul, int change){
if(tr[u].l >= l && tr[u].r <= r){
if(change != 0){
make1(tr[u], change);
}
else if(add != 0){
make2(tr[u], add);
}
else if(mul != 1){
make3(tr[u], mul);
}
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) modify(u << 1, l, r, add, mul, change);
if(r > mid) modify(u << 1 | 1, l, r, add, mul, change);
pushup(u);
}
int query(int u, int l, int r, int op){
if(tr[u].l >= l && tr[u].r <= r){
return tr[u].sum[op];
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1, res = 0;
if(l <= mid) res = query(u << 1, l, r, op);
if(r > mid) res = (res + query(u << 1 | 1, l, r, op)) % mod;
pushup(u);
return res;
}
signed main()
{
std::ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
int n, m;
while(cin >> n >> m && n != 0 && m != 0){
build(1, 1, n);
while(m--){
int op, l, r, p; cin >> op >> l >> r >> p;
if(op == 1) modify(1, l, r, p, 1, 0);
else if(op == 2) modify(1, l, r, 0, p, 0);
else if(op == 3) modify(1, l, r, 0, 1, p);
else cout << query(1, l, r, p - 1) << '\n';
}
}
return 0;
}
区间最值操作
上述的操作基本都是对区间的每一个数都做出改变, 通过如此操作我们的总和也较容易维护, 但是如果出现这一种操作: 对于一个区间 \([L,R]\) 中的所有数做: \(a_i = \text{min}(a_i, x)\), 这样的操作也是可以通过线段树来进行 \(log\) 级别的维护的, 具体方法如下:
线段树上额外维护这些值: 区间的最大值 \(mx\), 区间的次大值 \(se\), 区间最大值的个数 \(num\), 区间和 \(sum\)
那么当我们要进行一个操作的时候, 如果有:
- \(x >= mx\) 那么很明显不需要更新
- \(se <= x < mx\), 那么此时符合条件的只有最大值, 我们直接把最大值改成 \(x\), 然后更新一下总和即可, \(sum = sum - num * (mx - x)\)
- \(x < se\) 那么这个区间是更新不了, 递归左右儿子更新即可
此时就完成了, 这么做的复杂度是 \(log\), 先不做证明了, 后面补
这里给出一个非常经典的例题: 最假女选手:
注意, 在更新最大和次大以及最小和次小的时候注意两者是否可能相等, 若相等还需要进一步更改
#include <bits/stdc++.h>
#define ls u << 1
#define rs u << 1 | 1
#define LL long long
#define inf 1000000000
using namespace std;
const int N = 5e5 + 10, mod = 1e9 + 7;
int a[N];
struct node{
int l, r, tag1[2], tag2[2], num[2], vis[2], add;
LL sum;
}tr[N << 2];
void make(int u, int x){
tr[u].sum = tr[u].sum + 1LL * (tr[u].r - tr[u].l + 1) * x;
tr[u].tag1[0] += x, tr[u].tag1[1] += x;
tr[u].tag2[0] += x, tr[u].tag2[1] += x;
tr[u].add += x;
}
void make1(int u, int x){
if(x >= tr[u].tag1[0]) return;
tr[u].sum = tr[u].sum - 1LL * tr[u].num[0] * (tr[u].tag1[0] - x);
if(tr[u].tag2[0] == tr[u].tag1[0]) tr[u].tag2[0] = x;
if(tr[u].tag2[1] == tr[u].tag1[0]) tr[u].tag2[1] = x;
tr[u].tag1[0] = tr[u].vis[0] = x;
}
void make2(int u, int x){
if(x <= tr[u].tag2[0]) return;
tr[u].sum = tr[u].sum + 1LL * tr[u].num[1] * (x - tr[u].tag2[0]);
if(tr[u].tag1[0] == tr[u].tag2[0]) tr[u].tag1[0] = x;
if(tr[u].tag1[1] == tr[u].tag2[0]) tr[u].tag1[1] = x;
tr[u].tag2[0] = tr[u].vis[1] = x;
}
void pushup(int u){
tr[u].sum = tr[ls].sum + tr[rs].sum;
tr[u].tag1[0] = max(tr[ls].tag1[0], tr[rs].tag1[0]);
if(tr[ls].tag1[0] == tr[rs].tag1[0]){
tr[u].tag1[1] = max(tr[ls].tag1[1], tr[rs].tag1[1]);
tr[u].num[0] = tr[ls].num[0] + tr[rs].num[0];
}
else{
tr[u].tag1[1] = min(tr[ls].tag1[0], tr[rs].tag1[0]);
tr[u].tag1[1] = max({tr[u].tag1[1], tr[ls].tag1[1], tr[rs].tag1[1]});
tr[u].num[0] = tr[ls].tag1[0] > tr[rs].tag1[0] ? tr[ls].num[0] : tr[rs].num[0];
}
tr[u].tag2[0] = min(tr[ls].tag2[0], tr[rs].tag2[0]);
if(tr[ls].tag2[0] == tr[rs].tag2[0]){
tr[u].tag2[1] = min(tr[ls].tag2[1], tr[rs].tag2[1]);
tr[u].num[1] = tr[ls].num[1] + tr[rs].num[1];
}
else{
tr[u].tag2[1] = max(tr[ls].tag2[0], tr[rs].tag2[0]);
tr[u].tag2[1] = min({tr[u].tag2[1], tr[ls].tag2[1], tr[rs].tag2[1]});
tr[u].num[1] = tr[ls].tag2[0] < tr[rs].tag2[0] ? tr[ls].num[1] : tr[rs].num[1];
}
}
void pushdown(int u){
if(tr[u].add != 0){
make(ls, tr[u].add), make(rs, tr[u].add);
tr[u].add = 0;
}
make1(ls, tr[u].tag1[0]), make1(rs, tr[u].tag1[0]);
make2(ls, tr[u].tag2[0]), make2(rs, tr[u].tag2[0]);
}
void build(int u, int l, int r){
if(l == r){
tr[u] = {l, r, {a[l], -inf}, {a[l], inf}, {1, 1}, {-inf, inf}, 0, 1LL * a[l]};
return;
}
tr[u] = {l, r, {-inf, -inf}, {inf, inf}, {0, 0}, {-inf, inf}, 0, 1LL * 0};
int mid = l + r >> 1;
build(ls, l, mid), build(rs, mid+1 ,r);
pushup(u);
}
void modifyadd(int u, int l, int r, int x){
if(tr[u].l >= l && tr[u].r <= r){
make(u, x);
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) modifyadd(ls, l, r, x);
if(r > mid) modifyadd(rs, l, r, x);
pushup(u);
}
void modifymin(int u, int l, int r, int x){
if(tr[u].tag1[0] <= x) return;
if(tr[u].l >= l && tr[u].r <= r && x > tr[u].tag1[1]){
make1(u, x);
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) modifymin(ls, l, r, x);
if(r > mid) modifymin(rs, l, r, x);
pushup(u);
}
void modifymax(int u, int l, int r, int x){
if(tr[u].tag2[0] >= x) return;
if(tr[u].l >= l && tr[u].r <= r && x < tr[u].tag2[1]){
make2(u, x);
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) modifymax(ls, l, r, x);
if(r > mid) modifymax(rs, l, r, x);
pushup(u);
}
LL querysum(int u, int l, int r){
if(tr[u].l >= l && tr[u].r <= r){
return tr[u].sum;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
LL res = 0;
if(l <= mid) res += querysum(ls, l, r);
if(r > mid) res += querysum(rs, l, r);
return res;
}
int querymax(int u, int l, int r){
if(tr[u].l >= l && tr[u].r <= r){
return tr[u].tag1[0];
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1, res = -inf;
if(l <= mid) res = max(res, querymax(ls, l, r));
if(r > mid) res = max(res, querymax(rs, l, r));
return res;
}
int querymin(int u, int l, int r){
if(tr[u].l >= l && tr[u].r <= r){
return tr[u].tag2[0];
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1, res = inf;
if(l <= mid) res = min(res, querymin(ls, l, r));
if(r > mid) res = min(res, querymin(rs, l, r));
return res;
}
void debug(int i){
cout << "----now----" << endl;
cout << tr[i].l << " " << tr[i].r << endl;
cout << tr[i].tag1[0] << " " << tr[i].tag1[1] << endl;
cout << tr[i].tag2[0] << " " << tr[i].tag2[1] << endl;
cout << tr[i].sum << endl;
cout << "----end----" << endl;
}
signed main()
{
// std::ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
int n; cin >> n;
for(int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
int m; cin >> m;
while(m--){
int op, l, r; cin >> op >> l >> r;
if(op == 1){
int x; cin >> x;
modifyadd(1, l, r, x);
}
else if(op == 2){
int x; cin >> x;
modifymax(1, l, r, x);
}
else if(op == 3){
int x; cin >> x;
modifymin(1, l, r, x);
}
else if(op == 4) printf("%lld\n", querysum(1, l, r));
else if(op == 5) printf("%d\n", querymax(1, l, r));
else if(op == 6) printf("%d\n", querymin(1, l, r));
}
return 0;
}
主席树/可持久化线段树
旨在记录之前操作的状态, 其方法很好理解, 可以发现我们在进行一个单点更新的时候, 其实最多只会经过 \(log_n\) 的节点, 那么我们只需要向下图一样, 新建我们要用到的节点就可以了, 类似于动态开点一样, 然后我们每次记录不同的版本, 每个版本的入口就是根节点也就是 \(1'\), 其余的节点我们按照初始状态设置, 例如我们只需要经过左儿子, 那么我就不需要新建一个右儿子, 用原来的就可以, 形式化的就是 tr[new] = tr[old]
, 然后修改左儿子即可
【模板】可持久化线段树 1(可持久化数组)
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e6 + 10, mod = 1e9 + 7;
struct Seg{
int l, r, val;
}tr[N * 10];
int cnt;
int root[N], last[N], a[N];
void build(int u, int l, int r){
if(l == r){
tr[u].val = a[l];
return;
}
tr[u].l = ++cnt, tr[u].r = ++cnt;
int mid = l + r >> 1;
build(tr[u].l, l, mid), build(tr[u].r, mid + 1, r);
}
void modify(int &u, int old, int l, int r, int x, int v){
u = ++cnt;
tr[u] = tr[old], tr[u].val = v;
if(l == r) return;
int mid = l + r >> 1;
if(x <= mid) modify(tr[u].l, tr[old].l, l, mid, x, v);
else modify(tr[u].r, tr[old].r, mid + 1, r, x, v);
}
int query(int u, int pl, int pr, int l, int r){
if(pl >= l && pr <= r){
return tr[u].val;
}
int mid = pl + pr >> 1, res = 0;
if(l <= mid) res = query(tr[u].l, pl, mid, l, r);
if(r > mid) res = query(tr[u].r, mid + 1, pr, l, r);
return res;
}
signed main()
{
std::ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
int n, q; cin >> n >> q;
for(int i = 1; i <= n; i++) cin >> a[i];
root[0] = ++cnt;
build(root[0], 1, n);
for(int i = 1; i <= q; i++){
int v, op, loc; cin >> v >> op >> loc;
root[i] = root[v];
if(op == 1){
int x; cin >> x;
modify(root[i], root[v], 1, n, loc, x);
} else{
cout << query(root[v], 1, n, loc, loc) << '\n';
}
}
return 0;
}
ABC253 - F
由于操作二可以把之前所有的操作一的影响消除掉, 所以我们使用主席树来维护操作时间状态, 每次查询时, 我们知道上一次进行操作二的状last[x]
以及更改的值a[x]
, 那么结果就是a[x]
加上从上次操作二之后的操作一的和, 用主席树相减可以抵消
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e6 + 10, mod = 1e9 + 7;
struct Seg{
int l, r, val;
}tr[N << 2];
int cnt;
int root[N], last[N], a[N];
void build(int u, int l, int r){
if(l == r) return;
tr[u].l = ++cnt, tr[u].r = ++cnt;
int mid = l + r >> 1;
build(tr[u].l, l, mid), build(tr[u].r, mid + 1, r);
}
void modify(int &u, int old, int l, int r, int x, int v){
u = ++cnt;
tr[u] = tr[old], tr[u].val += v;
if(l == r) return;
int mid = l + r >> 1;
if(x <= mid) modify(tr[u].l, tr[old].l, l, mid, x, v);
else modify(tr[u].r, tr[old].r, mid + 1, r, x, v);
}
int query(int u, int old, int pl, int pr, int l, int r){
if(pl >= l && pr <= r){
return tr[u].val - tr[old].val;
}
int mid = pl + pr >> 1, res = 0;
if(l <= mid) res += query(tr[u].l, tr[old].l, pl, mid, l, r);
if(r > mid) res += query(tr[u].r, tr[old].r, mid + 1, pr, l, r);
return res;
}
signed main()
{
std::ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
int n, m, q; cin >> n >> m >> q;
root[0] = ++cnt;
build(root[0], 1, m);
for(int i = 1; i <= q; i++){
int op; cin >> op;
root[i] = root[i - 1];
if(op == 1){
int l, r, x; cin >> l >> r >> x;
modify(root[i], root[i], 1, m, l, x);
if(r + 1 <= m){
modify(root[i], root[i], 1, m, r + 1, -x);
}
} else if(op == 2){
int row, x; cin >> row >> x;
last[row] = i, a[row] = x;
} else{
int x, y; cin >> x >> y;
cout << a[x] + query(root[i], root[last[x]], 1, m, 1, y) << '\n';
}
}
return 0;
}