ABC306 E - Best Performances 题解 离散化+线段树/splay tree
题目链接:https://atcoder.jp/contests/abc306/tasks/abc306_e
题目大意:
有一个长度为 \(N\) 的序列 \(A = (A_1, A_2, \ldots, A_N)\),以及一个整数 \(K\)。
初始时序列 \(A\) 的所有元素的数值均为 \(0\)。
有 \(Q\) 次操作,每次操作给你两个整数 \(X_i\) 和 \(Y_i\),你需要将序列 \(A\) 中第 \(X_i\) 个元素的数值修改为 \(Y_i\)(即 \(A_{X_i} \leftarrow Y_i\)),然后输出一个整数,这个整数的数值为序列 \(A\) 中最小的 \(K\) 个数之和。
解题思路:
线段树离散化之后,或者 splay tree 都是基本操作。
示例程序1(线段树 + 离散化):
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 5;
int n, M, K, Q, x[maxn], y[maxn], a[maxn];
vector<int> vec;
int lsh(int val) {
return lower_bound(vec.begin(), vec.end(), val) - vec.begin() + 1;
}
// 线段树
int tr_cnt[maxn<<2];
long long tr_sum[maxn<<2];
#define lson l, mid, rt<<1
#define rson mid+1, r, rt<<1|1
void push_up(int rt) {
tr_cnt[rt] = tr_cnt[rt<<1] + tr_cnt[rt<<1|1];
tr_sum[rt] = tr_sum[rt<<1] + tr_sum[rt<<1|1];
}
// 离散化之后的数字是p,增加了 c 个(+1 或者 -1)
void add(int p, int c, int l, int r, int rt) {
// printf("add [%d, %d] [%d, %d] %d\n", p, c, l, r, rt);
if (l == r) {
tr_cnt[rt] += c;
tr_sum[rt] += (long long) c * vec[p-1];
return;
}
int mid = (l + r) / 2;
(p <= mid) ? add(p, c, lson) : add(p, c, rson);
push_up(rt);
}
// 前k个数
long long query(int k, int l, int r, int rt) {
// printf("query %d [%d , %d] %d\n", k, l, r, rt);
if (l == r) {
assert(tr_cnt[rt] >= k);
return (long long) k * vec[l-1];
}
if (tr_cnt[rt] == k)
return tr_sum[rt];
int mid = (l + r) / 2;
if (tr_cnt[rt<<1|1] >= k)
return query(k, rson);
return tr_sum[rt<<1|1] + query(k - tr_cnt[rt<<1|1], lson);
}
int main()
{
scanf("%d%d%d", &n, &K, &Q);
vec.push_back(0);
for (int i = 0; i < Q; i++) {
scanf("%d%d", x+i, y+i);
vec.push_back(y[i]);
}
sort(vec.begin(), vec.end());
vec.erase(unique(vec.begin(), vec.end()), vec.end());
M = vec.size();
add(1, n, 1, M, 1);
for (int i = 0; i < Q; i++) {
int p = x[i], q = y[i]; // a[p] = q
add(lsh(a[p]), -1, 1, M, 1);
a[p] = q;
add(lsh(a[p]), 1, 1, M, 1);
printf("%lld\n", query(K, 1, M, 1));
}
return 0;
}
示例程序2(splay tree):
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 5;
int n, M, K, Q, a[maxn];
struct Node {
int s[2], p, v; // s[0]左儿子编号,s[1]右儿子编号,p父节点编号,v数值
int sz, cnt; // 子树大小
long long sum; // 子树数值之和
Node() {}
Node(int _v, int _p) {
v = _v;
p = _p;
s[0] = s[1] = 0;
sz = cnt = 1;
sum = _v;
}
} tr[maxn];
int root, idx;
void push_up(int x) {
int ls = tr[x].s[0], rs = tr[x].s[1];
tr[x].sz = tr[ls].sz + tr[rs].sz + tr[x].cnt;
tr[x].sum = tr[ls].sum + tr[rs].sum + (long long) tr[x].cnt * tr[x].v;
}
void f_s(int p, int u, bool k) {
tr[p].s[k] = u;
tr[u].p = p;
}
void rot(int x) {
int y = tr[x].p, z = tr[y].p;
bool k = tr[y].s[1] == x;
f_s(z, x, tr[z].s[1]==y);
f_s(y, tr[x].s[k^1], k);
f_s(x, y, k^1);
push_up(y), push_up(x);
}
// 旋转到 x 的父节点为 k 为止(若k为0,则 x 将旋转到根节点)
void splay(int x, int k) {
while (tr[x].p != k) {
int y = tr[x].p, z = tr[y].p;
if (z != k)
(tr[y].s[1]==x) ^ (tr[z].s[1]==y) ? rot(x) : rot(y);
rot(x);
}
if (!k) root = x;
}
// 插入一个数值为 v 的节点
void ins(int v) {
int u = root, p = 0;
while (u) {
if (tr[u].v == v)
break;
p = u, u = tr[u].s[v > tr[u].v];
}
if (u) {
tr[u].cnt++;
push_up(u);
}
else {
tr[u = ++idx] = Node(v, p);
if (p) tr[p].s[v > tr[p].v] = u;
}
splay(u, 0);
}
// 找前驱:找数值 < v 的最大的那个数
int get_pre(int v) {
int u = root, res = -1;
while (u) {
if (tr[u].v < v) res = u, u = tr[u].s[1];
else u = tr[u].s[0];
}
return res;
}
// 找数值等于 v 的最前面(中序遍历序号最小)那个点
int get_point(int v) {
int u = root, res = -1;
while (u) {
if (tr[u].v >= v) res = u, u = tr[u].s[0];
else u = tr[u].s[1];
}
return res;
}
// 删除一个数值为 v 的节点
void del(int v) {
int u1 = get_pre(v); // 找前驱
splay(u1, 0);
int u2 = get_point(v); // 查找一个数值为 v 的节点
splay(u2, u1);
if (tr[u2].cnt > 1) {
tr[u2].cnt--;
push_up(u2);
}
else
f_s(u1, tr[u2].s[1], 1);
push_up(u1);
}
long long query(int k) {
int u = root;
long long res = 0;
if (tr[u].sz <= k) return tr[u].sum;
while (u) {
int ls = tr[u].s[0], rs = tr[u].s[1];
if (tr[rs].sz >= k)
u = rs;
else {
res += tr[rs].sum;
k -= tr[rs].sz;
if (k <= tr[u].cnt) {
res += (long long) tr[u].v * k;
break;
}
else {
res += (long long) tr[u].v * tr[u].cnt;
k -= tr[u].cnt;
}
u = ls;
}
}
return res;
}
int main()
{
scanf("%d%d%d", &n, &K, &Q);
ins(0);
for (int i = 0; i < Q; i++) {
int x, y;
scanf("%d%d", &x, &y);
if (a[x])
del(a[x]); // 删除 delete
a[x] = y;
if (a[x])
ins(a[x]); // 插入 insert
printf("%lld\n", query(K));
}
return 0;
}