POJ2976 Dropping tests 题解 01分数规划
题目链接:http://poj.org/problem?id=2976
题目大意
在某门课程中,你要进行 \(n\) 个测试。其中第 \(i\) 个测试一共有 \(b_i\) 道问题,你答对的有 \(a_i\) 道。你在这门课程中的成绩定义为
现在给你一个权力,你 最多 可以选择这 \(n\) 个测试中的 \(k\) 个测试(\(k \lt n\)),并将选择的测试作废(也就是说如果你将某一个测试 \(j\) 作废,则 \(a_j\) 和 \(b_j\) 将不计入上述公式)。
举个例子,假设你有 \(3\) 个测试,对应的答题情况为 \(5/5, 0/1, 2/6\)(即 \(a_1=b_1=5,a_2=0,b_2=1,a_3=2,b_3=6\))。如果你不选择任何测试作废,则你的成绩为 \(100 \cdot \frac{5+0+2}{5+1+6}=50\);然而,如果你将第 \(3\) 个测试作废,则你的成绩为 \(100 \cdot \frac{5+0}{5+1} \approx 83.33 \approx 83\)。
输入格式
输入包含多组测试数据,每组测试数据包含三行。
每组数据的第一行包含三个整数 \(n\) 和 \(k\)(\(0 \le k \lt n \le 1000\)),第二行包含 \(n\) 个整数,两两之间以一个空格分隔,表示 \(a_i\),第三行包含 \(n\) 个整数,两两之间以一个空格分隔,表示 \(b_i\)(\(0 \le a_i \le b_i \le 10^9\))。
输入的最后一行包含两个整数 \(0\),表示输入数据的结束。
输出格式
对于每组测试数据,输出一个整数,表示在最多将 \(k\) 个测试作废的情况下你能够得到的最高成绩。因为成绩可能不是一个整数,所以你需要输出答案四舍五入的整数。
样例输入
3 1
5 0 2
5 1 6
4 2
1 2 7 9
5 6 7 9
0 0
样例输出
83
100
问题分析
对于这个问题,首先我想到的是贪心去除掉比例(\(\frac{a_i}{b_i}\))最小的 \(k\) 个,然后计算剩下的 \(n-k\) 个对应的结果。但是很快发现了一个反例:
假设 \(n=3,k=1\),\(3\) 件物品分别为 \(10/10,3/10,1/5\),则:
- 保留三件物品,成绩为 \(100 \cdot \frac{10+3+1}{10+10+5} = 56\)
- 去除第 \(1\) 件物品,成绩为 \(100 \cdot \frac{3+1}{10+5} \approx 26.67 \approx 27\)
- 去除第 \(2\) 件物品,成绩为 \(100 \cdot \frac{10+1}{10+5} \approx 73.33 \approx 73\)
- 去除第 \(3\) 件物品,成绩为 \(100 \cdot \frac{10+3}{10+10} = 65\)
可以发现,第 \(3\) 个测试的 \(\frac{a_i}{b_i}\) 是最小的,但是去除第 \(3\) 个测试后的成绩却不是最高的(而是去除第 \(2\) 个测试后的成绩最高),这是因为大家不是在同一个角度(\(b_i\))上分析的,如果分母都一样,那么是没有问题的,但是分母不一样就会导致这种贪心策略出现问题。
由此推导出我们接下来要讨论的 01分数规划 的解法。
对于本题,我们可以将其看成如下问题:
求一组 \(w_i \in \{ 0, 1\}\) 最大化
这里还有一个限制是 \(\sum\limits_{i=1}^n w_i \ge n-k\)
我们如果当前的 \(mid\) 满足条件则必然
为了让上述条件满足,我会选 \(a_i - mid \times b_i\) 最大的 \(n-k\) 个对应的 \(w_i=1\),其它 \(w_i\) 为 \(0\)。
这样,我就能通过二分答案找到最终的结果了。
示例代码:
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstdlib>
const int maxn = 1010;
int n, k;
double a[maxn], b[maxn], c[maxn];
bool check(double mid) {
for (int i = 0; i < n; i ++) c[i] = a[i] - mid * b[i];
std::sort(c, c+n);
double ans = 0;
for (int i = 1; i <= n-k; i ++) ans += c[n-i];
return ans >= 0;
}
int main() {
while (~scanf("%d%d", &n, &k) && n) {
for (int i = 0; i < n; i ++) scanf("%lf", a+i);
for (int i = 0; i < n; i ++) scanf("%lf", b+i);
double L = 0, R = 1;
while (fabs(R - L) > 1e-4) {
double mid = (L + R) / 2.;
if (check(mid)) L = mid;
else R = mid;
}
printf("%d\n", (int) (100 * L + 0.5));
}
return 0;
}