W
H
X

线段树例题解析(合集)

目前有8道题,会持续更新(修正:已经停更啦)...

难度大致从易到难,(虽然都不难)

英文题面的题有题意简述 (提示:HDU的题目注意多组数据!不然被坑惨)

POJ3321 Apple Tree

题意:给定一棵树,有两种操作:1、把某个节点上的数^1(若是1改为0,是0改为1) 2、查询以某个节点为根节点的子树中1的个数

在一棵树上进行单点修改和查询,只要进行一遍dfs,记录每个点的dfs序即可把问题转化到链上用线段树进行维护

更改和模板一样,查询时以每棵子树遍历时第一个节点的dfs序和最后一个节点的边界作为查询区间(每颗子树中节点的DFS序是连续的)

#include <stdio.h>
#include <iostream>
using namespace std;
inline void read (int &x) {
    char ch = getchar(); x = 0;
     while (!isdigit(ch)) ch = getchar();
     while (isdigit(ch)) x = x * 10 + ch - 48, ch = getchar();
}
void print (int x) {
    if (x > 9) print (x / 10);
    putchar (x % 10 + 48);
}
const int N = 1e6 + 10;
int n, cnt, tot, m, pos, ql, qr, s[N << 2], num[N], l[N], r[N], nxt[N << 1], to[N << 1], h[N];
inline void add (int u, int v) {
    to[++tot] = v, nxt[tot] = h[u], h[u] = tot;
}
void dfs (int u, int la) {
    num[u] = l[u] = ++ cnt;
    for (int i = h[u]; i; i = nxt[i]) if (to[i] != la) dfs (to[i], u);
    r[u] = cnt;  //l、r记录每颗子树中最小和最大的dfs序
}
#define ls p << 1
#define rs p << 1 | 1
void build (int p, int l, int r) {
    if (l == r) {s[p] = 1; return;}
    int mid (l + r >> 1);
    build (ls, l, mid), build (rs, mid + 1, r);
    s[p] = s[ls] + s[rs];
}
void change (int p, int l, int r) {
    if (l == r) {s[p] ^= 1; return;} //把0改为1,把1改为0,相当于当前数字^1
    int mid (l + r >> 1);
    pos <= mid ? change (ls, l, mid) : change (rs, mid + 1, r);
    s[p] = s[ls] + s[rs];
}
int query (int p, int l, int r) {
    if (ql <= l && qr >= r) return s[p];
    int mid (l + r >> 1), t (0);
    if (ql <= mid) t += query (ls, l, mid);
    if (qr > mid) t += query (rs, mid + 1, r);
    return t;
}
int main() {
    read (n);
    for (int i = 1, u, v; i < n; ++i)
        read (u), read (v), add (u, v), add (v, u);
    dfs (1, 0); build (1, 1, n);
    read (m);
    for (int i = 1; i <= m; ++i) {
        char ch = getchar(); getchar();
        if (ch == 'C') read (pos), pos = num[pos], change (1, 1, n);
        else read (pos), ql = l[pos], qr = r[pos], print (query (1, 1, n)), puts ("");
    }
    return 0;
}

CF920F SUM and REPLACE

和模板的不同之处在于修改时是改为每个数的约数个数,不难发现,当一个数x<=2时,x的约数个数与本身相等,修改多少次多不会在改变

先预处理出每个数的约数个数,用线段树维护区间最大值,若<=2,则直接结束递归

对于>2的数都要暴力修改,但由于每个数的约数个数下降很快,几次后便降到<=2,所以复杂度优秀(大约是nlogn?)

这题与CF438D The Child and Sequence(区间取模),类似,也可以维护区间最大值

CF920F:

#include <bits/stdc++.h>
using namespace std;
#define rg register
#define ll long long
inline void read (int &x) {
	char ch = getchar(); x = 0;
	while (!isdigit(ch)) ch = getchar();
	while (isdigit(ch))  x = x * 10 + ch - 48, ch = getchar();
}
void print (ll x) {
	if (x > 9) print (x / 10);
	putchar (x % 10 + 48);
}
const int N = 3e5 + 10, M = 1e6;
int n, m, opt, ql, qr, cnt, maxn, p[N], a[N], k[M + 10], num[M + 10], c[N << 2];
ll s[N << 2];
inline void pre_work () {
    num[1] = 1;
    for (rg int i = 2; i <= maxn; ++i) {
        if (!k[i]) p[++cnt] = i, num[i] = 2;
        for (rg int j = 1; j <= cnt && p[j] * i <= maxn; ++j) {
            int tmp (p[j] * i), t (0);
            while (tmp % p[j] == 0) ++t, tmp /= p[j];
            k[p[j] * i] = 1, num[p[j] * i] = num[tmp] * (t + 1);
            if (i % p[j] == 0) continue;
        }
    }
}
#define ls p << 1
#define rs p << 1 | 1
inline int Max (int a, int b) {return a > b ? a : b;}
inline void push_up (int p) {
	s[p] = s[ls] + s[rs], c[p] = Max (c[ls], c[rs]);
}
void build (int p, int l, int r) {
	if (l == r) {c[p] = s[p] = a[l]; return;}
	int mid = l + r >> 1;
	build (ls, l, mid), build (rs, mid + 1, r);
	push_up (p);
}
void update (int p, int l, int r) {
	if (c[p] <= 2) return;
	if (l == r) {c[p] = s[p] = num[c[p]]; return;}
	int mid = l + r >> 1;
	if (ql <= mid) update (ls, l, mid);
	if (qr > mid) update (rs, mid + 1, r);
	push_up (p);
}
ll ask (int p, int l, int r) {
	if (ql <= l and qr >= r) return s[p];
	ll s (0);  int mid = l + r >> 1;
	if (ql <= mid) s += ask (ls, l, mid);
	if (qr > mid) s += ask (rs, mid + 1, r);
	return s;
}
int main() {
	read (n), read (m);
	for (rg int i = 1; i <= n; ++i) read (a[i]), maxn = Max (maxn, a[i]);
	pre_work (), build (1, 1, n);
	for (rg int i = 1; i <= m; ++i) {
		read (opt);
		if (opt == 1) read (ql), read (qr), update (1, 1, n);
		else read (ql), read (qr), print (ask (1, 1, n)), puts ("");
	}
	return 0;
}

CF438D:

#include <bits/stdc++.h>
using namespace std;
#define rg register
#define ll long long
inline void read (int &x) {
	char ch = getchar(); x = 0;
	while (!isdigit(ch)) ch = getchar();
	while (isdigit(ch))  x = x * 10 + ch - 48, ch = getchar();
}
void print (ll x) {
	if (x > 9) print (x / 10);
	putchar (x % 10 + 48);
}
const int N = 1e5 + 10;
int n, m, opt, ql, qr, mod, val, pos, a[N], c[N << 2];
ll s[N << 2];
#define ls p << 1
#define rs p << 1 | 1
inline int Max (int a, int b) {return a > b ? a : b;}
inline void push_up (int p) {
	s[p] = s[ls] + s[rs], c[p] = Max (c[ls], c[rs]);
}
void build (int p, int l, int r) {
	if (l == r) {c[p] = s[p] = a[l]; return;}
	int mid = l + r >> 1;
	build (ls, l, mid), build (rs, mid + 1, r);
	push_up (p);
}
void update (int p, int l, int r) {
	if (c[p] < mod) return;
	if (l == r) {c[p] %= mod, s[p] %= mod; return;}
	int mid = l + r >> 1;
	if (ql <= mid) update (ls, l, mid);
	if (qr > mid) update (rs, mid + 1, r);
	push_up (p);
}
void change (int p, int l, int r) {
	if (l == r) {c[p] = s[p] = val; return;}
	int mid = l + r >> 1;
	(pos <= mid) ? change (ls, l, mid) : change (rs, mid + 1, r);
	push_up (p);
}
ll ask (int p, int l, int r) {
	if (ql <= l and qr >= r) return s[p];
	ll s (0);  int mid = l + r >> 1;
	if (ql <= mid) s += ask (ls, l, mid);
	if (qr > mid) s += ask (rs, mid + 1, r);
	return s;
}
int main() {
	read (n), read (m);
	for (rg int i = 1; i <= n; ++i) read (a[i]);
	build (1, 1, n);
	for (rg int i = 1; i <= m; ++i) {
		read (opt);
		if (opt == 1) {
			read (ql), read (qr);
			print (ask (1, 1, n)), puts ("");
		}
		else if (opt == 2) {
			read (ql), read (qr), read (mod);
			update (1, 1, n);
		}
		else {
			read (pos), read (val);
			change (1, 1, n);
		}
	}
	return 0;
}

HDU2795 Billboard

题意:有一个h行w列的矩形,在里面横放m条大小为1*l[i]的小长方形,不能重叠,如果能放得下,输出能放下的最小行数,放不下输出-1

由于只有m个长方形,最多只需要m行(h范围很大),把h对m取min

然后维护每行剩下的值的区间最大值,查询时若左子树代表区间内的最大值>需要的长度,向左子树递归,否则考虑右子树,若长度都不够,答案为-1

#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10;
inline void read (int &x) {
    char ch = getchar(); x = 0;
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch))  x = x * 10 + ch - 48, ch = getchar();
}
inline int print (int x) {
    if (x < 0) putchar ('-'), x = -x;
    if (x > 9) print (x / 10);
    putchar (x % 10 + 48);
}
int h, w, n, val, c[N << 2];
#define ls p << 1
#define rs p << 1 | 1
inline int Max (int a, int b) {return a > b ? a : b;}
int query (int p, int l, int r, int val) {
    if (l == r) {
        if (c[p] >= val) {c[p] -= val; return l;}
        else return -1;
    }
    int mid (l + r >> 1), ans (-1);
    if (c[ls] >= val) ans = query (ls, l, mid, val);
    else if (c[rs] >= val) ans = query (rs, mid + 1, r, val);
    c[p] = Max (c[ls], c[rs]);
    return ans;
}
int main() {
    while (~scanf ("%d %d %d", &h, &w, &n)) {
        if (n < h) h = n;
        for (int i = 1; i <= (h << 2); ++i) c[i] = w;
        for (int i = 1; i <= n; ++i) {
            read (val);
            if (val > w) puts ("-1");
            else print (query (1, 1, h, val)), puts ("");
        }
    }
    return 0;
}

POJ1151 Atlantis

扫描线的模板题,求若干的矩形的覆盖面积

把每个矩形的左右边取出(把上下边取出竖着做也一样,这里讲横着做的方法),得到2n条线段,将线段按横坐标坐标排序

对纵坐标离散化,所有纵坐标的值可以最多只有2n种,只需考虑这些纵坐标组成的线段即可

用线段数维护区间内线段被覆盖的次数和区间中被覆盖的长度,当前被覆盖的长度=被覆盖一次以上的线段之和

所有的修改都成对出现(两条边的纵坐标边界一样),在这个问题中不需要下传懒标记(可以弄组数据手算一下,发现确实不需要下传)、

遍历所有线段,每次统计一次答案,加上 当前被覆盖的总长度*两条线段的纵坐标之差

另外这里的向上统计(push_up函数)比较特别

#include <stdio.h>
#include <string.h>
#include <algorithm>
using namespace std;
const int N = 410;
int n, m, T, c[N << 2];
double ans, xa[N], ya[N], xb[N], yb[N], ty[N << 1], len[N << 2];
struct e {
    double x;
    int ya, yb, val;
    bool operator <(const e &t) const {return x < t.x;}
} q[N << 1];
inline int find (double val) {
    int l (1), r (m), mid;
    while (l <= r) {
        mid = l + r >> 1;
        if (ty[mid] == val) return mid;
        if (ty[mid] < val) l = mid + 1;
        else r = mid - 1;
    }
}
#define ls p << 1
#define rs p << 1 | 1
void push_up (int p, int l, int r) {
    if (c[p]) len[p] = ty[r + 1] - ty[l];
    else if (l == r) len[p] = 0;
    else len[p] = len[ls] + len[rs];
}
void update (int p, int l, int r, int ql, int qr, int v) {
    if (ql <= l && qr >= r) {
        c[p] += v, push_up (p, l, r); return;
    }
    int mid (l + r >> 1);
    if (ql <= mid) update (ls, l, mid, ql, qr, v);
    if (qr > mid) update (rs, mid + 1, r, ql, qr, v);
    push_up (p, l, r);
}
int main() {
    while (1) {
        scanf ("%d", &n);
        if (!n) break;
        int t = n << 1;
        ans = 0; m = 1;
        memset (c, 0, sizeof (c));
        memset (len, 0, sizeof (len));
        if (!n) break;
        for (int i = 1; i <= n; ++i) scanf ("%lf %lf %lf %lf", xa + i, ya + i, xb + i, yb + i);
        for (int i = 1; i <= n; ++i) q[i].x = xa[i], q[n + i].x = xb[i], q[i].val = 1, q[n + i].val = -1;
        for (int i = 1; i <= n; ++i) ty[i] = ya[i], ty[n + i] = yb[i];
        sort (ty + 1, ty + t + 1);
        for (int i = 2; i <= t; ++i) if (ty[i] != ty[i - 1]) ty[++m] = ty[i];
        for (int i = 1; i <= n; ++i)
          q[i].ya = q[n + i].ya = find (ya[i]), q[i].yb = q[n + i].yb = find (yb[i]);
        sort (q + 1, q + t + 1);
        for (int i = 1; i < t; ++i) {
            update (1, 1, m, q[i].ya, q[i].yb - 1, q[i].val);
            ans += len[1] * (q[i + 1].x - q[i].x);
        }
        printf("Test case #%d\nTotal explored area: %.2lf\n\n", ++T, ans);
    }
    return 0;
}

HDU1225 覆盖的面积

扫描线的升级版问题,和扫描线模板不同的是,这题只有覆盖2次以上的面积才能记入总和,让问题复杂了一些

但不同的只有统计的步骤和维护的信息需要稍稍改变,这里需要维护被覆盖>=1次的长度和被覆盖>=2次的长度(因为覆盖>=2次的长度一开始由覆盖1次的得来)

这题的push_up函数更为奇特,需要分多类进行讨论

#include <stdio.h>
#include <string.h>
#include <algorithm>
using namespace std;
const int N = 2010;
int t, n, m, c[N << 2];
double xl, xr, yl[N], yr[N], res, s[N], len1[N << 2], len2[N << 2];
struct e {
    double x;
    int l, r, val;
    bool operator <(const e &t) const {return x < t.x;}
} q[N];
inline int find (double x) {
    int l (1), r (m), mid;
    while (l <= r) {
        mid = l + r >> 1;
        if (s[mid] == x) return mid;
        if (s[mid] < x) l = mid + 1;
        else r = mid - 1;
    }
}
#define ls p << 1
#define rs p << 1 | 1
inline void push_up (int p, int l, int r) {
    if (c[p]) len1[p] = s[r + 1] - s[l];
    else if (l == r) len1[p] = 0;
    else len1[p] = len1[ls] + len1[rs];
    if (c[p] > 1) len2[p] = s[r + 1] - s[l];
    else if (l == r) len2[p] = 0;
    else if (!c[p]) len2[p] = len2[ls] + len2[rs];
    else len2[p] = len1[ls] + len1[rs];
}
void update (int p, int l, int r, int ql, int qr, int val) {
    if (ql <= l && qr >= r) {
        c[p] += val; push_up (p, l, r); return;
    }
    int mid (l + r >> 1);
    if (ql <= mid) update (ls, l, mid, ql, qr, val);
    if (qr > mid) update (rs, mid + 1, r, ql, qr, val);
    push_up (p, l, r);
}
int main() {
    scanf ("%d", &t);
    while (t--) {
        res = 0, m = 1;
        scanf ("%d", &n);
        memset (c, 0, sizeof (c));
        memset (len1, 0, sizeof (len1));
        memset (len2, 0, sizeof (len2));
        for (int i = 1; i <= n; ++i) {
            scanf ("%lf %lf %lf %lf", &xl, &yl[i], &xr, &yr[i]);
            s[i] = yl[i], s[n + i] = yr[i];
            q[i].x = xl, q[n + i].x = xr, q[i].val = 1, q[n + i].val = -1;
        }
        int t = n << 1;
        sort (s + 1, s + t + 1);
        for (int i = 2; i <= t; ++i)
          if (s[i] != s[i - 1]) s[++m] = s[i];
        for (int i = 1; i <= n; ++i)
          q[i].l = q[n + i].l = find (yl[i]), q[i].r = q[n + i].r = find (yr[i]);
        sort (q + 1, q + t + 1);
        for (int i = 1; i < t; ++i) {
            update (1, 1, m - 1, q[i].l, q[i].r - 1, q[i].val);
            res += len2[1] * (q[i + 1].x - q[i].x);
        }
        printf ("%.2lf\n", res);
    }
    return 0;
}

POJ2892 Tunnel Warfare

题意:有一条相邻节点相连的链(1和2,2和3,...n-1和n相连),有3种操作:1、破坏某一个节点 2、问从某个点可以到的点有多少个,即中间没有被破坏的点(包括自己) 3、重建当前最后一个被破坏的点

建立线段树维护每个点能到达的左右边界(初始值都为1,n),破坏一个点k时(设k能到达的左右边界为l1,r1),将l1到k-1的区间内的点的右边界修改为k,将k+1到r1的区间内的点的左边界修改为k

用一个栈存储点的破坏顺序,重建一个点时,取出栈顶p(设p能到达的左右边界为l2,r2),将l2到k-1的区间内的点的右边界修改为r2,将k+1到r2的区间内的点的左边界修改为l2

查询时直接用左右边界相减再+1即可

(蒟蒻的代码可能不具有很强的可读性,但主体部分还是比较清楚的)

#include <stdio.h>
#include <iostream>
using namespace std;
inline void read (int &x) {
    char ch = getchar(); x = 0;
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) x = x * 10 + ch - 48, ch = getchar();
}
void print (int x) {
    if (x > 9) print (x / 10);
    putchar (x % 10 + 48);
}
const int N = 5e4 + 10;
int n, m, x, top, st[N], ll[N], rr[N], k[N], cl[N << 3], cr[N << 3];
char ch[5];
#define ls p << 1
#define rs p << 1 | 1
inline void push_down (int p) {
    if (cl[p]) cl[ls] = cl[rs] = cl[p], cl[p] = 0;
    if (cr[p]) cr[ls] = cr[rs] = cr[p], cr[p] = 0;
}
void build (int p, int l, int r) {
    if (l == r) {cl[l] = 1, cr[l] = n; return;}
    int mid (l + r >> 1);
    build (ls, l, mid), build (rs, mid + 1, r);
}
void update1 (int p, int l, int r, int ql, int qr, int v) {      //修改右边界
    if (ql <= l && qr >= r) {cr[p] = v; return;}
    if (l == r) {rr[l] = cr[p]; return;}
    push_down (p);
    int mid (l + r >> 1);
    if (ql <= mid) update1 (ls, l, mid, ql, qr, v);
    if (qr > mid) update1 (rs, mid + 1, r, ql, qr, v);
}
void update2  (int p, int l, int r, int ql, int qr, int v) {     //修改左边界
    if (ql <= l && qr >= r) {cl[p] = v; return;}
    if (l == r) {ll[l] = cl[p]; return;}
    push_down (p);
    int mid (l + r >> 1);
    if (ql <= mid) update2 (ls, l, mid, ql, qr, v);
    if (qr > mid) update2 (rs, mid + 1, r, ql, qr, v);
}
void query (int p, int l, int r, int pos) { //使这个点的所有懒标记更新(更新为真实值)
    if (l == r) {ll[l] = cl[p], rr[l] = cr[p]; return;}
    push_down (p);
    int mid (l + r >> 1);
    pos <= mid ? query (ls, l, mid, pos) : query (rs, mid + 1, r, pos);
}
int main() {
    read (n), read (m);
    build (1, 1, n);
    for (int i = 1; i <= n; ++i) ll[i] = 1, rr[i] = n;
    while (m--) {
        scanf ("%s", ch);
        if (ch[0] == 'D') {
            read (x);
            st[++top] = x, k[x] = 1, query (1, 1, n, x);
            update1 (1, 1, n, ll[x] - (ll[x] != 1), x - 1, x - 1);
            update2 (1, 1, n, x + 1, rr[x] + (rr[x] != n), x + 1);
        }
        else if (ch[0] == 'Q') {
            read (x);
            if (k[x]) puts ("0");
            else query (1, 1, n, x), print (rr[x] - ll[x] + 1), puts ("");
        }
        else {
            if (!top) continue;
            x = st[top--], k[x] = 0, query (1, 1, n, x);
            update1 (1, 1, n, ll[x] - (ll[x] != 1), x - 1, rr[x]);
            update2 (1, 1, n, x + 1, rr[x] + (rr[x] != n), ll[x]);
        }
    }
    return 0;
}

HDU3016 Man Down

题意:在平面内有n条横放的线段,每条线段给出高度、左右端点位置和能量(可正可负),最初处于最高的一条线段上且拥有100点能量和当前线段上的能量之和,每次可以从线段左边或右边竖直落下(可能落到其他线段上也可能落到地上),落到其他线段上则获得线段上能量并继续游戏,若落到地上则终止游戏,得分为当前能量之和。但任意一个时刻能量必须为正,当能量<=0时就失败了,终止游戏。若能落到地上,输出最大得分,若不能输出-1

这题我的方法比较奇特(傻)

假设地面为0号线段,每条的线段都会指向另外两条线段(从左右端点落下时到达线段),可以看作线段之间连了一条又向边,边的长度可以用到达的线段上的能量表示,这样就建出了一张有向图,跑Spfa求出最高线段到0号的最长路即为答案(其实这个问题用dp解决就好了,dp写法见下一题)

关键在于求出每条线段从左右两段落下后会到达哪两条线段:上面的线段可以覆盖下面的,所以按高度从低到高排序,依次更新、查询就可以求出两条线段的编号。举个例子,当前线段已经按高度排序,处理到一条左端点为5,右端点为10的线段i时,先在线段树中单点查询5、10,然后将5——10的区间值修改为i

#include <bits/stdc++.h>
using namespace std;
inline void read (int &x) {
    char ch = getchar(); int f = 0; x = 0;
    while (!isdigit(ch)) {if (ch == '-') f = 1; ch = getchar();}
    while (isdigit(ch)) x = x * 10 + ch - 48, ch = getchar();
    if (f) x = -x;
}
const int N = 1E5 + 10;
int n, cnt, M, c[N << 2], d[N], to[N][2], vis[N];
struct e {
    int h, l, r, val;
    bool operator < (const e &x) const {return h < x.h;}
} a[N];
#define ls p << 1
#define rs p << 1 | 1
inline void push_down (int p) {
    if (c[p]) c[ls] = c[rs] = c[p], c[p] = 0;
}
int update (int p, int l, int r, int ql, int qr, int val) {
    if (ql <= l && qr >= r) {c[p] = val; return 0;}
    push_down (p);
    int mid (l + r >> 1);
    if (ql <= mid) update (ls, l, mid, ql, qr, val);
    if (qr > mid) update (rs, mid + 1, r, ql, qr, val);
}
int query (int p, int l, int r, int pos) {
    if (l == r) return c[p];
    push_down (p);
    int mid (l + r >> 1);
    return pos <= mid ? query (ls, l, mid, pos) : query (rs, mid + 1, r, pos);
}
inline void Spfa () {
    queue <int> q;
    memset (d, 0xcf, sizeof (d));
    memset (vis, 0, sizeof (vis));
    d[n] = 100 + a[n].val, vis[n] = 1, q.push (n);
    while (!q.empty()) {
        int u = q.front ();
        vis[u] = 0, q.pop ();
        for (int i = 0; i <= 1; ++i) {
            int v = to[u][i];
            if (d[u] + a[v].val > d[v] && d[u] + a[v].val > 0) {
                d[v] = d[u] + a[v].val;
                if (!vis[v]) q.push (v), vis[v] = 1;
            }
        }
    }
}
int main() {
    while (~scanf ("%d", &n)) {
        cnt = M = 0;
        memset (c, 0, sizeof (c));
        for (int i = 1; i <= n; ++i)
            read (a[i].h), read (a[i].l), read (a[i].r), read (a[i].val), M = max (M, a[i].r);
        sort (a + 1, a + n + 1);
        for (int i = 1; i <= n; ++i) {
            int tl = query (1, 1, M, a[i].l), tr = query (1, 1, M, a[i].r);
            to[i][0] = tl, to[i][1] = tr;
            update (1, 1, M, a[i].l, a[i].r, i);
        }
        Spfa ();
        if (d[0] < 0) puts ("-1");
        else printf ("%d\n", d[0]);
    }
    return 0;
}

P1442 铁球落地

与上题十分类似,只是答案的计算方式有些不同,前面的部分就不说了

计算答案时上题用最短路写的,这题用dp写一遍

fl[i]表示到达每个平台左侧的最短时间,fr[i]同理

转移方程应该挺简单的,从上向下dp就行了,因为每个平台只会更新下面的两块平台,dp时间复杂度O(n),总时间复杂度nlogn

#include <bits/stdc++.h>
using namespace std;
inline void read (int &x) {
    char ch = getchar(); x = 0;
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) x = x * 10 + ch - 48, ch = getchar();
}
const int N = 2e5 + 10;
int n, m, M, sx, sy, cnt, s[N << 1], l[N], r[N], f[N], c[N << 2], tag[N << 2], fl[N], fr[N];
struct e {
    int h, l, r;
    bool operator < (const e &t) {return h < t.h;}
} p[N];
inline int find (int x) {
    int l (1), r (m), mid;
    while (l <= r) {
        mid = l + r >> 1;
        if (s[mid] == x) return mid;
        if (s[mid] < x) l = mid + 1;
        else r = mid - 1;
    }
}
#define ls p << 1
#define rs p << 1 | 1
void push_down (int p) {
    if (tag[p])
        c[ls] = c[rs] = tag[ls] = tag[rs] = tag[p], tag[p] = 0;
}
void update (int p, int l, int r, int ql, int qr, int val) {
    if (ql <= l && qr >= r) {c[p] = tag[p] = val; return;}
    push_down (p);
    int mid (l + r >> 1);
    if (ql <= mid) update (ls, l, mid, ql, qr, val);
    if (qr > mid) update (rs, mid + 1, r, ql, qr, val);
}
int query (int p, int l, int r, int pos) {
    if (l == r) return c[p];
    push_down (p);
    int mid (l + r >> 1);
    return pos <= mid ? query (ls, l, mid, pos) : query (rs, mid + 1, r, pos);
}
inline int Min (int a, int b) {return a > b ? b : a;}
signed main() {
    read (n), read (M), read (sx), read (sy);
    for (int i = 1; i <= n; ++i) read (p[i].h), read (p[i].l), read (p[i].r);
    sort (p + 1, p + n + 1);
    while (p[n].h >= sy) --n;
    p[++n] = (e) {sy, sx, sx};
    for (int i = 1; i <= n; ++i) s[++cnt] = p[i].l, s[++cnt] = p[i].r;
    sort (s + 1, s + cnt + 1); m = 1;
    for (int i = 2; i <= cnt; ++i) if (s[i] != s[i - 1]) s[++m] = s[i];
    for (int i = 1; i <= n; ++i) p[i].l = find (p[i].l), p[i].r = find (p[i].r);
//    for (int i = 1; i <= n; ++i) printf ("%d %d\n", p[i].l, p[i].r);
    for (int i = 1; i <= n; ++i) {
        l[i] = query (1, 1, m, p[i].l), r[i] = query (1, 1, m, p[i].r);
        update (1, 1, m, p[i].l, p[i].r, i);
    }
//    for (int i = 1; i <= n; ++i) printf ("%d %d %d %d %d\n", p[i].h, p[i].l, p[i].r, l[i], r[i]);
    memset (fl, 0x3f, sizeof (fl));
    memset (fr, 0x3f, sizeof (fr));
    fl[n] = fr[n] = 0;
    for (int i = n; i >= 1; --i) {
        if (p[i].h - p[l[i]].h <= M) {
            int t1 (0), t2 (0);
            if (l[i]) t1 = s[p[i].l] - s[p[l[i]].l], t2 = s[p[l[i]].r] - s[p[i].l];
            fl[l[i]] = Min (fl[l[i]], 1ll * fl[i] + p[i].h - p[l[i]].h + t1);
            fr[l[i]] = Min (fr[l[i]], 1ll * fl[i] + p[i].h - p[l[i]].h + t2);
        }
        if (p[i].h - p[r[i]].h <= M) {
            int t1 (0), t2 (0);
            if (r[i]) t1 = s[p[i].r] - s[p[r[i]].l], t2 = s[p[r[i]].r] - s[p[i].r];
            fl[r[i]] = Min (fl[r[i]], 1ll * fr[i] + p[i].h - p[r[i]].h + t1);
            fr[r[i]] = Min (fr[r[i]], 1ll * fr[i] + p[i].h - p[r[i]].h + t2);
        }
    }
//    for (int i = n; i >= 0; --i) printf ("%d %d\n", fl[i], fr[i]);
    printf ("%d", Min (fl[0], fr[0]));
    return 0;
}
posted @ 2019-12-13 23:35  -敲键盘的猫-  阅读(361)  评论(0编辑  收藏  举报