P10235 [yLCPC2024] C. 舞萌基本练习 题解

题目传送门

大致题意:

多组测试数据,每组数据给定一个长度为 \(n\) 的序列和一个参数 \(k\),要求将此区间划分成不超过 \(k\) 段,使这些区间中的逆序对数量的最大值最小。

思路:

对于求“最大值最小”这类问题,很容易想到二分。显然,本题的答案是具有单调性的,即当划分的段数减少时,区间中逆序对数量的最大值只会增大不会减小

所以我们考虑将二分答案转化为二分判定,具体为:我们二分一个 \(limit\) 值,表示当前的解,然后将当前这个 \(limit\) 代入计算。

若当前解合法,则说明区间 \([limit,r]\) 的解肯定都是合法的,因为此时 \(limit\) 也可能作为最后答案,所以令 \(r = mid\),向左扩展答案即可。

若当前解不合法,则说明区间 \([l,limit]\) 的解肯定都不合法,所以此时令 \(l = mid + 1\) 向右扩展答案即可。

二分的问题解决了,那么怎么判断这个解是否合法呢?

一个很简单的想法:扫描整个序列,一开始整个序列只有一段,将序列中的数一个一个地加入这个段中,当此段逆序对数量大于 \(limit\) 时,就重新开辟新的一段并使段数 \(cnt\) 增加 \(1\)。将整个序列划分完后,若 \(cnt \le k\),则此解合法,否则不合法。

对于求逆序对数量,用树状数组可以很好解决。

建立一个权值树状数组,每次在某段末尾加入一个数时,只需计算该段中大于它的数的个数,这就是新增的逆序对数。

然而这里要注意两个点:

  1. \(-10^9 \le a_i \le 10^9\),所以需要离散化;
  2. 在重新开辟一段时,之前那段的数要全部从树状数组中抹去,这样才能让它正确地求出后面段的逆序对数。

\(\texttt{Some Tips}\)

由于长度为 \(n\) 的序列最大逆序对数量为 \(\frac{n(n-1)}{2}\),所以我的二分上界取了 \(10^{10}\)

在每次二分中,序列中的所有数只会进入一次、出一次树状数组,所以 \(\operatorname{check()}\) 的时间复杂度为 \(O(n\log n)\)

整个程序时间复杂度 \(O(n\log n\log 10^{10})\),空间复杂度 \(O(n)\)

\(\texttt{Code}:\)

#include <iostream>
#include <vector>
#include <algorithm>
#include <cstdio>

#define lowbit(x) x & -x
using namespace std;

const int N = 100010;

int T, n, k;
int a[N], c[N];
int nums[N];
int tt;
int fnd[N];

int find(int x) {
	return lower_bound(nums + 1, nums + tt + 1, x) - nums;
}

int ask(int x) {
	int res = 0;
	for(; x; x -= lowbit(x)) res += c[x];
	return res;
}

void add(int x, int y) {
	for(; x <= n; x += lowbit(x)) c[x] += y;
}

bool check(long long limit) {
	int cnt = 1; //段数 
	long long f = 0; //目前处理的段的逆序对数 
	int L = 1; //目前处理的段的左端点 
	for(int i = 1; i <= n; i++) {
		int tmp = ask(tt) - ask(fnd[i]); //计算新增的逆序对数 
		if(f + tmp > limit) {
			cnt++; //段数 + 1
			f = 0; //重置逆序对数
			for(int j = L; j <= i - 1; j++) 
				add(fnd[j], -1); //清除上一区间的贡献 
			L = i; //更新左端点 
		}
		else f += tmp;
		add(fnd[i], 1); //加入树状数组
	}
	for(int i = L; i <= n; i++) add(fnd[i], -1); //不要忘了最后一段也要抹去
	return cnt > k;
}

int main() {
	scanf("%d", &T);
	while(T--) {
		scanf("%d%d", &n, &k);
		for(int i = 1; i <= n; i++) {
			scanf("%d", &a[i]);
			nums[++tt] = a[i];
		}
		sort(nums + 1, nums + tt + 1);
		tt = unique(nums + 1, nums + tt + 1) - nums - 1;
		for(int i = 1; i <= n; i++) fnd[i] = find(a[i]);
		long long l = 0, r = 1e10;
		while(l < r) {
			long long mid = l + r >> 1;
			if(check(mid)) l = mid + 1;
			else r = mid;
		}
		printf("%lld\n", l);
		for(int i = 1; i <= tt; i++) nums[i] = 0;
		tt = 0;
	}
	return 0;
}
posted @ 2024-07-23 14:37  Brilliant11001  阅读(4)  评论(0编辑  收藏  举报