线段树与可持久化线段树
发现博客里没有线段树相关板子,遂决定跟主席树一起写一篇方便自己查看
可持久化线段树据说是去年的铜牌题,还有不到一个月就昆明站了,补一下(
内容参考:
线段树 从入门到进阶
OI Wiki
可持久化线段树(含主席树)原理与实现
线段树
线段树的构建
线段树是一种二叉树,也就是对于一个线段,我们会用一个二叉树来表示。例如一个区间 1~4,那么其性质:节点 i 的权值 = 左儿子权值 + 右儿子权值,即可理解为 1~4 的和就是等于 1~2 的和加 2~3 的和
根据这个思路,我们就可以建树了,设一个结构体 tree,tree[i].l 和 tree[i].r 分别表示这个点代表的线段的左右下标,tree[i].sum 表示这个节点表示的线段的和
根据二叉树的性质:编号为 n 的父亲节点的左儿子和右儿子编号分别是 n * 2 和 n * 2 + 1
得到式子:tree[i].sum = tree[i * 2].sum + tree[i * 2 + 1].sum
inline void build(int i, int l, int r) {
tree[i].l = l; tree[i].r = r;
if (l == r) {
tree[i].sum = a[l];
return;
}
int mid = (l + r) >> 1;
build(i * 2, l, mid);
build(i * 2 + 1, mid + 1, r);
tree[i].sum = tree[i * 2].sum + tree[i * 2 + 1].sum;
return;
}
无 pushdown 线段树
单点修改,区间查询
1. 单点修改
修改区间单点即修改区间第 pos 位上的数据,首先从根节点开始,看 pos 在左儿子区域还是右儿子区域,在哪边就往那边继续搜直到到达目标位置,修改后返回更新路过的所有点
inline void add(int i, int pos, int k) {
if (tree[i].l == tree[i].r) {
tree[i].sum += k;
return;
}
if (pos <= tree[i * 2].r) add(i * 2, pos, k);
else add(i * 2 + 1, pos, k);
tree[i].sum = tree[i * 2].sum + tree[i * 2 + 1].sum;
return;
}
2. 区间查询
线段树的查询方式为:
-
如果这个区间被完全包括在目标区间里面,直接返回这个区间的值
-
如果这个区间的左儿子和目标区间有交集,那么搜索左儿子
-
如果这个区间的右儿子和目标区间有交集,那么搜索右儿子
inline int search(int i, int l, int r) {
if (tree[i].l >= l && tree[i].r <= r) return tree[i].sum;
if (tree[i].r < l || tree[i].l > r) return 0;
int sum = 0;
if (tree[i * 2].r >= l) sum += search(i * 2, l, r);
if (tree[i * 2 + 1].l <= r) sum += search(i * 2 + 1, l, r);
return sum;
}
例题:P3374 【模板】树状数组 1
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 1e7 + 10;
ll n, m, x, y, k, a[maxn];
struct node {
ll l, r, sum, lz;
}tree[maxn];
inline void build(ll i, ll l, ll r) {
tree[i].l = l; tree[i].r = r;
if (l == r) {
tree[i].sum = a[l];
return;
}
ll mid = (l + r) >> 1;
build(i * 2, l, mid);
build(i * 2 + 1, mid + 1, r);
tree[i].sum = tree[i * 2].sum + tree[i * 2 + 1].sum;
}
inline void add(ll i, ll dis, ll k) {
if (tree[i].l == tree[i].r) {
tree[i].sum += k;
return;
}
if (dis <= tree[i * 2].r) add(i * 2, dis, k);
else add(i * 2 + 1, dis, k);
tree[i].sum = tree[i * 2].sum + tree[i * 2 + 1].sum;
return;
}
inline ll search(ll i, ll l, ll r) {
if (tree[i].r < l || tree[i].l > r) return 0;
if (tree[i].l >= l && tree[i].r <= r) return tree[i].sum;
ll sum = 0;
if (tree[i * 2].r >= l) sum += search(i * 2, l, r);
if (tree[i * 2 + 1].l <= r) sum += search(i * 2 + 1, l, r);
return sum;
}
int main() {
scanf("%lld%lld", &n, &m);
for (int i = 1; i <= n; i++) scanf("%lld", &a[i]);
build(1, 1, n);
for (int i = 1; i <= m; i++) {
int num;
scanf("%d", &num);
if (num == 1) {
scanf("%lld%lld", &x, &k);
add(1, x, k);
}
if (num == 2) {
scanf("%lld%lld", &x, &y);
ll ans = search(1, x, y);
printf("%lld\n", ans);
}
}
return 0;
}
区间修改,单点查询
1. 区间修改
区间修改和单点查询,我们的思路就变为:如果把这个区间加上 k,相当于把这个区间涂上一个 k 的标记,然后单点查询从上往下跑的时候把沿路的标记加起来就好
这里面给区间贴标记的方式与上面的区间查找类似,原则还是那三条,只不过第一条:如果这个区间被完全包括在目标区间里面,直接返回这个区间的值 变为了 如果这个区间被完全包括在目标区间里面,则将这个区间标记
inline void add(int i, int l, int r, int k) {
if (tree[i].l >= l && tree[i].r <= r) {
tree[i].sum += k;
return;
}
if (tree[i * 2].r >= l) add(i * 2, l, r, k);
if (tree[i * 2 + 1].l <= r) add(i * 2 + 1, l, r, k);
}
2. 单点查询
需要注意的是加上沿途的标记就好了
inline void search(int i, int pos) {
ans += tree[i].num;
if (tree[i].l == tree[i].r) return;
if (pos <= tree[i * 2].r) search(i * 2, pos);
if (pos >= tree[i * 2 + 1].l) search(i * 2 + 1, pos);
}
例题:P3368 【模板】树状数组 2
#include <bits/stdc++.h
#define ll long long
using namespace std;
const int maxn = 1e7 + 10;
ll n, m, x, y, z, ans, a[maxn];
struct node {
ll l, r, sum, lz;
}tree[maxn];
inline ll read() {
ll x = 0, k = 1;
char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') k = -1;
for (; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0';
return x * k;
}
inline void build(ll i, ll l, ll r) {
tree[i].l = l; tree[i].r = r;
if (l == r) {
tree[i].sum = a[l];
return;
}
ll mid = (l + r) >> 1;
build(i * 2, l, mid);
build(i * 2 + 1, mid + 1, r);
tree[i].sum = tree[i * 2].sum + tree[i * 2 + 1].sum;
}
inline void add(ll i, ll l, ll r, ll k) {
if (tree[i].r <= r && tree[i].l >= l) {
tree[i].sum += k;
return;
}
if (tree[i].r < l || tree[i].l > r) return;
if (tree[i << 1].r >= l) add(i << 1, l, r, k);
if (tree[i << 1 | 1].l <= r) add(i << 1 | 1, l, r, k);
//tree[i].sum = tree[i << 1].sum + tree[i << 1 | 1].sum;
return;
}
inline void search(int i, int dis) {
ans += tree[i].sum;
if (tree[i].l == tree[i].r) return;
if (dis <= tree[i * 2].r) search(i * 2, dis);
if (dis >= tree[i * 2 + 1].l) search(i * 2 + 1, dis);
}
int main() {
n = read(); m = read();
build(1, 1, n);
for (int i = 1; i <= n; i++) a[i] = read();
for (int i = 1; i <= m; i++) {
int num;
num = read();
if (num == 1) {
x = read(); y = read(); z = read();
add(1, x, y, z);
}
if (num == 2) {
x = read();
ans = 0;
search(1, x);
printf("%d\n", ans + a[x]);
}
}
return 0;
}
有 pushdown 线段树
如果问题涉及到了区间修改,区间查询这样一组的操作,就需要一个 lazytag 懒惰标记,这样在查询的时候就可以将区间的标记推下去
区间修改,区间查询
1. 区间修改
区间修改的时候,我们按照如下原则:
-
如果当前区间被完全覆盖在目标区间里,将这个区间的 sum 加上 k * (tree[i].r - tree[i].l + 1)
-
如果没有完全覆盖,则先下传懒标记
-
如果这个区间的左儿子和目标区间有交集,那么搜索左儿子
-
如果这个区间的右儿子和目标区间有交集,那么搜索右儿子
inline void push_down(int i) {
if (tree[i].lz != 0) {
tree[i * 2].lz += tree[i].lz;
tree[i * 2 + 1].lz += tree[i].lz;
int mid = (tree[i].l + tree[i].r) / 2;
tree[i * 2].data += tree[i].lz * (mid - tree[i * 2].l + 1);
tree[i * 2 + 1].data += tree[i].lz * (tree[i * 2 + 1].r - mid);
tree[i].lz = 0;
}
return;
}
inline void add(int i, int l, int r, int k) {
if (tree[i].r <= r && tree[i].l >= l) {
tree[i].sum += k * (tree[i].r - tree[i].l + 1);
tree[i].lz += k;
return;
}
push_down(i);
if (tree[i * 2].r >= l) add(i * 2, l, r, k);
if (tree[i * 2 + 1].l <= r) add(i * 2 + 1, l, r, k);
tree[i].sum = tree[i * 2].sum + tree[i * 2 + 1].sum;
return;
}
2. 区间查询
相较于无 pushdown 代码,当前区间不完全包含目标区间时,多了一步下传懒惰标记操作
inline int search(int i, int l, int r) {
if (tree[i].l >= l && tree[i].r <= r) return tree[i].sum;
if (tree[i].r < l || tree[i].l > r) return 0;
push_down(i);
int sum = 0;
if (tree[i * 2].r >= l) sum += search(i * 2, l, r);
if (tree[i * 2 + 1].l <= r) sum += search(i * 2 + 1, l, r);
return sum;
}
乘法/根号线段树
乘法线段树
如果一个线段树又加又乘,则当 lazytag 下标传递的时候,我们需要考虑是先加再乘还是先乘再加,所以要对 lazytag 做这样一个处理:将 lazytage 分为两种,分别是加法的 plz 和乘法的 mlz
处理方式如下:
-
mlz 的处理:pushdown 时乘上父亲的 mlz
-
plz 的处理:把原先的 plz 乘上父亲的 mlz 再加上父亲的 plz
inline void pushdown(long long i) {
long long k1 = tree[i].mlz, k2 = tree[i].plz;
tree[i << 1].sum = (tree[i << 1].sum * k1 + k2 * (tree[i << 1].r - tree[i << 1].l + 1)) % p;
tree[i << 1 | 1].sum = (tree[i << 1 | 1].sum * k1 + k2 * (tree[i << 1 | 1].r - tree[i << 1 | 1].l + 1)) % p;
tree[i << 1].mlz = (tree[i << 1].mlz * k1) % p;
tree[i << 1 | 1].mlz = (tree[i << 1 | 1].mlz * k1) % p;
tree[i << 1].plz = (tree[i << 1].plz * k1 + k2) % p;
tree[i << 1 | 1].plz = (tree[i << 1 | 1].plz * k1 + k2) % p;
tree[i].plz = 0; tree[i].mlz = 1;
return;
}
根号线段树
其实,根号线段树和除法线段树一样。乍一看感觉直接用 lazytage 标记除了多少,但是实际上,会出现精度问题
我们对于每个区间,维护最大值和最小值,然后每次修改时,如果这个区间的最大值的根号和最小值的根号一样,说明这个区间整体根号不会产生误差,就直接修改(除法同理)
其中,lazytage 把除法当成减法,记录的是这个区间里每个元素减去的值
inline void Sqrt(int i, int l, int r) {
if (tree[i].l >= l && tree[i].r <= r && (tree[i].minn - (long long)sqrt(tree[i].minn)) == (tree[i].maxx - (long long)sqrt(tree[i].maxx))) {
long long u = tree[i].minn - (long long)sqrt(tree[i].minn);
tree[i].lz += u;
tree[i].sum -= (tree[i].r - tree[i].l + 1) * u;
tree[i].minn -= u;
tree[i].maxx -= u;
return;
}
if (tree[i].r < l || tree[i].l > r) return;
push_down(i);
if (tree[i * 2].r >= l) Sqrt(i * 2, l, r);
if (tree[i * 2 + 1].l <= r) Sqrt(i * 2 + 1, l, r);
tree[i].sum = tree[i * 2].sum + tree[i * 2 + 1].sum;
tree[i].minn = min(tree[i * 2].minn, tree[i * 2 + 1].minn);
tree[i].maxx = max(tree[i * 2].maxx, tree[i * 2 + 1].maxx);
return;
}
然后 pushdown 没什么变化,就是要记得 tree[i].minn、tree[i].maxx 也要记得 -lazytag
一些例题
区间加法:P3372 【模板】线段树 1
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 1e6 + 10;
ll n, m, x, y, k, a[maxn];
struct node {
ll l, r, sum, lz;
}tree[maxn];
inline ll read() {
ll x = 0, k = 1; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') k = -1;
for (; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0';
return x * k;
}
inline void build(ll i, ll l, ll r) {
tree[i].l = l; tree[i].r = r;
if (l == r) {
tree[i].sum = a[l];
return;
}
ll mid = (l + r) >> 1;
build(i * 2, l, mid);
build(i * 2 + 1, mid + 1, r);
tree[i].sum = tree[i * 2].sum + tree[i * 2 + 1].sum;
}
inline void push_down(ll i) {
if (tree[i].lz != 0) {
tree[i * 2].lz += tree[i].lz;
tree[i * 2 + 1].lz += tree[i].lz;
ll mid = (tree[i].l + tree[i].r) / 2;
tree[i * 2].sum += tree[i].lz * (mid - tree[i * 2].l + 1);
tree[i * 2 + 1].sum += tree[i].lz * (tree[i * 2 + 1].r - mid);
tree[i].lz = 0;
}
return;
}
inline void add(ll i, ll l, ll r, ll k) {
if (tree[i].r <= r && tree[i].l >= l) {
tree[i].sum += k * (tree[i].r - tree[i].l + 1);
tree[i].lz += k;
return;
}
if (tree[i].r < l || tree[i].l > r) return;
push_down(i);
if (tree[i * 2].r >= l) add(i * 2, l, r, k);
if (tree[i * 2 + 1].l <= r) add(i * 2 + 1, l, r, k);
tree[i].sum = tree[i * 2].sum + tree[i * 2 + 1].sum;
return;
}
inline ll search(ll i, ll l, ll r) {
if (tree[i].l >= l && tree[i].r <= r) return tree[i].sum;
push_down(i);
ll sum = 0;
if (tree[i * 2].r >= l) sum += search(i * 2, l, r);
if (tree[i * 2 + 1].l <= r) sum += search(i * 2 + 1, l, r);
return sum;
}
int main() {
n = read(); m = read();
for (int i = 1; i <= n; i++) a[i] = read();
build(1, 1, n);
for (int i = 1; i <= m; i++) {
register ll num;
num = read();
if (num == 1) {
x = read(); y = read(); k = read();
add(1, x, y, k);
}
if (num == 2) {
x = read(); y = read();
ll ans = search(1, x, y);
printf("%lld\n", ans);
}
}
return 0;
}
区间乘法:P3373 【模板】线段树 2
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 4e5 + 10;
ll n, m, p, x, y, z, a[maxn];
struct node{
ll l, r, sum, plz, mlz;
}tree[maxn];
inline ll read() {
ll x = 0, k = 1;
char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') k = -1;
for (; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0';
return x * k;
}
inline void build(ll i, ll l, ll r) {
tree[i].l = l; tree[i].r = r;
tree[i].mlz = 1;
if (l == r) {
tree[i].sum = a[l] % p;
return;
}
ll mid = (l + r) >> 1;
build(i << 1, l, mid);
build(i << 1 | 1, mid + 1, r);
tree[i].sum = (tree[i << 1].sum + tree[i << 1 | 1].sum) % p;
return;
}
inline void push_down(ll i) {
ll k1 = tree[i].mlz, k2 = tree[i].plz;
tree[i << 1].sum = (tree[i << 1].sum * k1 + k2 * (tree[i << 1].r - tree[i << 1].l + 1)) % p;
tree[i << 1 | 1].sum = (tree[i << 1 | 1].sum * k1 + k2 * (tree[i << 1 | 1].r - tree[i << 1 | 1].l + 1)) % p;
tree[i << 1].mlz = (tree[i << 1].mlz * k1) % p;
tree[i << 1 | 1].mlz = (tree[i << 1 | 1].mlz * k1) % p;
tree[i << 1].plz = (tree[i << 1].plz * k1 + k2) % p;
tree[i << 1 | 1].plz = (tree[i << 1 | 1].plz * k1 + k2) % p;
tree[i].plz = 0;
tree[i].mlz = 1;
return;
}
inline void multiply(ll i, ll l, ll r, ll k) {
if (tree[i].r < l || tree[i].l > r) return;
if (tree[i].l >= l && tree[i].r <= r) {
tree[i].sum = (tree[i].sum * k) % p;
tree[i].mlz = (tree[i].mlz * k) % p;
tree[i].plz = (tree[i].plz * k) % p;
return;
}
push_down(i);
if (tree[i << 1].r >= l) multiply(i << 1, l, r, k);
if (tree[i << 1 | 1].l <= r) multiply(i << 1 | 1, l, r, k);
tree[i].sum = (tree[i << 1].sum + tree[i << 1 | 1].sum) % p;
return;
}
inline void add(ll i, ll l, ll r, ll k) {
if (tree[i].r <= r && tree[i].l >= l) {
tree[i].sum += (k * (tree[i].r - tree[i].l + 1)) % p;
tree[i].plz = (tree[i].plz + k) % p;
return;
}
if (tree[i].r < l || tree[i].l > r) return;
push_down(i);
if (tree[i << 1].r >= l) add(i << 1, l, r, k);
if (tree[i << 1 | 1].l <= r) add(i << 1 | 1, l, r, k);
tree[i].sum = (tree[i << 1].sum + tree[i << 1 | 1].sum) % p;
return;
}
inline ll search(ll i, ll l, ll r) {
if (tree[i].r < l || tree[i].l > r) return 0;
if (tree[i].l >= l && tree[i].r <= r) return tree[i].sum;
push_down(i);
ll sum = 0;
if (tree[i << 1].r >= l) sum += search(i << 1, l, r) % p;
if (tree[i << 1 | 1].l <= r) sum += search(i << 1 | 1, l, r) % p;
return sum % p;
}
int main() {
n = read(); m = read(); p = read();
for (int i = 1; i <= n; i++) a[i] = read();
build(1, 1, n);
for (int i = 1; i <= m; i++) {
int num;
num = read();
if (num == 1) {
x = read(); y = read(); z = read();
z %= p;
multiply(1, x, y, z);
}
if (num == 2) {
x = read(); y = read(); z = read();
z %= p;
add(1, x, y, z);
}
if (num == 3) {
x = read(); y = read();
z = search(1, x, y);
printf("%lld\n", z);
}
}
return 0;
}
可持久化线段树
介绍
可持久化线段树是一类线段树的实现方式,用于保存线段树的历史版本。比如,对于一个序列进行更改,你可以选择每一次更改前将上次更改后的序列复制一份,再在新序列中更改,以保留这个序列的所有历史版本。主席树是以下标为时间加入元素的一类可持久化线段树,常见的用法是静态查询区间第 k 大值
修改
线段树的修改应该是修改一条链,也就是两个相邻版本的的线段树只会有一条链有差别。那我们考虑动态开点,将后面版本的线段树那些没有更改的儿子指向前一个版本的线段树的节点,这样每次只会新建一条链的节点,空间会降低很多
查询
每次修改后存储这一次修改的新线段树的根,然后按照根下去正常查询即可
代码&例题:P3834 【模板】可持久化线段树 2
#include <algorithm>
#include <cstdio>
#include <cstring>
using namespace std;
const int maxn = 2e5; // 数据范围
int tot, n, m;
int sum[(maxn << 5) + 10], rt[maxn + 10], ls[(maxn << 5) + 10],
rs[(maxn << 5) + 10];
int a[maxn + 10], ind[maxn + 10], len;
inline int getid(const int &val) { // 离散化
return lower_bound(ind + 1, ind + len + 1, val) - ind;
}
int build(int l, int r) { // 建树
int root = ++tot;
if (l == r) return root;
int mid = l + r >> 1;
ls[root] = build(l, mid);
rs[root] = build(mid + 1, r);
return root; // 返回该子树的根节点
}
int update(int k, int l, int r, int root) { // 插入操作
int dir = ++tot;
ls[dir] = ls[root], rs[dir] = rs[root], sum[dir] = sum[root] + 1;
if (l == r) return dir;
int mid = l + r >> 1;
if (k <= mid)
ls[dir] = update(k, l, mid, ls[dir]);
else
rs[dir] = update(k, mid + 1, r, rs[dir]);
return dir;
}
int query(int u, int v, int l, int r, int k) { // 查询操作
int mid = l + r >> 1,
x = sum[ls[v]] - sum[ls[u]]; // 通过区间减法得到左儿子中所存储的数值个数
if (l == r) return l;
if (k <= x) // 若 k 小于等于 x ,则说明第 k 小的数字存储在在左儿子中
return query(ls[u], ls[v], l, mid, k);
else // 否则说明在右儿子中
return query(rs[u], rs[v], mid + 1, r, k - x);
}
inline void init() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++i) scanf("%d", a + i);
memcpy(ind, a, sizeof ind);
sort(ind + 1, ind + n + 1);
len = unique(ind + 1, ind + n + 1) - ind - 1;
rt[0] = build(1, len);
for (int i = 1; i <= n; ++i) rt[i] = update(getid(a[i]), 1, len, rt[i - 1]);
}
int l, r, k;
inline void work() {
while (m--) {
scanf("%d%d%d", &l, &r, &k);
printf("%d\n", ind[query(rt[l - 1], rt[r], 1, len, k)]); // 回答询问
}
}
int main() {
init();
work();
return 0;
}