[IOI2020]装饼干
这题我的方法比较奇怪。
题意:
有\(k\)种物品,第\(i\)个物品有\(a_i\)个,权值为\(2^i\)。
求有多少个\(y\),使得可以选出\(x\)组物品,每组的和都为\(y\)。
先考虑如何判定一个\(y\)是否可行:
从最高位开始,依次求出第i位需要的数目\(b_i\)。若\(y\)的第\(i\)位为1,则\(b\leftarrow b+x\)。
如果\(b_i \leq a_i\),那么说明\(a_i\)够用,进入下一位。
如果\(b_i>a_i\),则选上所有\(a_i\)后,还剩\(b_i-a_i\)个。那么这些\(2^i\)只能用两倍的\(2^{i-1}\)来凑。因此把\(b_{i-1}\)加上\(2 \times (b_i-a_i)\)。
这样,只要\(b_0<a_0\),则说明这个\(y\)可行。
可以发现,若\(b_i \leq a_i\),那么对于所有\(j<i\),\(j\)不会受之前的影响。
因此,可以把\(b_i \leq a_i\)作为分界点,进行DP。
设\(dp_i\)表示使得\(b_i \leq a_i\)的方案数目。只考虑大于等于\(i\)的位。那么,\(dp_0\)就是答案。
枚举\(j>i\)作为前一个分界点。那么,对于所有\(i<c<j\),都要求\(b_c>a_c\)。
再枚举所有的\(c\),那么\(b_c\)就可以很容易地用这个\(y\)在\([c,j-1]\)这些数位上的值来表示。
于是,对于每个\(c\),可以得出一条不等式。把这些不等式联立,就能得到\(y\)在\([i,j)\)这些数位上的范围\(d_{i,j}\)。
因此,\(dp_i\leftarrow dp_i+d_{i,j}\times dp_j\)。
这样做的复杂度是\(O(k^3q)\)的,可以过。
代码:
#include <stdio.h>
#include <vector>
#include "biscuits.h"
using namespace std;
#define ll long long
ll dp[70];
ll count_tastiness(ll x, vector<ll> sz) {
int k = sz.size();
for (int i = 0; i < 62 - k; i++) sz.push_back(0);
k = 62;
for (int i = k; i >= 0; i--) {
if (i == k) {
dp[i] = 1;
continue;
}
dp[i] = 0;
for (int j = i + 1; j <= k; j++) {
ll zx = 0, zd = (1ll << (j - i)) - 1, h = 0;
for (int a = j - 1; a >= i; a--) {
h = h * 2 + sz[a];
ll z = (h / x + 1) << (a - i);
if (a > i) {
if (z > zx)
zx = z;
} else {
if (z - 1 < zd)
zd = z - 1;
}
}
if (zx <= zd)
dp[i] += dp[j] * (zd - zx + 1);
}
}
return dp[0];
}
不难发现,求\(d_{i,j}\)的过程可以优化。提前预处理出\(d_{i,j}\),就可以\(O(k^2)\)了。
#include <stdio.h>
#include <vector>
#include "biscuits.h"
using namespace std;
#define ll long long
ll dp[70], zz[70][70], dd[70][70];
ll count_tastiness(ll x, vector<ll> sz) {
int k = sz.size();
for (int i = 0; i < 62 - k; i++) sz.push_back(0);
for (int j = 1; j <= 62; j++) {
ll zx = 0, h = 0;
for (int a = j - 1; a >= 0; a--) {
zz[a][j] = zx;
h = h * 2 + sz[a];
ll z = (h / x + 1) << a;
if (z > zx)
zx = z;
dd[a][j] = z;
}
}
for (int i = 62; i >= 0; i--) {
if (i == 62) {
dp[i] = 1;
continue;
}
dp[i] = 0;
for (int j = i + 1; j <= 62; j++) {
ll zx = (zz[i][j] >> i), zd = (1ll << (j - i));
if ((dd[i][j] >> i) < zd)
zd = (dd[i][j] >> i);
if (zx <= zd)
dp[i] += dp[j] * (zd - zx);
}
}
return dp[0];
}