「NOI2020」制作菜品(bitset+背包DP)
Address
Solution
今天上计程设,真的讲了 Hello world 和 a+b。
至于我为什么要上这个课,因为它在化学专业的培养方案里面。
言归正传。
不会做这题的建议先看看下发文件中的 dish2.in
和 dish2.ans
。
打开这个样例,发现只有一个 \(m=n-1\) 的,而 \(m=n-1\) 正好有部分分。
dish2.ans
给出的做法是:每次找 \(\max\) 和 \(\min\),把 \(\min\) 全拿走,\(\max\) 拿走 \(k-\min\)。
对着这题想了好久,一点思路都没有,看到一个 \(m=n-1\) 的构造方案,还不赶紧写下来,说不定能骗到分。
发现这个样例里面还有一个 \(m=n\) 的,而 \(m\ge n\) 也有部分分。
dish2.ans
里面有唯一一行输出两个整数的,看看它拿走的是哪个 \(d_i\),显然是 \(\max\)。
如果一开始就把 \(\max\) 减去 \(k\),那就相当于 \(m\) 减 \(1\),\(n\) 不变,就转化成了 \(m=n-1\) 的情况。
那 \(m\ge n\) 是不是都可以用这个做法转化为 \(m=n-1\)?
\(m\ge n\) 的部分分让你更加坚定地相信这个做法的正确性。
要想证明它正确,最好要证明 \(m=n-1\) 必定有解。回顾 \(m=n-1\) 的做法:\(\min\) 拿掉,\(\max\) 拿 \(k-\min\)。
什么时候做不下去了?\(\max+\min<k\)。会有这种可能出现吗?
假设 \(\max+\min<k\),由于总和是 \((n-1)k\),那剩下的 \(n-2\) 个加起来必须大于 \((n-2)k\),那么这 \(n-2\) 个里面肯定有 \(d_i>k\) 的,跟 \(\max<k\) 矛盾。
那 \(m\ge n-1\) 都做完了。观察数据范围,就剩 \(m=n-2\) 了。
还是考虑沿用 \(m=n-1\) 的做法,把这 \(n\) 个分成两组,每组都是 \(m=n-1\)。说形象点就是,一棵树断掉一条边,变成边数之和为 \(n-2\) 的两棵树。
什么样的分组是合法的?设分出来的其中一组 \(d_i\) 之和为 \(s\),个数为 \(c\),则 \(s=(c-1)k\)。也就是 \(\sum k-d_i=k\)。
设 \(f_{i,j}\) 表示前 \(i\) 个中选若干个,是否能让 \(k-d\) 之和为 \(j\),用 bitset 可以算。
如果 \(f_{n,k}=1\),那么有合法方案,否则没有。
构造方案(\(bo=1\) 的全部放到一组,剩下的放到另一组):
int sum = L + k; // DP数组下标可能是负数,所以整体移位
if (!f[x][sum])
{
puts("-1");
return;
}
memset(bo, 0, sizeof(bo));
for (;;)
{
bo[x] = 1;
int i, v = k - d[x];
sum -= v;
for (i = 0; i < x; i++) // 注意要找最小的i
if (d[i] && f[i][sum])
{
x = i;
break;
}
if (x == 0) break;
}
然后就过了。
所以这题的考点就是打开 dish2.in
和 dish2.ans
,没别的。
至于为什么找不到分组方案就无解,我也不知道怎么严谨证明。我只知道明天早八有机,今晚要早点睡,溜了。
Code
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
template <class t>
inline void read(t & res)
{
char ch;
while (ch = getchar(), !isdigit(ch));
res = ch ^ 48;
while (ch = getchar(), isdigit(ch))
res = res * 10 + (ch ^ 48);
}
const int e = 1e5 + 15, P = 5e6 + 15;
struct point
{
int p1, d1, p2, d2;
point(){}
point(int _p1, int _d1, int _p2, int _d2) :
p1(_p1), d1(_d1), p2(_p2), d2(_d2) {}
}a[e];
int d[e], T, n, m, k, tot, c[e];
bool bo[e];
bitset<3655555>f[505];
vector<int>g1, g2;
inline int plu(int x, int y)
{
(x += y) >= k && (x -= k);
return x;
}
inline int sub(int x, int y)
{
(x -= y) < 0 && (x += k);
return x;
}
inline void work1()
{
int i, j;
while (m--)
{
int p1 = -1, p2 = n + 1;
for (i = 1; i <= n; i++)
if (d[i])
{
if (d[i] > d[p1]) p1 = i;
if (d[i] < d[p2]) p2 = i;
}
if (p1 == p2)
{
a[++tot] = point(p1, d[p1], 0, 0);
d[p1] = 0;
}
else
{
if (d[p1] + d[p2] >= k)
{
int x = d[p2];
d[p2] -= x;
d[p1] -= k - x;
a[++tot] = point(p2, x, p1, k - x);
}
else
{
puts("-1");
return;
}
}
}
for (i = 1; i <= tot; i++)
if (a[i].p2) printf("%d %d %d %d\n", a[i].p1, a[i].d1, a[i].p2, a[i].d2);
else printf("%d %d\n", a[i].p1, a[i].d1);
}
inline void work2()
{
int i, j, rn = n, rm = m;
for (;;)
{
for (i = 1; i <= n; i++)
if (d[i] == k)
{
d[i] = 0;
rn--;
rm--;
a[++tot] = point(i, k, 0, 0);
break;
}
int p = 0;
for (i = 1; i <= n; i++)
if (d[i] && d[i] > d[p]) p = i;
if (p)
{
d[p] -= k;
a[++tot] = point(p, k, 0, 0);
rm--;
}
if (!rm || rm == rn - 1) break;
}
m = rm;
work1();
}
inline void solve(vector<int>g)
{
int i, len = g.size(), rm = len - 1;
while (rm--)
{
int p1 = 0, p2 = n + 1;
for (i = 0; i < len; i++)
{
int pos = g[i];
if (d[pos])
{
if (d[pos] > d[p1]) p1 = pos;
if (d[pos] <= d[p2]) p2 = pos;
}
}
if (p1 == p2)
{
a[++tot] = point(p1, d[p1], 0, 0);
d[p1] = 0;
}
else
{
int x = d[p2];
d[p2] -= x;
d[p1] -= k - x;
a[++tot] = point(p2, x, p1, k - x);
}
}
}
inline void work3()
{
int i, j, x = 0, rn = n, rm = m;
for (i = 1; i <= n; i++)
if (d[i] % k == 0)
{
while (d[i])
{
d[i] -= k;
a[++tot] = point(i, k, 0, 0);
rn--; rm--;
}
}
g1.clear();
g2.clear();
int L = 0;
for (i = 0; i <= n; i++) f[i].reset();
for (i = 1; i <= n; i++)
if (d[i])
{
int v = k - d[i];
if (v < 0) L += v;
}
L = -L;
f[0][L] = 1;
for (i = 1; i <= n; i++)
if (d[i])
{
int v = k - d[i];
if (v >= 0) f[i] = f[x] | (f[x] << v);
else f[i] = f[x] | (f[x] >> (-v));
x = i;
}
int sum = L + k;
if (!f[x][sum])
{
puts("-1");
return;
}
memset(bo, 0, sizeof(bo));
for (;;)
{
bo[x] = 1;
int i, v = k - d[x];
sum -= v;
for (i = 0; i < x; i++)
if (d[i] && f[i][sum])
{
x = i;
break;
}
if (x == 0) break;
}
for (i = 1; i <= n; i++)
if (d[i])
{
if (bo[i]) g1.pb(i);
else g2.pb(i);
}
solve(g1);
solve(g2);
for (i = 1; i <= tot; i++)
if (a[i].p2) printf("%d %d %d %d\n", a[i].p1, a[i].d1, a[i].p2, a[i].d2);
else printf("%d %d\n", a[i].p1, a[i].d1);
}
int main()
{
freopen("dish.in", "r", stdin);
freopen("dish.out", "w", stdout);
read(T);
int i;
while (T--)
{
tot = 0;
read(n); read(m); read(k);
d[0] = -1; d[n + 1] = 1e9;
for (i = 1; i <= n; i++) read(d[i]);
if (m == n - 1) work1();
else if (m >= n) work2();
else work3();
}
fclose(stdin);
fclose(stdout);
return 0;
}