[ABC257Ex] Dice Sum 2 题解
[ABC257Ex] Dice Sum 2 Solution
更好的阅读体验戳此进入
题面
存在 $ n $ 个正六面体骰子,第 $ i $ 个骰子六个面的数值分别为 $ A_{i, 1}, A_{i, 2}, \cdots, A_{i, 6} $,购买第 $ i $ 个骰子的花费为 $ C_i $。你要在其中购买 $ k $ 个骰子,以最大化收益的期望。定义收益为将购买的 $ k $ 个骰子各扔一遍,其朝上的数的和的平方减去买 $ k $ 个骰子花费的总费用。输出模 $ 998244353 $ 意义下的收益期望最大值。
Solution
果然难题的尽头是推式子。
首先令买的 $ k $ 个骰子为 $ A_1, A_2, \cdots, A_k $,不难根据期望定义列出式子:
然后考虑将中间的 $ \sum_{i = 1}^kA_{i, x_i} $ 转化一下,并去掉外面的一堆求和(关于这步,可以考虑钦定一个骰子为 $ A $ 时,剩下 $ k - 1 $ 个骰子可以任选,即 $ 6^{k - 1} $,而钦定两个骰子并钦定顺序之后则为 $ 6^{k - 2} $),即:
然后将前面的分式带进去并进一步化简:
发现 $ i \neq j $ 难以处理,尝试通过类似容斥的思想,从平方中去除 $ i = j $ 的部分,化简为:
发现式子中出现了大量的 $ \sum_{x_i = 1}^6A_{i, x_i} $,考虑令:
然后发现第二项可以转化为:
然后考虑一三项,同样尽量用 $ X_i $ 转化:
此时若我们令:
则原式转化为:
写的更明显一点,我们就是要求:
显然购买方案一共有 $ n \choose k $ 种,我们可以考虑将每一种购买方案抽象成二维平面中坐标为 $ (\sum X, \sum Y) $ 的点,记作 $ (x, y) $,那么我们也就是要在若干个点最大化 $ x^2 + y $。
显然我们是需要尽量使 $ \lvert x \rvert $ 和 $ y $ 都大一些,所以最终可能成为答案的点一定在点集组成的上凸包中。所以理论上我们可以通过枚举每个实数斜率,然后以该斜率的直线去切这个凸包,截距最大的即为凸包上的点。这里理论上截距应该是 $ -kx + y $,但是为了便于计算我们可以认为是一条斜率为 $ -k $ 的直线,那么截距即为 $ kx + y $,同时注意这里我们只求截距最大值是因为只需要维护上凸包。具体来说就是对于每个 $ k $ 找到最大的 $ k \sum X + \sum Y $,显然可以转化为对所有点求一次 $ kX + Y $,然后从中选择前 $ k $ 大(注意这里的 $ k $ 不是斜率的 $ k $,而是买的 $ k $ 个骰子)求和即可。
然后这里我们显然不能枚举实数,于是考虑什么时候骰子之间的大小关系会发生变化,考虑存在以下情况,当 $ k \rightarrow k' $ 使得 $ i, j $ 之间顺序变化时显然有以下式子(举个例子,大小关系相反亦可):
显然对于这两种情况的转折点存在于:
并且此时仅有 $ i, j $ 之间的大小关系会改变。
所以我们 $ O(n^2) $ 枚举并预处理,然后遍历一下 $ k $,第一次排个序,后面的 $ O(1) $ 修改即可,最终复杂度卡在排序上,为 $ O(n^2 \log n^2) $,可以通过。
Tips:实现时为了便于排序,且 $ X, Y $ 远小于 long long
的范围,我们可以考虑按照 $ (36X, 36Y) $ 排序,这样就不可能有分数了,然后最后乘一个 $ 36 $ 的逆元即可。
Tips:还有一个我调了很久的细节,就是过程中不能取模,否则大小关系会变化,导致答案错误。
然后按照这个思路实现之后大概是这样:
#define _USE_MATH_DEFINES
#include <bits/stdc++.h>
#define PI M_PI
#define E M_E
#define npt nullptr
#define SON i->to
#define OPNEW void* operator new(size_t)
#define ROPNEW(arr) void* Edge::operator new(size_t){static Edge* P = arr; return P++;}
using namespace std;
mt19937 rnd(random_device{}());
int rndd(int l, int r){return rnd() % (r - l + 1) + l;}
bool rnddd(int x){return rndd(1, 100) <= x;}
typedef unsigned int uint;
typedef unsigned long long unll;
typedef long long ll;
typedef long double ld;
#define MOD (998244353ll)
template < typename T = int >
inline T read(void);
ll inv36;
int N, K;
ll C[1100];
ll A[1100][10];
ll X[1100], Y[1100];
struct MyPair{int first, second;};
map < ld, basic_string < MyPair > > mp;
bitset < 1100 > inAns;
ll ansX(0), ansY(0);
ll ans(LONG_LONG_MIN);
int idx[1100];
ll qpow(ll a, ll b){
ll ret(1), mul(a);
while(b){
if(b & 1)ret = ret * mul % MOD;
b >>= 1;
mul = mul * mul % MOD;
}return ret;
}
int main(){
inv36 = qpow(36ll, MOD - 2);
N = read(), K = read();
for(int i = 1; i <= N; ++i)C[i] = read(), idx[i] = i;
for(int i = 1; i <= N; ++i){
for(int j = 1; j <= 6; ++j)
A[i][j] = read(),
X[i] += A[i][j], //6Xi
Y[i] += 6 * A[i][j] * A[i][j]; //36Yi
Y[i] -= X[i] * X[i] + 36 * C[i];
}
for(int i = 1; i <= N; ++i)for(int j = i + 1; j <= N; ++j)
mp[(ld)(Y[j] - Y[i]) / (ld)(X[i] - X[j])] += MyPair{i, j};
ld k = mp.begin()->first - 0.1;
sort(idx + 1, idx + N + 1, [=](const int &a, const int &b)->bool{return k * (ld)X[a] + (ld)Y[a] > k * (ld)X[b] + (ld)Y[b];});
for(int i = 1; i <= K; ++i)
ansX += X[idx[i]], ansY += Y[idx[i]], inAns[idx[i]] = true;
ans = max(ans, ansX * ansX + ansY);
for(auto mdf : mp){
for(auto Pair : mdf.second){
int l = Pair.first, r = Pair.second;
if(inAns[l] ^ inAns[r]){
if(!inAns[l] && inAns[r])swap(l, r);
inAns[l] = false, inAns[r] = true;
ansX += -X[l] + X[r], ansY += -Y[l] + Y[r];
}
}ans = max(ans, ansX * ansX + ansY);
}printf("%lld\n", (ans % MOD * inv36 % MOD + MOD) % MOD);
fprintf(stderr, "Time: %.6lf\n", (double)clock() / CLOCKS_PER_SEC);
return 0;
}
template < typename T >
inline T read(void){
T ret(0);
int flag(1);
char c = getchar();
while(c != '-' && !isdigit(c))c = getchar();
if(c == '-')flag = -1, c = getchar();
while(isdigit(c)){
ret *= 10;
ret += int(c - '0');
c = getchar();
}
ret *= flag;
return ret;
}
看起来没什么问题,但是交上去就会发现错了很多点,感觉可能是精度之类的问题,于是我们考虑去掉所有浮点数相关运算。
首先对于初始的排序,不难想到若将 $ k $ 升序,那么我们可以认为初始 $ k = -\infty $,则 $ Y $ 对大小关系影响忽略不计,仅比较 $ X $ 即可。
然后对于修改之间的排序把式子推一下换成乘法即可。还有个细节就是我们在修改的时候应仅保留一些满足 $ X $ 偏序关系的,可以大概理解为保证我们变的这个 $ k $ 是单调的,当然这个点我也没有完全弄明白,如果有大佬知道的话欢迎解惑。
Code
#define _USE_MATH_DEFINES
#include <bits/stdc++.h>
#define PI M_PI
#define E M_E
#define npt nullptr
#define SON i->to
#define OPNEW void* operator new(size_t)
#define ROPNEW(arr) void* Edge::operator new(size_t){static Edge* P = arr; return P++;}
using namespace std;
mt19937 rnd(random_device{}());
int rndd(int l, int r){return rnd() % (r - l + 1) + l;}
bool rnddd(int x){return rndd(1, 100) <= x;}
typedef unsigned int uint;
typedef unsigned long long unll;
typedef long long ll;
typedef long double ld;
#define MOD (998244353ll)
template < typename T = int >
inline T read(void);
ll inv36;
int N, K;
ll C[1100];
ll A[1100][10];
ll X[1100], Y[1100];
struct Node{
int l, r;
friend const bool operator < (const Node &a, const Node &b){
return (Y[a.r] - Y[a.l]) * (X[b.l] - X[b.r]) < (Y[b.r] - Y[b.l]) * (X[a.l] - X[a.r]);
}
};
basic_string < Node > mdfs;
multimap < ld, pair < int, int > > mp;
bitset < 1100 > inAns;
ll ansX(0), ansY(0);
ll ans(LONG_LONG_MIN);
int idx[1100];
ll qpow(ll a, ll b){
ll ret(1), mul(a);
while(b){
if(b & 1)ret = ret * mul % MOD;
b >>= 1;
mul = mul * mul % MOD;
}return ret;
}
int main(){
// freopen("in.txt", "r", stdin);
inv36 = qpow(36ll, MOD - 2);
N = read(), K = read();
for(int i = 1; i <= N; ++i)C[i] = read(), idx[i] = i;
for(int i = 1; i <= N; ++i){
for(int j = 1; j <= 6; ++j)
A[i][j] = read(),
X[i] += A[i][j], //6Xi
Y[i] += 6 * A[i][j] * A[i][j]; //36Yi
Y[i] -= X[i] * X[i] + 36 * C[i];
}
for(int i = 1; i <= N; ++i)for(int j = 1; j <= N; ++j)if(X[i] < X[j])mdfs += Node{i, j};
sort(mdfs.begin(), mdfs.end());
sort(idx + 1, idx + N + 1, [](const int &a, const int &b)->bool{return X[a] < X[b];});
for(int i = 1; i <= K; ++i)
ansX += X[idx[i]], ansY += Y[idx[i]], inAns[idx[i]] = true;
ans = max(ans, ansX * ansX + ansY);
for(auto mdf : mdfs)
if(inAns[mdf.l] && !inAns[mdf.r])
inAns[mdf.l] = false, inAns[mdf.r] = true,
ansX += -X[mdf.l] + X[mdf.r], ansY += -Y[mdf.l] + Y[mdf.r],
ans = max(ans, ansX * ansX + ansY);
printf("%lld\n", (ans % MOD * inv36 % MOD + MOD) % MOD);
fprintf(stderr, "Time: %.6lf\n", (double)clock() / CLOCKS_PER_SEC);
return 0;
}
template < typename T >
inline T read(void){
T ret(0);
int flag(1);
char c = getchar();
while(c != '-' && !isdigit(c))c = getchar();
if(c == '-')flag = -1, c = getchar();
while(isdigit(c)){
ret *= 10;
ret += int(c - '0');
c = getchar();
}
ret *= flag;
return ret;
}
UPD
update-2022_12_19 初稿