线段树

记录
15:48 2024-2-4

1.线段树

线段树是处理区间用的数据结构,最经典的例子是RMQ(Range Minimum Query),查询某个区间上的最大值

这里init/update是从叶子节点向上更新的 自底向上

点击查看代码
#include<cstdio>
#include<algorithm>
using namespace std;
const int INF = 0x3f3f3f3f;
const int MAX_N = 1 << 17;

// 存储线段树的全局数组
int n, dat[2 * MAX_N - 1];

//初始化
void init(int n_) {
    // 为了简单起见,把元素个数扩大到2的幂
    n = 1;
    while (n < n_) n *= 2;
    //把所有值都设为INT_MAX
    for(int i = 0; i < 2 * n - 1; i++) dat[i] = INF;
}

// 把第k个值(0-indexed)更新为a
void upadate(int k, int a) {
    // n + n/2 + n/4 + ... = 2n 所以0到n-1都是树枝
    //叶子节点 (从0到n-1都是树枝)
    k += n - 1;
    dat[k] = a;
    // 向上更新
    while (k > 0){
        k = (k - 1) / 2;
        dat[k] = min(dat[k * 2 + 1], dat[k * 2 + 2]);
    }
}

// RMQ range minimum query
// 求[a, b)最小值
// 后面的参数是为了计算起来方便而传入的
// k是节点的编号,1,r表示这个节点对应的是[l, r]区间
// 在外部调用时,用query(a, b, 0, 0, n)
int query(int a, int b, int k, int l, int r) {

    //如果[a, b)和[l,r)不相交,则返回INT_MAX
    if (r <= a || b <= l) return INF;

    // 如果[a, b) 完全包含[l,r)则返回当前节点的值
    if(a <= l && r <= b) return dat[k];
    else {
        //否则返回两个儿子中值的较小者
        int vl = query(a, b, k * 2 + 1, l, (l + r) / 2);
        int vr = query(a, b, k * 2 + 2, (l + r) / 2, r);
        return min(vl, vr);
    }
}

这里的build和change是自上而下的

点击查看代码
struct SegmentTree {
	int l, r;
	int dat;
} t[SIZE * 4]; // struct数组存储线段树

void build(int p, int l, int r) {
	t[p].l = l, t[p].r = r; // 节点p代表区间[l,r]
	if (l == r) { t[p].dat = a[l]; return; } // 叶节点
	int mid = (l + r) / 2; // 折半
	build(p*2, l, mid); // 左子节点[l,mid],编号p*2
	build(p*2+1, mid+1, r); // 右子节点[mid+1,r],编号p*2+1
	t[p].dat = max(t[p*2].dat, t[p*2+1].dat); // 从下往上传递信息
}

build(1, 1, n); // 调用入口

void change(int p, int x, int v) {
	if (t[p].l == t[p].r) { t[p].dat = v; return; } // 找到叶节点
	int mid = (t[p].l + t[p].r) / 2;
	if (x <= mid) change(p*2, x, v); // x属于左半区间
	else change(p*2+1, x, v); // x属于右半区间
	t[p].dat = max(t[p*2].dat, t[p*2+1].dat); // 从下往上更新信息
}

change(1, x, v); // 调用入口

int ask(int p, int l, int r) {
	if (l <= t[p].l && r >= t[p].r) return t[p].dat; // 完全包含,直接返回
	int mid = (t[p].l + t[p].r) / 2;
	int val = 0;
	if (l <= mid) val = max(val, ask(p*2, l, r)); // 左子节点有重叠
	if (r > mid) val = max(val, ask(p*2+1, l, r)); // 右子节点有重叠
	return val;
}

cout << ask(1, l, r) << endl; // 调用入口

1. 区间增加 + 区间查询

线段树维护的是区间的值,如果对区间增加,需要对区间相关的节点全部更新才行。

为了保证线段树的高效性,将该问题转化: 对每个节点维护两个数据 (线段树本身就是维护区间数据的数据结构)

  1. 给这个节点对应的区间内的所有元素共同加上的值
    这里意思就是说 操作是要把[a,b]之前共同加上一个值 v,如果某个区间[l,r] 有 a < l < r < b 说明这个区间内都要加上值,这里可以只记录这个值v,因为求的话乘上区间范围就行了
  2. 在这个节点对应的区间中除去2(↑)之外其他的值的和
    这里意思是说[a,b]和区间[l,r]有交集,这个区间有部分要加上值v,这里记录要把这部分值全部算出来,因为没法通过区间范围求出来

如果对于父亲节点同时加了一个值, 那么这个值就不会在儿子节点被重复考虑。 在递归计算和时再把这一部分的值加到结果里面就可以了。
这就是重点了

点击查看代码
#include<iostream>
#include<cstdio>
using namespace std;
typedef long long ll;

const int MAXN = 100005;
const int DAT_SIZE = (1 << 18) - 1;
int N, Q;
ll a[MAXN];
// 线段树
// a 给这个节点对应的区间内的所有元素共同加上的值
// b 在这个节点对应的区间中除去a之外其他的值的和
ll datA[DAT_SIZE], datB[DAT_SIZE];

// 对区间[a, b)同时加x
// k是节点的编号, 对应的区间是[l, r)
void add(int a, int b, int x, int k, int l, int r) {
    if (a <= l && r <= b) {
        // a.给这个节点对应的区间内的所有元素共同加上的值
        datA[k] += x;
    } else if (l < b && a < r) {
        // b.在这个节点对应的区间中除去a之外其他的值的和
        datB[k] += (min(b, r) - max(a, l)) * x;
        add(a, b, x, k * 2 + 1, l, (l + r) / 2);
        add(a, b, x, k * 2 + 2, (l + r) / 2, r );
    }
}

// 计算[a, b)的和
// k是节点的编号 , 对应的区间是[l, r)
ll sum(int a, int b, int k, int l, int r) {
    if (b <= l||r <= a) {
        return 0;
    } else if (a <= l && r <= b) {
        //a l r b 全部包含
        return datA[k] * (r - l) + datB[k];
    } else {
        // 交集区间上同时加上的值 + 子集区间上加上的值
        // 如果对于父亲节点同时加了一个值, 那么这个值就不会在儿子节点被重复考虑。 
        // 在递归计算和时再把这一部分的值加到结果里面就可以了。
        ll res = (min(b, r) - max(a, l)) * datA[k];
        res += sum(a, b, k * 2 + 1, l, (l + r) / 2);
        res += sum(a, b, k * 2 + 2, (l + r) / 2, r);
        return res;
    }
}

int main() {
    cin >> N >> Q;
    for(int i = 0; i < N; i++) {
        scanf("%lld", &a[i]);
        add(i, i + 1, a[i], 0, 0, N);
    }
    char c[2];
    int l, r, v;
    for(int i = 0; i < Q; i++) {
        // %s 它会读入一个不含空格、TAB和回车符的字符串,存入字符数组
        // %c 会读入\n
        scanf("%s%d%d", c, &l, &r);
        if(c[0] == 'C') {
            scanf("%d", &v);
            add(l - 1, r, v, 0, 0, N);
        } else {
            printf("%lld\n" , sum(l - 1, r, 0, 0, N));
        }
    }
}

2. 延迟标记

在区间上维持一个延迟标记

延迟标记的含义为“该节点曾经被修改, 但其子节点尚未被更新”

在更新或查询的时候向下传递延迟标记

点击查看代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int SIZE=100010;
struct SegmentTree{
	int l,r;
	long long sum,add;
	#define l(x) tree[x].l
	#define r(x) tree[x].r
	#define sum(x) tree[x].sum
	#define add(x) tree[x].add
}tree[SIZE*4];
int a[SIZE],n,m;

void build(int p,int l,int r)//No.p, [l,r]
{
	l(p)=l,r(p)=r;
	if(l==r) { sum(p)=a[l]; return; }
	int mid=(l+r)/2;
	build(p*2,l,mid);
	build(p*2+1,mid+1,r);
	sum(p)=sum(p*2)+sum(p*2+1);
}

void spread(int p)
{
	if(add(p))
	{
		sum(p*2)+=add(p)*(r(p*2)-l(p*2)+1);
		sum(p*2+1)+=add(p)*(r(p*2+1)-l(p*2+1)+1);
		add(p*2)+=add(p);
		add(p*2+1)+=add(p);
		add(p)=0;
	}
}

void change(int p,int l,int r,int z)
{
	if(l<=l(p)&&r>=r(p))
	{
		sum(p)+=(long long)z*(r(p)-l(p)+1);
		add(p)+=z;
		return;
	}
	spread(p);
	int mid=(l(p)+r(p))/2;
	if(l<=mid) change(p*2,l,r,z);
    if(r>mid) change(p*2+1,l,r,z);
	sum(p)=sum(p*2)+sum(p*2+1);
}

long long ask(int p,int l,int r)
{
	if(l<=l(p)&&r>=r(p)) return sum(p);
	spread(p);
    int mid=(l(p)+r(p))/2;
	long long ans=0;
    if(l<=mid) ans+=ask(p*2,l,r);
    if(r>mid) ans+=ask(p*2+1,l,r);
    return ans;
}

int main()
{
	cin>>n>>m;
	for(int i=1;i<=n;i++)
		scanf("%d",&a[i]);
	build(1,1,n);
	while(m--)
	{
		char op[2]; int x,y,z;
		scanf("%s%d%d",op,&x,&y);
		if(op[0]=='C')
		{
			scanf("%d",&z);
			change(1,x,y,z);
		}
		else printf("%lld\n",ask(1,x,y));
	}
}

3. 扫描线

感觉这个扫描线 只是告诉我们可以利用线段树来解决问题,并且利用有些题的特性,不需要做延迟标记

O(n^2) 没有利用线段树

点击查看代码
#include<vector>
#include<map>
#include<algorithm>
#include<cstdio>
#include<cstring>
#define rep(i, a, n) for (auto i = a; i < (n); ++i)  // repeat
#define repe(i, a, n) for (auto i = a; i <= (n); ++i) // repeat and equal
#define revrep(i, a, n) for (auto i = n; i > (a); --i) // reverse repeat
#define revrepe(i, a, n) for (auto i = n; i >= (a); --i)
#define all(a) a.begin(), a.end()
#define sz(a) (int)(a.size());
#define mem(a,b) memset(a,b,sizeof(a))
#define lb(x) ((x) & -(x)) // lowbit
#define pb push_back
#define qb pop_back
#define pf push_front
#define qf pop_front
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef vector<vi> vvi;

template<class T> inline bool chmax(T &a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T &a, T b) { if (b < a) { a = b; return 1; } return 0; }

const int MAXN = 1e5 + 5;
struct P {
    double x, y1, y2;
    int k;
    bool operator < (const P &p) const {
        return x < p.x;
    }
} a[MAXN << 1];
double raw[MAXN << 1];
map<double, int> val;
// 表示某个区间是否被覆盖
int c[MAXN << 1];

int N, ind;
int main() {
    while(scanf("%d", &N) && N != 0) {
        for(int i = 1; i <= N; i++) {
            int k = i << 1;
            double y1, y2;
            scanf("%lf%lf%lf%lf", &a[k - 1].x, &y1, &a[k].x, &y2); 
            raw[k - 1] = a[k - 1].y1 = a[k].y1 = y1;
            raw[k] = a[k - 1].y2 = a[k].y2 = y2;
            // .k 可以表示是左边还是右边 和这里k (i << 2)是不一样的1
            a[k - 1].k = 1;
            a[k].k = -1;
        }
        N <<= 1;
        // 离散化y
        sort(raw + 1, raw + N + 1);
        int m = unique(raw + 1, raw + N + 1) - (raw + 1);
        // 记录下y对应的位置
        for(int i = 1; i <= m; i++) val[raw[i]] = i;

        sort(a + 1, a + N + 1);
        memset(c, 0, sizeof(c));
        double result = 0;
        for(int i = 1; i < N; i++) {
            int l = val[a[i].y1], r = val[a[i].y2];
            for(int j = l; j < r; j++) c[j] += a[i].k;

            double len = 0;
            for(int j = 1; j < m; j++) 
                if(c[j] > 0) len += raw[j + 1] - raw[j];
                
            result += len * (a[i + 1].x - a[i].x);
        }
        printf("Test case #%d\nTotal explored area: %.2f\n\n",++ind, result);
    }
}

利用线段树 O(nlogn)

点击查看代码
#include<vector>
#include<map>
#include<algorithm>
#include<cstdio>
#include<cstring>
#define rep(i, a, n) for (auto i = a; i < (n); ++i)  // repeat
#define repe(i, a, n) for (auto i = a; i <= (n); ++i) // repeat and equal
#define revrep(i, a, n) for (auto i = n; i > (a); --i) // reverse repeat
#define revrepe(i, a, n) for (auto i = n; i >= (a); --i)
#define all(a) a.begin(), a.end()
#define sz(a) (int)(a.size());
#define mem(a,b) memset(a,b,sizeof(a))
#define lb(x) ((x) & -(x)) // lowbit
#define pb push_back
#define qb pop_back
#define pf push_front
#define qf pop_front
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef vector<vi> vvi;

template<class T> inline bool chmax(T &a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T &a, T b) { if (b < a) { a = b; return 1; } return 0; }

const int MAXN = 1e5 + 5;
struct P {
    double x, y1, y2;
    int k;
    bool operator < (const P &p) const {
        return x < p.x;
    }
} a[MAXN << 1];
double raw[MAXN << 1];
map<double, int> val;
// 线段树
struct T {
  // cnt 表示节点自身被覆盖的次数
  int l, r, cnt;
  // len表示该节点代表的区间被矩形覆盖的长度
  double len;
} t[MAXN << 3];  // 线段树的大小是 需要为数组大小的四倍 数组大小是 MAXN << 1;
 
void build(int p, int l, int r) {
	t[p].l = l, t[p].r = r, t[p].cnt = 0, t[p].len = 0;
	if (l == r) return;
  
  int mid = (t[p].l + t[p].r) / 2;
  // 2 * p               2 * p + 1
	build(p << 1, l, mid), build((p << 1) | 1, mid + 1, r);
}

void change(int p, int l, int r, double k) {
  // left side
  int ls = p << 1;
  // right side
  int rs = ls | 1;
  int mid = (t[p].l + t[p].r) / 2;

	if (l <= t[p].l && r >= t[p].r) return t[p].len = ((t[p].cnt += k) ? raw[t[p].r+1] - raw[t[p].l] : (t[p].l == t[p].r ? 0 : t[ls].len + t[rs].len)), void();
	if (l <= mid) change(ls, l, r, k);
	if (r > mid) change(rs, l, r, k);
	t[p].len = (t[p].cnt ? raw[t[p].r+1] - raw[t[p].l] : t[ls].len + t[rs].len);
}

int N, ind;
int main() {
    while(scanf("%d", &N) && N != 0) {
        for(int i = 1; i <= N; i++) {
            int k = i << 1;
            double y1, y2;
            scanf("%lf%lf%lf%lf", &a[k - 1].x, &y1, &a[k].x, &y2); 
            raw[k - 1] = a[k - 1].y1 = a[k].y1 = y1;
            raw[k] = a[k - 1].y2 = a[k].y2 = y2;
            // .k 可以表示是左边还是右边 和这里k (i << 2)是不一样的1
            a[k - 1].k = 1;
            a[k].k = -1;
        }
        N <<= 1;
        // 离散化y
        sort(raw + 1, raw + N + 1);
        int m = unique(raw + 1, raw + N + 1) - (raw + 1);
        // 记录下y对应的位置
        for(int i = 1; i <= m; i++) val[raw[i]] = i;

        sort(a + 1, a + N + 1);
        build(1, 1, m - 1);
        double result = 0;
        for (int i = 1; i < N; i++) {
          int l = val[a[i].y1], r = val[a[i].y2] - 1;
          change(1, l, r, a[i].k);
          result += t[1].len * (a[i+1].x - a[i].x);
        }

        printf("Test case #%d\nTotal explored area: %.2f\n\n",++ind, result);
    }
}

4.动态开点与线段树合并

动态开点,在数据结构中不记录节点对应的区间了,而是在操作过程中作为参数进行传递

点击查看代码
// 动态开点的线段树
struct SegmentTree {
    int lc, rc; // 左右子节点的编号
	int dat;
} tr[SIZE * 2];
int root, tot;

int build() { // 新建一个节点
	tot++;
	tr[tot].lc = tr[tot].rc = tr[tot].dat = 0;
	return tot;
}

// 在main函数中
tot = 0;
root = build(); // 根节点

// 单点修改,在val位置加delta,维护区间最大值
void insert(int p, int l, int r, int val, int delta) {
    if (l == r) {
        tr[p].dat += delta;
        return;
    }
    int mid = (l + r) >> 1; // 代表的区间[l,r]作为递归参数传递
    if (val <= mid) {
        if (!tr[p].lc) tr[p].lc = build(); // 左子树不存在,动态开点
        insert(tr[p].lc, l, mid, val, delta);
    }
    else {
        if (!tr[p].rc) tr[p].rc = build(); // 右子树不存在,动态开点
        insert(tr[p].rc, mid + 1, r, val, delta);
    }
    tr[p].dat = max(tr[tr[p].lc].dat, tr[tr[p].rc].dat);
}

// 调用
insert(root, 1, n, val, delta);

// 合并两棵线段树
int merge(int p, int q, int l, int r) {
    if (!p) return q; // p,q之一为空
    if (!q) return p;
    if (l == r) { // 到达叶子
        tr[p].dat += tr[q].dat;
        return p;
    }
    int mid = (l + r) >> 1;
    tr[p].lc = merge(tr[p].lc, tr[q].lc, l, mid); // 递归合并左子树
    tr[p].rc = merge(tr[p].rc, tr[q].rc, mid + 1, r); // 递归合并右子树
    tr[p].dat = max(tr[tr[p].lc].dat, tr[tr[p].rc].dat); // 更新最值
    return p; // 以p为合并后的节点,相当于删除q
}

例题

POJ--3468 A Simple Problem with Integers(线段树/树状数组)

posted @ 2024-02-05 00:33  57one  阅读(9)  评论(0编辑  收藏  举报