题意:给你一个序列,让你求出对于所有区间<i, j>的mex和,mex表示该区间没有出现过的最小的整数。
思路:从时限和点数就可以看出是线段树,并且我们可以枚举左端点i, 然后求出所有左端点为i的区间内mex值的和。
先把数插满,然后先询问后删除当前最左边的断点i。而且显然线段树里面保存的是mex值,而且这个序列是非递减的。
分析:我们先预处理出对于右端点为i的所有<1,i>的mex,分别插入线段树的i位置。然后每次删除最左边的左端点i
,假如当前我们要删除a[i] ,我们找到它之后第一个位置j满足a[i] == a[j], 那么区间i------j-1里面的所有mex都要更新,取线段树内的值和a[i]的最小值。 实际操作我们只要找到第一个比a[i]的位置l, r = j-1, 更新<l,r>之间的mex为a[i]即可。
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; typedef long long ll; #define lson l, m, rt<<1 #define rson m+1, r, rt<<1|1 #define ls rt<<1 #define rs rt<<1|1 #define Mid int m = l+r>>1 const int maxn = 2000006; int next[maxn], pre[maxn], n; int a[maxn], mex; bool vis[maxn]; ll sum[maxn<<2]; int mx[maxn<<2], col[maxn<<2]; void build(int l=1, int r=n, int rt=1) { col[rt] = -1; sum[rt] = 0; mx[rt] = 0; if(l == r) return; Mid; build(lson); build(rson); } inline void down(int l, int r, int rt) { if(~col[rt]) { col[ls] = col[rs] = col[rt]; Mid; sum[ls] = (ll)(m-l+1)*col[rt]; mx[ls] = mx[rs] = col[rt]; sum[rs] = (ll)(r-m)*col[rt]; col[rt] = -1; } } inline void up(int rt) { sum[rt] = sum[ls] + sum[rs]; mx[rt] = max(mx[ls], mx[rs]); } void update(int L, int R, int v, int l=1, int r=n, int rt=1) { if(L <= l && r <= R) { col[rt] = mx[rt] = v; sum[rt] = (ll)(r-l+1)*v; return; } Mid; down(l, r, rt); if(L <= m) update(L, R, v, lson); if(R > m) update(L, R, v, rson); up(rt); } ll query(int L, int R, int l=1, int r=n, int rt=1) { if(L <= l && r <= R) return sum[rt]; Mid; down(l, r, rt); ll ret = 0; if(L <= m) ret += query(L, R, lson); if(R > m) ret += query(L, R, rson); up(rt); return ret; } int find(int v, int l=1, int r=n, int rt=1) { if(mx[rt] <= v) return n+1; if(l == r) return l; Mid; down(l, r, rt); if(mx[ls] > v) return find(v, lson); else return find(v, rson); } int main() { int i, j; while(~scanf("%d", &n) && n) { for(i = 1; i <= n; i++) { scanf("%d", &a[i]); pre[i] = vis[i] = 0; next[i] = n+1; } pre[0] = vis[0] = 0; for(i = 1; i <= n; i++) if(a[i] <= n) { if(pre[a[i]]) next[pre[a[i]]] = i; pre[a[i]] = i; } build(); mex = 0; for(i = 1; i <= n; i++) { if(a[i] <= n){ vis[a[i]] = 1; while(vis[mex]) mex++; } update(i, i, mex); } ll ans = 0; for(i = 1; i <= n; i++) { ans += query(i, n); if(a[i] <= mex) { int l = max(find(a[i]), i); int r = next[i]-1; if(l <= r) update(l, r, a[i]); } } printf("%I64d\n", ans); } return 0; } /* 3 0 10000 20000 */