线段树
记录
15:48 2024-2-4
1.线段树
线段树是处理区间用的数据结构,最经典的例子是RMQ(Range Minimum Query),查询某个区间上的最大值
这里init/update是从叶子节点向上更新的 自底向上
点击查看代码
#include<cstdio>
#include<algorithm>
using namespace std;
const int INF = 0x3f3f3f3f;
const int MAX_N = 1 << 17;
// 存储线段树的全局数组
int n, dat[2 * MAX_N - 1];
//初始化
void init(int n_) {
// 为了简单起见,把元素个数扩大到2的幂
n = 1;
while (n < n_) n *= 2;
//把所有值都设为INT_MAX
for(int i = 0; i < 2 * n - 1; i++) dat[i] = INF;
}
// 把第k个值(0-indexed)更新为a
void upadate(int k, int a) {
// n + n/2 + n/4 + ... = 2n 所以0到n-1都是树枝
//叶子节点 (从0到n-1都是树枝)
k += n - 1;
dat[k] = a;
// 向上更新
while (k > 0){
k = (k - 1) / 2;
dat[k] = min(dat[k * 2 + 1], dat[k * 2 + 2]);
}
}
// RMQ range minimum query
// 求[a, b)最小值
// 后面的参数是为了计算起来方便而传入的
// k是节点的编号,1,r表示这个节点对应的是[l, r]区间
// 在外部调用时,用query(a, b, 0, 0, n)
int query(int a, int b, int k, int l, int r) {
//如果[a, b)和[l,r)不相交,则返回INT_MAX
if (r <= a || b <= l) return INF;
// 如果[a, b) 完全包含[l,r)则返回当前节点的值
if(a <= l && r <= b) return dat[k];
else {
//否则返回两个儿子中值的较小者
int vl = query(a, b, k * 2 + 1, l, (l + r) / 2);
int vr = query(a, b, k * 2 + 2, (l + r) / 2, r);
return min(vl, vr);
}
}
这里的build和change是自上而下的
点击查看代码
struct SegmentTree {
int l, r;
int dat;
} t[SIZE * 4]; // struct数组存储线段树
void build(int p, int l, int r) {
t[p].l = l, t[p].r = r; // 节点p代表区间[l,r]
if (l == r) { t[p].dat = a[l]; return; } // 叶节点
int mid = (l + r) / 2; // 折半
build(p*2, l, mid); // 左子节点[l,mid],编号p*2
build(p*2+1, mid+1, r); // 右子节点[mid+1,r],编号p*2+1
t[p].dat = max(t[p*2].dat, t[p*2+1].dat); // 从下往上传递信息
}
build(1, 1, n); // 调用入口
void change(int p, int x, int v) {
if (t[p].l == t[p].r) { t[p].dat = v; return; } // 找到叶节点
int mid = (t[p].l + t[p].r) / 2;
if (x <= mid) change(p*2, x, v); // x属于左半区间
else change(p*2+1, x, v); // x属于右半区间
t[p].dat = max(t[p*2].dat, t[p*2+1].dat); // 从下往上更新信息
}
change(1, x, v); // 调用入口
int ask(int p, int l, int r) {
if (l <= t[p].l && r >= t[p].r) return t[p].dat; // 完全包含,直接返回
int mid = (t[p].l + t[p].r) / 2;
int val = 0;
if (l <= mid) val = max(val, ask(p*2, l, r)); // 左子节点有重叠
if (r > mid) val = max(val, ask(p*2+1, l, r)); // 右子节点有重叠
return val;
}
cout << ask(1, l, r) << endl; // 调用入口
1. 区间增加 + 区间查询
线段树维护的是区间的值,如果对区间增加,需要对区间相关的节点全部更新才行。
为了保证线段树的高效性,将该问题转化: 对每个节点维护两个数据 (线段树本身就是维护区间数据的数据结构)
- 给这个节点对应的区间内的所有元素共同加上的值
这里意思就是说 操作是要把[a,b]之前共同加上一个值 v,如果某个区间[l,r] 有 a < l < r < b 说明这个区间内都要加上值,这里可以只记录这个值v,因为求的话乘上区间范围就行了 - 在这个节点对应的区间中除去2(↑)之外其他的值的和
这里意思是说[a,b]和区间[l,r]有交集,这个区间有部分要加上值v,这里记录要把这部分值全部算出来,因为没法通过区间范围求出来
如果对于父亲节点同时加了一个值, 那么这个值就不会在儿子节点被重复考虑。 在递归计算和时再把这一部分的值加到结果里面就可以了。
这就是重点了
点击查看代码
#include<iostream>
#include<cstdio>
using namespace std;
typedef long long ll;
const int MAXN = 100005;
const int DAT_SIZE = (1 << 18) - 1;
int N, Q;
ll a[MAXN];
// 线段树
// a 给这个节点对应的区间内的所有元素共同加上的值
// b 在这个节点对应的区间中除去a之外其他的值的和
ll datA[DAT_SIZE], datB[DAT_SIZE];
// 对区间[a, b)同时加x
// k是节点的编号, 对应的区间是[l, r)
void add(int a, int b, int x, int k, int l, int r) {
if (a <= l && r <= b) {
// a.给这个节点对应的区间内的所有元素共同加上的值
datA[k] += x;
} else if (l < b && a < r) {
// b.在这个节点对应的区间中除去a之外其他的值的和
datB[k] += (min(b, r) - max(a, l)) * x;
add(a, b, x, k * 2 + 1, l, (l + r) / 2);
add(a, b, x, k * 2 + 2, (l + r) / 2, r );
}
}
// 计算[a, b)的和
// k是节点的编号 , 对应的区间是[l, r)
ll sum(int a, int b, int k, int l, int r) {
if (b <= l||r <= a) {
return 0;
} else if (a <= l && r <= b) {
//a l r b 全部包含
return datA[k] * (r - l) + datB[k];
} else {
// 交集区间上同时加上的值 + 子集区间上加上的值
// 如果对于父亲节点同时加了一个值, 那么这个值就不会在儿子节点被重复考虑。
// 在递归计算和时再把这一部分的值加到结果里面就可以了。
ll res = (min(b, r) - max(a, l)) * datA[k];
res += sum(a, b, k * 2 + 1, l, (l + r) / 2);
res += sum(a, b, k * 2 + 2, (l + r) / 2, r);
return res;
}
}
int main() {
cin >> N >> Q;
for(int i = 0; i < N; i++) {
scanf("%lld", &a[i]);
add(i, i + 1, a[i], 0, 0, N);
}
char c[2];
int l, r, v;
for(int i = 0; i < Q; i++) {
// %s 它会读入一个不含空格、TAB和回车符的字符串,存入字符数组
// %c 会读入\n
scanf("%s%d%d", c, &l, &r);
if(c[0] == 'C') {
scanf("%d", &v);
add(l - 1, r, v, 0, 0, N);
} else {
printf("%lld\n" , sum(l - 1, r, 0, 0, N));
}
}
}
2. 延迟标记
在区间上维持一个延迟标记
延迟标记的含义为“该节点曾经被修改, 但其子节点尚未被更新”
在更新或查询的时候向下传递延迟标记
点击查看代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int SIZE=100010;
struct SegmentTree{
int l,r;
long long sum,add;
#define l(x) tree[x].l
#define r(x) tree[x].r
#define sum(x) tree[x].sum
#define add(x) tree[x].add
}tree[SIZE*4];
int a[SIZE],n,m;
void build(int p,int l,int r)//No.p, [l,r]
{
l(p)=l,r(p)=r;
if(l==r) { sum(p)=a[l]; return; }
int mid=(l+r)/2;
build(p*2,l,mid);
build(p*2+1,mid+1,r);
sum(p)=sum(p*2)+sum(p*2+1);
}
void spread(int p)
{
if(add(p))
{
sum(p*2)+=add(p)*(r(p*2)-l(p*2)+1);
sum(p*2+1)+=add(p)*(r(p*2+1)-l(p*2+1)+1);
add(p*2)+=add(p);
add(p*2+1)+=add(p);
add(p)=0;
}
}
void change(int p,int l,int r,int z)
{
if(l<=l(p)&&r>=r(p))
{
sum(p)+=(long long)z*(r(p)-l(p)+1);
add(p)+=z;
return;
}
spread(p);
int mid=(l(p)+r(p))/2;
if(l<=mid) change(p*2,l,r,z);
if(r>mid) change(p*2+1,l,r,z);
sum(p)=sum(p*2)+sum(p*2+1);
}
long long ask(int p,int l,int r)
{
if(l<=l(p)&&r>=r(p)) return sum(p);
spread(p);
int mid=(l(p)+r(p))/2;
long long ans=0;
if(l<=mid) ans+=ask(p*2,l,r);
if(r>mid) ans+=ask(p*2+1,l,r);
return ans;
}
int main()
{
cin>>n>>m;
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
build(1,1,n);
while(m--)
{
char op[2]; int x,y,z;
scanf("%s%d%d",op,&x,&y);
if(op[0]=='C')
{
scanf("%d",&z);
change(1,x,y,z);
}
else printf("%lld\n",ask(1,x,y));
}
}
3. 扫描线
感觉这个扫描线 只是告诉我们可以利用线段树来解决问题,并且利用有些题的特性,不需要做延迟标记
O(n^2) 没有利用线段树
点击查看代码
#include<vector>
#include<map>
#include<algorithm>
#include<cstdio>
#include<cstring>
#define rep(i, a, n) for (auto i = a; i < (n); ++i) // repeat
#define repe(i, a, n) for (auto i = a; i <= (n); ++i) // repeat and equal
#define revrep(i, a, n) for (auto i = n; i > (a); --i) // reverse repeat
#define revrepe(i, a, n) for (auto i = n; i >= (a); --i)
#define all(a) a.begin(), a.end()
#define sz(a) (int)(a.size());
#define mem(a,b) memset(a,b,sizeof(a))
#define lb(x) ((x) & -(x)) // lowbit
#define pb push_back
#define qb pop_back
#define pf push_front
#define qf pop_front
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef vector<vi> vvi;
template<class T> inline bool chmax(T &a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T &a, T b) { if (b < a) { a = b; return 1; } return 0; }
const int MAXN = 1e5 + 5;
struct P {
double x, y1, y2;
int k;
bool operator < (const P &p) const {
return x < p.x;
}
} a[MAXN << 1];
double raw[MAXN << 1];
map<double, int> val;
// 表示某个区间是否被覆盖
int c[MAXN << 1];
int N, ind;
int main() {
while(scanf("%d", &N) && N != 0) {
for(int i = 1; i <= N; i++) {
int k = i << 1;
double y1, y2;
scanf("%lf%lf%lf%lf", &a[k - 1].x, &y1, &a[k].x, &y2);
raw[k - 1] = a[k - 1].y1 = a[k].y1 = y1;
raw[k] = a[k - 1].y2 = a[k].y2 = y2;
// .k 可以表示是左边还是右边 和这里k (i << 2)是不一样的1
a[k - 1].k = 1;
a[k].k = -1;
}
N <<= 1;
// 离散化y
sort(raw + 1, raw + N + 1);
int m = unique(raw + 1, raw + N + 1) - (raw + 1);
// 记录下y对应的位置
for(int i = 1; i <= m; i++) val[raw[i]] = i;
sort(a + 1, a + N + 1);
memset(c, 0, sizeof(c));
double result = 0;
for(int i = 1; i < N; i++) {
int l = val[a[i].y1], r = val[a[i].y2];
for(int j = l; j < r; j++) c[j] += a[i].k;
double len = 0;
for(int j = 1; j < m; j++)
if(c[j] > 0) len += raw[j + 1] - raw[j];
result += len * (a[i + 1].x - a[i].x);
}
printf("Test case #%d\nTotal explored area: %.2f\n\n",++ind, result);
}
}
利用线段树 O(nlogn)
点击查看代码
#include<vector>
#include<map>
#include<algorithm>
#include<cstdio>
#include<cstring>
#define rep(i, a, n) for (auto i = a; i < (n); ++i) // repeat
#define repe(i, a, n) for (auto i = a; i <= (n); ++i) // repeat and equal
#define revrep(i, a, n) for (auto i = n; i > (a); --i) // reverse repeat
#define revrepe(i, a, n) for (auto i = n; i >= (a); --i)
#define all(a) a.begin(), a.end()
#define sz(a) (int)(a.size());
#define mem(a,b) memset(a,b,sizeof(a))
#define lb(x) ((x) & -(x)) // lowbit
#define pb push_back
#define qb pop_back
#define pf push_front
#define qf pop_front
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef vector<vi> vvi;
template<class T> inline bool chmax(T &a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T &a, T b) { if (b < a) { a = b; return 1; } return 0; }
const int MAXN = 1e5 + 5;
struct P {
double x, y1, y2;
int k;
bool operator < (const P &p) const {
return x < p.x;
}
} a[MAXN << 1];
double raw[MAXN << 1];
map<double, int> val;
// 线段树
struct T {
// cnt 表示节点自身被覆盖的次数
int l, r, cnt;
// len表示该节点代表的区间被矩形覆盖的长度
double len;
} t[MAXN << 3]; // 线段树的大小是 需要为数组大小的四倍 数组大小是 MAXN << 1;
void build(int p, int l, int r) {
t[p].l = l, t[p].r = r, t[p].cnt = 0, t[p].len = 0;
if (l == r) return;
int mid = (t[p].l + t[p].r) / 2;
// 2 * p 2 * p + 1
build(p << 1, l, mid), build((p << 1) | 1, mid + 1, r);
}
void change(int p, int l, int r, double k) {
// left side
int ls = p << 1;
// right side
int rs = ls | 1;
int mid = (t[p].l + t[p].r) / 2;
if (l <= t[p].l && r >= t[p].r) return t[p].len = ((t[p].cnt += k) ? raw[t[p].r+1] - raw[t[p].l] : (t[p].l == t[p].r ? 0 : t[ls].len + t[rs].len)), void();
if (l <= mid) change(ls, l, r, k);
if (r > mid) change(rs, l, r, k);
t[p].len = (t[p].cnt ? raw[t[p].r+1] - raw[t[p].l] : t[ls].len + t[rs].len);
}
int N, ind;
int main() {
while(scanf("%d", &N) && N != 0) {
for(int i = 1; i <= N; i++) {
int k = i << 1;
double y1, y2;
scanf("%lf%lf%lf%lf", &a[k - 1].x, &y1, &a[k].x, &y2);
raw[k - 1] = a[k - 1].y1 = a[k].y1 = y1;
raw[k] = a[k - 1].y2 = a[k].y2 = y2;
// .k 可以表示是左边还是右边 和这里k (i << 2)是不一样的1
a[k - 1].k = 1;
a[k].k = -1;
}
N <<= 1;
// 离散化y
sort(raw + 1, raw + N + 1);
int m = unique(raw + 1, raw + N + 1) - (raw + 1);
// 记录下y对应的位置
for(int i = 1; i <= m; i++) val[raw[i]] = i;
sort(a + 1, a + N + 1);
build(1, 1, m - 1);
double result = 0;
for (int i = 1; i < N; i++) {
int l = val[a[i].y1], r = val[a[i].y2] - 1;
change(1, l, r, a[i].k);
result += t[1].len * (a[i+1].x - a[i].x);
}
printf("Test case #%d\nTotal explored area: %.2f\n\n",++ind, result);
}
}
4.动态开点与线段树合并
动态开点,在数据结构中不记录节点对应的区间了,而是在操作过程中作为参数进行传递
点击查看代码
// 动态开点的线段树
struct SegmentTree {
int lc, rc; // 左右子节点的编号
int dat;
} tr[SIZE * 2];
int root, tot;
int build() { // 新建一个节点
tot++;
tr[tot].lc = tr[tot].rc = tr[tot].dat = 0;
return tot;
}
// 在main函数中
tot = 0;
root = build(); // 根节点
// 单点修改,在val位置加delta,维护区间最大值
void insert(int p, int l, int r, int val, int delta) {
if (l == r) {
tr[p].dat += delta;
return;
}
int mid = (l + r) >> 1; // 代表的区间[l,r]作为递归参数传递
if (val <= mid) {
if (!tr[p].lc) tr[p].lc = build(); // 左子树不存在,动态开点
insert(tr[p].lc, l, mid, val, delta);
}
else {
if (!tr[p].rc) tr[p].rc = build(); // 右子树不存在,动态开点
insert(tr[p].rc, mid + 1, r, val, delta);
}
tr[p].dat = max(tr[tr[p].lc].dat, tr[tr[p].rc].dat);
}
// 调用
insert(root, 1, n, val, delta);
// 合并两棵线段树
int merge(int p, int q, int l, int r) {
if (!p) return q; // p,q之一为空
if (!q) return p;
if (l == r) { // 到达叶子
tr[p].dat += tr[q].dat;
return p;
}
int mid = (l + r) >> 1;
tr[p].lc = merge(tr[p].lc, tr[q].lc, l, mid); // 递归合并左子树
tr[p].rc = merge(tr[p].rc, tr[q].rc, mid + 1, r); // 递归合并右子树
tr[p].dat = max(tr[tr[p].lc].dat, tr[tr[p].rc].dat); // 更新最值
return p; // 以p为合并后的节点,相当于删除q
}