解题报告 『[HNOI2002]营业额统计(splay)』

原题地址

套splay的板子就能过了,但要注意当x无前驱或无后继时要返回inf。

 

代码实现如下:

#include <bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for (register int i = (a); i <= (b); i++)

const int inf = 0x3f3f3f3f, maxn = 4e4 + 5;

int n, rt, ans = 0, tot = 0;
int fa[maxn], cnt[maxn], val[maxn], size[maxn], ch[maxn][3];

int ABS(int x) {return x < 0 ? -x : x;}

int get(int x) {return x == ch[fa[x]][1];}

int MIN(int a, int b) {return a < b ? a : b;}

void maintain(int x) {size[x] = size[ch[x][0]] + size[ch[x][1]] + cnt[x];}

int pre() {
    int cnr = ch[rt][0];
    if (!cnr) return inf;
    while (ch[cnr][1]) cnr = ch[cnr][1];
    return val[cnr];
}

int nxt() {
    int cnr = ch[rt][1];
    if (!cnr) return inf;
    while (ch[cnr][0]) cnr = ch[cnr][0];
    return val[cnr];
}

void rotate(int x) {
    int y = fa[x], z = fa[y], chk = get(x);
    ch[y][chk] = ch[x][chk ^ 1];
    fa[ch[x][chk ^ 1]] = y;
    ch[x][chk ^ 1] = y;
    fa[y] = x;
    fa[x] = z;
    if (z) ch[z][y == ch[z][1]] = x;
    maintain(y);
    maintain(x);
}

void splay(int x) {
    for (register int f = fa[x]; f = fa[x], f; rotate(x))
        if (fa[f]) rotate(get(x) == get(f) ? f : x);
    rt = x;
}

void insert(int k) {
    if (!rt) {
        val[++tot] = k;
        cnt[tot]++;
        rt = tot;
        maintain(rt);
        return;
    }
    int f = 0, cnr = rt;
    while (1) {
        if (k == val[cnr]) {
            cnt[cnr]++;
            maintain(cnr);
            maintain(f);
            splay(cnr);
            break;
        }
        f = cnr;
        cnr = ch[cnr][val[cnr] < k];
        if (!cnr) {
            val[++tot] = k;
            cnt[tot]++;
            fa[tot] = f;
            ch[f][val[f] < k] = tot;
            maintain(tot);
            maintain(f);
            splay(tot);
            break;
        }
    }
}

int read() {
    int x = 0, flag = 0;
    char ch = ' ';
    while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
    if (ch == '-') {
        flag = 1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 1) + (x << 3) + (ch ^ '0');
        ch = getchar();
    }
    return flag ? -x : x;
}

void write(int x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

int main() {
    n = read();
    rep(i, 1, n) {
        int x;
        x = read();
        insert(x);
        if (i == 1) ans += x;
        else {
            if (cnt[rt] <= 1)
                ans += MIN(ABS(x - pre()), ABS(x - nxt()));
        }  
    }
    write(ans);
    return 0;
}
View Code
posted @ 2019-08-03 10:42  雲裏霧裏沙  阅读(184)  评论(0编辑  收藏  举报