HDU #4747 MEX (线段树的应用)

题目描述:

定义$mex(i,j)$为序列中第$i$项到第$j$项中没有出现的最小自然数。给定序列,求$\sum^{n}_{i=1}\sum^{n}_{j=i}mex(i,j)$。

解题思路:

首先我们可以$O(n)$预处理出$mex(1,1\sim n)$,因为显然的是$mex$是递增的。然后我们考虑怎么从$mex(i,i\sim n)$推出$mex(i+1,i+1\sim n)$,我们删掉$a_i$这个数后,哪些区间的$mex$会改变呢?其实就是到下一个与$a_i$相等的数出现前$mex$大于$a_i$的区间,因为这段区间没有了$a_i$这个数,而他们原本的mex却大于$a_i$,所以可以变小。所以要区间查询、修改、求和,用线段树就可以了。

代码:

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <map>
#define i64 long long
using namespace std;

const int N = 2e5 + 10;
int n, a[N], mex[N], nxt[N];
i64 ans;

map<int, int> mp;

struct node {
    int s, mx, tag;
} tr[N * 8];

void init() {
    int now = 0;
    for (int i = 1; i <= n; i ++) {
        mp[a[i]] = 1;
        while (mp.count(now)) now ++;
        mex[i] = now;
    }
    mp.clear();
    for (int i = n; i; i --) {
        if (mp.count(a[i])) nxt[i] = mp[a[i]];
        else nxt[i] = n + 1;
        mp[a[i]] = i;
    }
}

void build(int o, int l, int r) {
    if (l == r) {
        tr[o].s = tr[o].mx = mex[l];
        tr[o].tag = -1;
        return;
    }
    tr[o].tag = -1;
    int m = l + r >> 1;
    build(o << 1, l, m);
    build(o << 1 | 1, m + 1, r);
    tr[o].s = tr[o << 1].s + tr[o << 1 | 1].s;
    tr[o].mx = max(tr[o << 1].mx, tr[o << 1 | 1].mx);
}

void pushdown(int o, int l, int r) {
    if (tr[o].tag == -1) return;
    tr[o << 1].tag = tr[o << 1 | 1].tag = tr[o].tag;
    tr[o].s = tr[o].tag * (r - l + 1);
    tr[o].mx = tr[o].tag;
    tr[o].tag = -1;
}

int find(int o, int l, int r, int v) {
    if (l == r) return l;
    int m = l + r >> 1;
    pushdown(o << 1, l, m);
    pushdown(o << 1 | 1, m + 1, r);
    if (tr[o << 1].mx > v) return find(o << 1, l, m, v);
    else return find(o << 1 | 1, m + 1, r, v);    
}

void updata(int o, int l, int r) {
    int m = l + r >> 1, x, y;
    if (tr[o << 1].tag != -1) x = tr[o << 1].tag; else x = tr[o << 1].mx;
    if (tr[o << 1 | 1].tag != -1) y = tr[o << 1 | 1].tag; else y = tr[o << 1 | 1].mx;
    tr[o].mx = max(x, y);
    if (tr[o << 1].tag != -1) x = tr[o << 1].tag * (m - l + 1); else x = tr[o << 1].s;
    if (tr[o << 1 | 1].tag != -1) y = tr[o << 1 | 1].tag * (r - m); else y = tr[o << 1 | 1].s;
    tr[o].s = x + y;
}

void modify(int o, int l, int r, int x, int y, int v) {
    if (x <= l && r <= y) {
        tr[o].tag = v;
        return;
    }
    pushdown(o, l, r);
    int m = l + r >> 1;
    if (x <= m) modify(o << 1, l, m, x, y, v);
    if (y > m) modify(o << 1 | 1, m + 1, r, x, y, v);
    updata(o, l, r);
}

int query(int o, int l, int r, int x, int y) {
    pushdown(o, l, r);
    if (x <= l && r <= y) return tr[o].s;
    int m = l + r >> 1, t = 0;
    if (x <= m) t = query(o << 1, l, m, x, y);
    if (y > m) t += query(o << 1 | 1, m + 1, r, x, y);
    return t;
}

void work() {
    ans += (i64)query(1, 1, n, 1, n - 1);
    for (int i = 1; i < n - 1; i ++) {
        pushdown(1, 1, n);
        int k = find(1, 1, n, a[i]);
        if (k < nxt[i]) modify(1, 1, n, k, nxt[i] - 1, a[i]);
        ans += (i64)query(1, 1, n, i + 1, n - 1);
    }
    printf("%lld", ans);
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i ++) scanf("%d", &a[i]);
    init();
    mex[++ n] = N;
    build(1, 1, n);
    work();
    return 0;
}

 

posted @ 2016-08-16 21:53  Awner  阅读(243)  评论(0编辑  收藏  举报