仔细理解线段树和树状数组

一,树状数组

很简单的一种数据结构,用数字二进制的规律 lowbit = x&(x^(x-1))求出数字二进制末尾0的个数为2的幂次的值,这个数就管这么多个数,巧妙的用一个数代表一个区间的和

解决单值修改&区间查询,和区间修改&单值查询问题很赞

分为一维和二维,多维

HDU 2430 区间最值查询

#include <iostream>
using namespace std;
#define MAXN 50003
int C[2][MAXN], num[MAXN];
int Lowbit[MAXN];
int nCount;
void Init() {
    for (int i = 1; i <= nCount; i++) {
        C[0][i] = C[1][i] = num[i];
        for (int j = 1; j < Lowbit[i]; j <<= 1) {
            C[0][i] = min(C[0][i], C[0][i - j]);
            C[1][i] = max(C[1][i], C[1][i - j]);
        }
    }
}
int get_Min(int p) {
    int nMin = 0x3fffffff;
    while (p > 0) {
        nMin = min(C[0][p], nMin);
        p -= Lowbit[p];
    }
    return nMin;
}
void Modify(int p, int val, int f) {
    while (p <= nCount) {
        C[f][p] = min(C[f][p], val);
        p += Lowbit[p];
    }
}
int Query(int l, int r) {
    int ans[2] = { num[r], num[r] };
    while (1) {
        ans[0] = min(ans[0], num[r]);
        ans[1] = max(ans[1], num[r]);
        if (r == l)
            break;
        for (r--; r - l >= Lowbit[r]; r -= Lowbit[r]) {
            ans[0] = min(ans[0], C[0][r]);
            ans[1] = max(ans[1], C[1][r]);
        }
    }
    return ans[1] - ans[0];
}
int main() {
//    freopen("data4.txt", "r", stdin);
    int N, Q;
    for (int i = 1; i <= 50002; i++)
        Lowbit[i] = i & (i ^ (i - 1));
    while (~scanf("%d%d", &N, &Q)) {
        nCount = N;
        memset(C, 0, sizeof(C));
        for (int i = 1; i <= N; i++)
            scanf("%d", &num[i]);
        Init();
        int a, b;
        while (Q--) {
            scanf("%d%d", &a, &b);
            printf("%d\n", Query(a, b));
        }
    }
    return 0;
}

 

 

#include <iostream>
using namespace std;
#define lowbit(x) (x&((x)^(x-1)))
#define MAXN 1005

int c[MAXN][MAXN];
int m, n;
void init(int _m, int _n) {
    memset(c, 0, sizeof(c));
    m = _m, n = _n;
}
void update(int i, int j, int v) {
    for (; i <= m; i += lowbit(i))
        for (; j <= n; c[i][j] += v, j += lowbit(j))
            ;
}
int query(int i, int j) {
    int ret;
    for (ret = 0; i; i -= lowbit(i))
        for (; j; ret += c[i][j], j -= lowbit(j))
            ;
    return ret;
}
int main() {
//    freopen("data3.txt", "r", stdin);
    int T;
    cin >> T;
    char oper[2];
    int x1, y1, x2, y2;
    int N, t;
    while (T--) {
        scanf("%d%d", &N, &t);
        init(N, N);
        while (t--) {
            scanf("%s", oper);
            if (oper[0] == 'C') {
                scanf("%d%d%d%d", &x1, &y1, &x2, &y2);
                update(x1, y1, 1);
                update(x1, y2 + 1, -1);
                update(x2 + 1, y1, -1);
                update(x2 + 1, y2 + 1, 1);
            } else {
                scanf("%d%d", &x1, &y1);
                printf("%d\n", 1 & (query(x1, y1)));
            }
        }
        if (T)
            puts("");
    }
    return 0;
}

 

关于区间最值,ST的RMQ方法

#include<iostream>
#include<cmath>
using namespace std;
#define MAXN 310
#define LOGMAXN int(log(1000000+1.0)/log(2.0))
int M1[MAXN][20]; // 20 = LOGMAXN 一维最大范围
int A[MAXN];
int N, M;
int A2[MAXN][MAXN];
int M2[MAXN][MAXN][9][9]; //9= LOGMAXN_ROW 9= LOGMAXN_column 二维最大范围

inline int MMIN(int a, int b) {//定义为求区间最小值
    if (a > b)
        return b;
    else
        return a;
}

// *********** 一维 *************//
void Init() {
    int i, j;
    for (i = 0; i < N; i++)
        M1[i][0] = A[i];
    for (j = 1; 1 << j <= N; j++) {
        for (i = 0; i + (1 << (j - 1)) - 1 < N; i++) {
                M1[i][j] = max( M1[i][j - 1] ,M1[i + (1 << (j - 1))][j - 1]);
        }
    }
}
int Modify(int low, int high) {//满足2^k<=(r-l+1) 的最大k
    int k = log(high - low + 1.0) / log(2.0);
    return max( M1[low][k] ,M1[high + 1 - (1 << k)][k]);
}

// ************** 二维 *****************//
void Init_2D() {
    int i, j, p, q;
    for (i = 0; i < N; i++)
        for (j = 0; j < M; j++)
            M2[i][j][0][0] = A2[i][j];
    for (p = 0; 1 << p <= N; p++) {
        for (q = 0; 1 << q <= M; q++) {
            if (p == 0 && q == 0)
                continue;
            for (i = 0; i + (1 << p) - 1 < N; i++) {
                for (j = 0; j + (1 << q) - 1 < M; j++) {
                    if (p == 0) {
                        M2[i][j][p][q] = MMIN(M2[i][j][p][q - 1], M2[i][j + (1
                                << (q - 1))][p][q - 1]);
                    } else {
                        M2[i][j][p][q] = MMIN(M2[i][j][p - 1][q], M2[i + (1
                                << (p - 1))][j][p - 1][q]);
                    }
                }
            }
        }
    }
}

int Modify(int low_x, int low_y, int high_x, int high_y) {
    int k_x = log(high_x - low_x + 1.0) / log(2.0);
    int k_y = log(high_y - low_y + 1.0) / log(2.0);
    return MMIN(MMIN(MMIN(M2[low_x][low_y][k_x][k_y], M2[high_x - (1 << k_x)
            + 1][high_y - (1 << k_y) + 1][k_x][k_y]), M2[low_x][high_y - (1
            << k_y) + 1][k_x][k_y]),
            M2[high_x - (1 << k_x) + 1][low_y][k_x][k_y]);
}

 

树状数组求第k小数 log(n) POJ 2985

#include<cstdio>
#include<cstring>
#include <iostream>
using namespace std;

#define MAXN 200009
#define lowbit(x) x&((x)^(x-1))
int num[MAXN], c[MAXN], father[MAXN];
int m, n;
void init() {
    for (int i = 1; i <= n; i++)
        father[i] = i, num[i] = 1;
}
int find(int x) {
    if (x == father[x])
        return x;
    return father[x] = find(father[x]);
}
void update(int x, int d) {
    for (; x <= n; x += lowbit(x))
        c[x] += d;
}
int find_kth(int k) {
    int ans = 0, cnt = 0, i;
    for (i = 20; i >= 0; i--) { //二进制枚举趋近第K小
        ans += (1 << i);
        if (ans >= n || cnt + c[ans] >= k)
            ans -= (1 << i);
        else
            cnt += c[ans]; //cnt用来累加比当前ans小的个数
    }
    return ans + 1;
}
int main() {
//    freopen("data3.txt", "r", stdin);
    int i, q, x, y, k;
    scanf("%d%d", &n, &m);
    init();
    update(1, n); //初始状态值为1的数有n个
    int cnt = n;
    for (i = 1; i <= m; i++) {
        scanf("%d", &q);
        if (!q) {
            scanf("%d%d", &x, &y);
            x = find(x);
            y = find(y);
            if (x == y)
                continue;
            update(num[x], -1);
            update(num[y], -1);
            update(num[x] + num[y], 1);
            father[y] = x;
            num[x] += num[y];
            cnt--; //合并集合
        } else {
            scanf("%d", &k);
            k = cnt - k + 1; //转换为找第k小的数
            printf("%d\n", find_kth(k));
        }
    }
    return 0;
}

 

 

二,线段树

作为ACM里比较爱出难题的数据结构,只能说多练来深入理解巧妙运用吧

应用一 :点的区间最值查询(单点更新)

struct Segtree {
    LL mx[N * 4];
    void PushUp(int ind) {
        mx[ind] = max(mx[LL(ind)], mx[RR(ind)]);
    }
    void build(int lft, int rht, int ind) {
        mx[ind] = -INF;
        if (lft != rht) {
            int mid = MID(lft,rht);
            build(lft, mid, LL(ind));
            build(mid + 1, rht, RR(ind));
        }
    }
    void updata(int pos, LL valu, int lft, int rht, int ind) {
        if (lft == rht)
            mx[ind] = max(mx[ind], valu);
        else {
            int mid = MID(lft,rht);
            if (pos <= mid)
                updata(pos, valu, lft, mid, LL(ind));
            else
                updata(pos, valu, mid + 1, rht, RR(ind));
            PushUp(ind);
        }
    }
    LL query(int st, int ed, int lft, int rht, int ind) {
        if (st <= lft && rht <= ed)
            return mx[ind];
        else {
            LL mx1 = -INF, mx2 = -INF;
            int mid = MID(lft,rht);
            if (st <= mid)
                mx1 = query(st, ed, lft, mid, LL(ind));
            if (ed > mid)
                mx2 = query(st, ed, mid + 1, rht, RR(ind));
            return max(mx1, mx2);
        }
    }
} seg;

 

应用二 :区段加,区段乘,区段变另一值

成段更新需要用到延迟标记(或者说懒惰标记),简单的说就是每次更新的时候不要更新到底,用延迟标记使得更新延迟到下次需要更新或者询问到的时候再传递下去。

 

 

/*
 * HDU 4614 统计区间个数(1 or 0),则把+=全改为=即为全更改为add
 * 线段树从 1开始,区段加,区段乘,区段变另一值
 * 区间增加,区间查询为线段树;区间增加,单值更新也可为树状数组
 */

#include<iostream>
#define MAXN 100001
#define LL long long
using namespace std;

struct Node {
    int left, right;
    int add;LL sum;
} tree[MAXN * 4];
//LL a[MAXN]; //叶子节点初始值

void push_down(int num) {
    if (tree[num].add != -1) {
        tree[num << 1].sum += (LL) tree[num].add
                * (tree[num * 2].right - tree[num * 2].left + 1);
        tree[num << 1 | 1].sum += (LL) tree[num].add
                * (tree[num * 2 + 1].right - tree[num * 2 + 1].left + 1);
        tree[num << 1].add += tree[num].add;
        tree[num << 1 | 1].add += tree[num].add;
        tree[num].add = -1;
    }
}

void push_up(int num) {
    tree[num].sum += (LL) tree[num * 2].sum + tree[num * 2 + 1].sum;
}

void build(int s, int t, int num) {
    tree[num].left = s;
    tree[num].right = t;
    tree[num].add = -1;
    /*
     if(s==t)
     tree[num].sum=a[s];
     */
    if (s != t) {
        build(s, (s + t) / 2, num << 1);
        build((s + t) / 2 + 1, t, num << 1 | 1);
        //push_up(num);
    }
}

void update(int s, int t, int add, int num) {

    if (s <= tree[num].left && tree[num].right <= t) {
        tree[num].add += add;
        tree[num].sum += add * (tree[num].right - tree[num].left + 1);
        return;
    }
    push_down(num);
    if (s <= (tree[num].left + tree[num].right) / 2)
        update(s, t, add, num * 2);
    if (t > (tree[num].left + tree[num].right) / 2)
        update(s, t, add, num * 2 + 1);
    push_up(num);

}

LL query(int s, int t, int num) {
    LL i = 0, j = 0;
    if (s <= tree[num].left && tree[num].right <= t) {
        return (tree[num].sum);
    }
    push_down(num);
    if (s <= (tree[num].left + tree[num].right) / 2) {
        i = query(s, t, num * 2);
    }
    if (t > (tree[num].left + tree[num].right) / 2) {
        j = query(s, t, num * 2 + 1);
    }
    return (i + j);
}

int main() {
    //freopen("data.txt","r",stdin);
    int n, m, i, T;
    cin >> T;
    while (T--) {
        scanf("%d%d", &n, &m);
        int a, b, c;
        build(1, n, 1);
        for (i = 0; i < m; i++) {
            scanf("%d%d%d", &a, &b, &c);
            if (a == 1) {
                int low = b + 1, high = n, mid; //b从0开始,线段树操作需要+1
                int Max, Min, tmp;
                if (n - low + 1 - query(low, n, 1) == 0) {
                    printf("Can not put any one.\n");
                    continue;
                }
                while (low < high) {
                    mid = (low + high) >> 1;
                    tmp = mid - (b + 1) + 1 - query(b + 1, mid, 1);
                    if (tmp == 0) {
                        low = mid + 1;
                    } else
                        high = mid;
                }
                Min = low;
                tmp = n - Min + 1 - query(Min, n, 1);
                if (tmp < c)
                    c = tmp;
                low = Min, high = n;
                while (low < high) {
                    mid = (low + high) >> 1;
                    tmp = mid - Min + 1 - query(Min, mid, 1);
                    if (tmp >= c) {
                        high = mid;
                    } else {
                        low = mid + 1;
                    }
                }
                Max = low;
                printf("%d %d\n", Min - 1, Max - 1);
                update(Min, Max, 1, 1);
            } else {
                printf("%d\n", query(b + 1, c + 1, 1));
                update(b + 1, c + 1, 0, 1); //区间全部更新为0
            }
        }
        printf("\n");
    }
    return 0;
}

附一个简洁的写法

#define LL(x) (x<<1)
#define RR(x) (x<<1|1)
#define MID(a,b) (a+((b-a)>>1))

struct Segtree {
    int mx[N * 4], delay[N * 4];
    void fun(int ind, int valu) {
        mx[ind] += valu;
        delay[ind] += valu;
    }
    void PushDown(int ind) {
        if (delay[ind]) {
            fun(LL(ind), delay[ind]);
            fun(RR(ind), delay[ind]);
            delay[ind] = 0;
        }
    }
    void PushUp(int ind) {
        mx[ind] = max(mx[LL(ind)], mx[RR(ind)]);
    }
    void build(int lft, int rht, int ind) {
        mx[ind] = delay[ind] = 0;
        if (lft != rht) {
            int mid = MID(lft,rht);
            build(lft, mid, LL(ind));
            build(mid + 1, rht, RR(ind));
        }
    }
    void updata(int st, int ed, int valu, int lft, int rht, int ind) {
        if (st <= lft && rht <= ed)
            fun(ind, valu);
        else {
            PushDown(ind);
            int mid = MID(lft,rht);
            if (st <= mid)
                updata(st, ed, valu, lft, mid, LL(ind));
            if (ed > mid)
                updata(st, ed, valu, mid + 1, rht, RR(ind));
            PushUp(ind);
        }
    }
} seg;

 

应用三 :区间合并

学习中。。这样的问题一般都是问你区间中满足条件最长的序列,你关键需要知道怎样操作对线段树左右儿子进行合并。

 

一般都要定义三个数组lm(定义从区间左边第一个点开始的满足条件的序列), rm(定义以区间右边最后一个点结束的满足条件的序列), sm(定义整个区间满足条件的最长序列)。

POJ3667

#include <cstdio>
#include <cmath>
#include <iostream>
#include <algorithm>
using namespace std;

#define LEFT 2*u,l,mid
#define RIGHT 2*u+1,mid+1,r
const int maxn = 50005;
int flag[4 * maxn];   ///标记

struct node {
    int lm; ///从左边第一个点开始最长的连续空hotel
    int rm; ///以右边最后一个结束的最长的连续空hotel
    int sm; ///整段区间最大的连续空hotel
} tree[4 * maxn];

void push_up(int u, int l, int r) {
    tree[u].lm = tree[2 * u].lm;
    tree[u].rm = tree[2 * u + 1].rm;
    tree[u].sm = max(tree[2 * u].sm, tree[2 * u + 1].sm);
    int mid = (l + r) >> 1;
    if (tree[2 * u].lm == mid - l + 1)
        tree[u].lm += tree[2 * u + 1].lm; ///!!这里注意,当左孩子左边连续的达到整个区间时,要加上右孩子的左边区间
    if (tree[2 * u + 1].rm == r - mid)
        tree[u].rm += tree[2 * u].rm;   ///!!考虑右区间,同上
    int t = tree[2 * u].rm + tree[2 * u + 1].lm;
    if (t > tree[u].sm)
        tree[u].sm = t;
}

void push_down(int u, int l, int r) {
    if (flag[u] == -1)
        return;
    if (flag[u]) {
        flag[2 * u] = flag[2 * u + 1] = flag[u];
        tree[2 * u].lm = tree[2 * u].rm = tree[2 * u].sm = 0;
        tree[2 * u + 1].lm = tree[2 * u + 1].rm = tree[2 * u + 1].sm = 0;
        flag[u] = -1;
    } else {
        flag[2 * u] = flag[2 * u + 1] = flag[u];
        int mid = (l + r) >> 1;
        tree[2 * u].lm = tree[2 * u].rm = tree[2 * u].sm = mid - l + 1;
        tree[2 * u + 1].lm = tree[2 * u + 1].rm = tree[2 * u + 1].sm = r - mid;
        flag[u] = -1;
    }
}

void build(int u, int l, int r) {
    flag[u] = -1;
    if (l == r) {
        tree[u].lm = tree[u].rm = tree[u].sm = 1;
        return;
    }
    int mid = (l + r) >> 1;
    build(LEFT);
    build(RIGHT);
    push_up(u, l, r);
}

void Update(int u, int l, int r, int tl, int tr, int c) {
    if (tl <= l && r <= tr) {
        tree[u].sm = tree[u].lm = tree[u].rm = (c == 1 ? 0 : r - l + 1);
        flag[u] = c;
        return;
    }
    push_down(u, l, r);   ///再次遇见此段区间时,延迟标记同步向下更新
    int mid = (l + r) >> 1;
    if (tr <= mid)
        Update(LEFT, tl, tr, c);
    else if (tl > mid)
        Update(RIGHT, tl, tr, c);
    else {
        Update(LEFT, tl, mid, c);    ///注意区间分隔开,tl,tr跨越两个左右区间
        Update(RIGHT, mid + 1, tr, c);
    }
    push_up(u, l, r);     ///递归的时候同步向上更新
}

int Query(int u, int l, int r, int num) {
    if (l == r)
        return l;
    push_down(u, l, r);     ///延迟标记向下传递
    int mid = (l + r) >> 1;
    if (tree[2 * u].sm >= num)
        return Query(LEFT, num);
    else if (tree[2 * u].rm + tree[2 * u + 1].lm >= num && tree[2 * u].rm >= 1)
        return mid - tree[2 * u].rm + 1;   ///满足条件时,返回左边rm连续的hotel第一个房间标号
    else
        return Query(RIGHT, num);
}

int main() {
//    freopen("data3.txt", "r", stdin);
    int n, m;
    while (~scanf("%d%d", &n, &m)) {
        build(1, 1, n);
        while (m--) {
            int p, u, v;
            scanf("%d", &p);
            if (p == 1) {
                scanf("%d", &u);
                if (tree[1].sm < u)  ///特判一下是否有这么多个连续的空hotel,没有则直接输出,不用操作
                        {
                    puts("0");
                    continue;
                }
                int p = Query(1, 1, n, u);
                printf("%d\n", p);
                Update(1, 1, n, p, p + u - 1, 1);
            } else {
                scanf("%d%d", &u, &v);
                Update(1, 1, n, u, u + v - 1, 0);
            }
        }
    }
    return 0;
}

 

 

 

应用四 :扫描线。

配合离散化求些面积和,面积最值类的问题。

把每个矩阵的上下边先存储起来,每条边在我们意想看来都是一条扫描线,一条接一条的从下往上扫描,每次扫描到的区间覆盖到的区间cover值会发生变化,但是n条扫描线只在总区间上变化。

那多矩形面积和举例,我按x轴记录所有的与y轴平行的线段的y1,y2。然后对y轴所有点离散化,建立线段树,来维护每个区间被覆盖的次数。接下来就是一条一条的扫描与y轴平行的线段,二分找到对应离散化后的区间,更新对应区间的覆盖次数c 和 区间实际覆盖长度,主要在push_up里。

 

// POJ 1151 矩形面积合并,区间插入和同区间删除不向下更新
//  POJ 2482
//void push_up(int root) {
//    tree[root].c = tree[root].m + max(tree[root << 1].c, tree[root << 1 | 1].c);
//}
#include<iostream>
#include<algorithm>
using namespace std;
#define N 210
struct node {
    int st, ed;
    int c;
    double lf, rf;
    double m;
} tree[N * 4];
struct Line {
    double x, y1, y2;
    bool operator <(const Line &tmp) const {
        return x < tmp.x;
    }
    bool operator ==(const Line &tmp) const {
        return x == tmp.x;
    }
    int f;
} line[N];
double y[N];
void build(int root, int l, int r) {
    tree[root].st = l;
    tree[root].ed = r;
    tree[root].c = 0;
    tree[root].m = 0;
    tree[root].lf = y[l];
    tree[root].rf = y[r];
    if (l + 1 < r) {
        int mid = (l + r) >> 1;
        build(root << 1, l, mid);
        build(root << 1 | 1, mid, r);
    }
}
void push_up(int root) {
    if (tree[root].c > 0)
        tree[root].m = tree[root].rf - tree[root].lf;
    else {
        if (tree[root].st + 1 == tree[root].ed)
            tree[root].m = 0;
        else
            tree[root].m = tree[root << 1].m + tree[root << 1 | 1].m;
    }
}
void update(int root, int st, int ed, int add) {
    if (st <= tree[root].st && tree[root].ed <= ed) {
        tree[root].c += add;
        push_up(root);
        return;
    }
    int mid = (tree[root].ed + tree[root].st) >> 1;
    if (st < mid)
        update(root << 1, st, ed, add);
    if (ed > mid)
        update(root << 1 | 1, st, ed, add);
    push_up(root);
}
int lower_bound(double *arr, int lhs, int rhs, double value) {
    int m;
    while (lhs < rhs) {
        m = lhs + (rhs - lhs) / 2;
        if (arr[m] >= value)
            rhs = m;
        else
            lhs = m + 1;
    }
    return lhs;
}
int main() {
//    freopen("data4.txt", "r", stdin);
    double x1, y1, x2, y2;
    int n, cas = 1;
    while (scanf("%d", &n) && n) {
        for (int i = 0; i < n; i++) {
            scanf("%lf%lf%lf%lf", &x1, &y1, &x2, &y2);
            line[i * 2].f = 1;
            line[i * 2].x = x1;
            line[i * 2].y1 = y1;
            line[i * 2].y2 = y2;

            line[i * 2 + 1].f = -1;
            line[i * 2 + 1].x = x2;
            line[i * 2 + 1].y1 = y1;
            line[i * 2 + 1].y2 = y2;

            y[i * 2] = y1;
            y[i * 2 + 1] = y2;
        }
        n <<= 1;
        sort(line, line + n);
        sort(y, y + n);
        int num = unique(y, y + n) - y;
        build(1, 0, num - 1);
        double area = 0;
        for (int i = 0; i < n - 1; i++) {
            int l = lower_bound(y, 0, num - 1, line[i].y1);
            int r = lower_bound(y, 0, num - 1, line[i].y2);
            update(1, l, r, line[i].f);
            area += tree[1].m * (line[i + 1].x - line[i].x);
        }
        printf("Test case #%d\n", cas++);
        printf("Total explored area: %.2lf\n\n", area);
    }
    return 0;
}

 

http://www.cnblogs.com/kane0526/archive/2013/03/11/2952952.html

posted @ 2013-11-06 18:27  匡时@下一站.info  阅读(410)  评论(0编辑  收藏  举报