[ABC257Ex] Dice Sum 2 题解

[ABC257Ex] Dice Sum 2 Solution

更好的阅读体验戳此进入

题面

存在 $ n $ 个正六面体骰子,第 $ i $ 个骰子六个面的数值分别为 $ A_{i, 1}, A_{i, 2}, \cdots, A_{i, 6} $,购买第 $ i $ 个骰子的花费为 $ C_i $。你要在其中购买 $ k $ 个骰子,以最大化收益的期望。定义收益为将购买的 $ k $ 个骰子各扔一遍,其朝上的数的和的平方减去买 $ k $ 个骰子花费的总费用。输出模 $ 998244353 $ 意义下的收益期望最大值。

Solution

果然难题的尽头是推式子。

首先令买的 $ k $ 个骰子为 $ A_1, A_2, \cdots, A_k $,不难根据期望定义列出式子:

\[E = \sum_{x_1 = 1}^6\sum_{x_2 = 1}^6\cdots\sum_{x_k = 1}^6 \dfrac{1}{6^k}\sum_{i = 1}^kA_{i, x_i} - \sum_{i = 1}^kC_i \]

然后考虑将中间的 $ \sum_{i = 1}^kA_{i, x_i} $ 转化一下,并去掉外面的一堆求和(关于这步,可以考虑钦定一个骰子为 $ A $ 时,剩下 $ k - 1 $ 个骰子可以任选,即 $ 6^{k - 1} $,而钦定两个骰子并钦定顺序之后则为 $ 6^{k - 2} $),即:

\[E = \dfrac{1}{6^k}(\sum_{i = 1}^k\sum_{x_i = 1}^66^{k - 1}A_{i, x_i}^2 + \sum_{i = 1}^k\sum_{j = 1}^k\sum_{x_i = 1}^6\sum_{x_j = 1}^66^{k - 2}A_{i, x_i}A_{j, x_j} \times [i \neq j]) - \sum_{i = 1}^kC_i \]

然后将前面的分式带进去并进一步化简:

\[E = \dfrac{1}{6}\sum_{i = 1}^k\sum_{x_i = 1}^6A_{i, x_i}^2 + \dfrac{1}{36}\sum_{i = 1}^k\sum_{j = 1}^k\sum_{x_i = 1}^6\sum_{x_j = 1}^6A_{i, x_i}A_{j, x_j} \times [i \neq j] - \sum_{i = 1}^kC_i \]

发现 $ i \neq j $ 难以处理,尝试通过类似容斥的思想,从平方中去除 $ i = j $ 的部分,化简为:

\[E = \dfrac{1}{6}\sum_{i = 1}^k\sum_{x_i = 1}^6A_{i, x_i}^2 + \dfrac{1}{36}(\sum_{i = 1}^k\sum_{x_i = 1}^6A_{i, x_i})^2 - \dfrac{1}{36}\sum_{i = 1}^k(\sum_{x_i = 1}^6A_{i, x_i})^2 - \sum_{i = 1}^kC_i \]

发现式子中出现了大量的 $ \sum_{x_i = 1}^6A_{i, x_i} $,考虑令:

\[X_i = \dfrac{1}{6}\sum_{x_i = 1}^6A_{i, x_i} \]

然后发现第二项可以转化为:

\[\dfrac{1}{36}(\sum_{i = 1}^k\sum_{x_i = 1}^6A_{i, x_i})^2 = (\sum_{i = 1}^kX_i)^2 \]

然后考虑一三项,同样尽量用 $ X_i $ 转化:

\[\begin{aligned} & \dfrac{1}{6}\sum_{i = 1}^k\sum_{x_i = 1}^6A_{i, x_i}^2 - \dfrac{1}{36}\sum_{i = 1}^k(\sum_{x_i = 1}^6A_{i, x_i})^2 - \sum_{i = 1}^kC_i \\ =& \sum_{i = 1}^k(\dfrac{1}{6}\sum_{x_i = 1}^6A_{i, x_i}^2 - X_i^2 - C_i) \end{aligned} \]

此时若我们令:

\[Y_i = \dfrac{1}{6}\sum_{x_i = 1}^6A_{i, x_i}^2 - X_i^2 - C_i \]

则原式转化为:

\[(\sum_{i = 1}^kX_i)^2 + \sum_{i = 1}^kY_i \]

写的更明显一点,我们就是要求:

\[\max\left\{ \left(\sum X\right)^2 + \sum Y \right\} \]

显然购买方案一共有 $ n \choose k $ 种,我们可以考虑将每一种购买方案抽象成二维平面中坐标为 $ (\sum X, \sum Y) $ 的点,记作 $ (x, y) $,那么我们也就是要在若干个点最大化 $ x^2 + y $。

显然我们是需要尽量使 $ \lvert x \rvert $ 和 $ y $ 都大一些,所以最终可能成为答案的点一定在点集组成的上凸包中。所以理论上我们可以通过枚举每个实数斜率,然后以该斜率的直线去切这个凸包,截距最大的即为凸包上的点。这里理论上截距应该是 $ -kx + y $,但是为了便于计算我们可以认为是一条斜率为 $ -k $ 的直线,那么截距即为 $ kx + y $,同时注意这里我们只求截距最大值是因为只需要维护上凸包。具体来说就是对于每个 $ k $ 找到最大的 $ k \sum X + \sum Y $,显然可以转化为对所有点求一次 $ kX + Y $,然后从中选择前 $ k $ 大(注意这里的 $ k $ 不是斜率的 $ k $,而是买的 $ k $ 个骰子)求和即可。

然后这里我们显然不能枚举实数,于是考虑什么时候骰子之间的大小关系会发生变化,考虑存在以下情况,当 $ k \rightarrow k' $ 使得 $ i, j $ 之间顺序变化时显然有以下式子(举个例子,大小关系相反亦可):

\[\begin{aligned} kX_i + Y_i \lt kX_j + Y_j \\ k'X_i + Y_i \gt k'X_j + Y_j \end{aligned} \]

显然对于这两种情况的转折点存在于:

\[k = \dfrac{Y_j - Y_i}{X_i - X_j} \]

并且此时仅有 $ i, j $ 之间的大小关系会改变。

所以我们 $ O(n^2) $ 枚举并预处理,然后遍历一下 $ k $,第一次排个序,后面的 $ O(1) $ 修改即可,最终复杂度卡在排序上,为 $ O(n^2 \log n^2) $,可以通过。

Tips:实现时为了便于排序,且 $ X, Y $ 远小于 long long 的范围,我们可以考虑按照 $ (36X, 36Y) $ 排序,这样就不可能有分数了,然后最后乘一个 $ 36 $ 的逆元即可。

Tips:还有一个我调了很久的细节,就是过程中不能取模,否则大小关系会变化,导致答案错误。

然后按照这个思路实现之后大概是这样:

#define _USE_MATH_DEFINES
#include <bits/stdc++.h>

#define PI M_PI
#define E M_E
#define npt nullptr
#define SON i->to
#define OPNEW void* operator new(size_t)
#define ROPNEW(arr) void* Edge::operator new(size_t){static Edge* P = arr; return P++;}

using namespace std;

mt19937 rnd(random_device{}());
int rndd(int l, int r){return rnd() % (r - l + 1) + l;}
bool rnddd(int x){return rndd(1, 100) <= x;}

typedef unsigned int uint;
typedef unsigned long long unll;
typedef long long ll;
typedef long double ld;

#define MOD (998244353ll)

template < typename T = int >
inline T read(void);

ll inv36;
int N, K;
ll C[1100];
ll A[1100][10];
ll X[1100], Y[1100];
struct MyPair{int first, second;};
map < ld, basic_string < MyPair > > mp;
bitset < 1100 > inAns;
ll ansX(0), ansY(0);
ll ans(LONG_LONG_MIN);
int idx[1100];

ll qpow(ll a, ll b){
    ll ret(1), mul(a);
    while(b){
        if(b & 1)ret = ret * mul % MOD;
        b >>= 1;
        mul = mul * mul % MOD;
    }return ret;
}

int main(){
    inv36 = qpow(36ll, MOD - 2);
    N = read(), K = read();
    for(int i = 1; i <= N; ++i)C[i] = read(), idx[i] = i;
    for(int i = 1; i <= N; ++i){
        for(int j = 1; j <= 6; ++j)
            A[i][j] = read(),
            X[i] += A[i][j], //6Xi
            Y[i] += 6 * A[i][j] * A[i][j]; //36Yi
        Y[i] -= X[i] * X[i] + 36 * C[i];
    }
    for(int i = 1; i <= N; ++i)for(int j = i + 1; j <= N; ++j)
        mp[(ld)(Y[j] - Y[i]) / (ld)(X[i] - X[j])] += MyPair{i, j};
    ld k = mp.begin()->first - 0.1;
    sort(idx + 1, idx + N + 1, [=](const int &a, const int &b)->bool{return k * (ld)X[a] + (ld)Y[a] > k * (ld)X[b] + (ld)Y[b];});
    for(int i = 1; i <= K; ++i)
        ansX += X[idx[i]], ansY += Y[idx[i]], inAns[idx[i]] = true;
    ans = max(ans, ansX * ansX + ansY);
    for(auto mdf : mp){
        for(auto Pair : mdf.second){
            int l = Pair.first, r = Pair.second;
            if(inAns[l] ^ inAns[r]){
                if(!inAns[l] && inAns[r])swap(l, r);
                inAns[l] = false, inAns[r] = true;
                ansX += -X[l] + X[r], ansY += -Y[l] + Y[r];
            }
        }ans = max(ans, ansX * ansX + ansY);
    }printf("%lld\n", (ans % MOD * inv36 % MOD + MOD) % MOD);
    fprintf(stderr, "Time: %.6lf\n", (double)clock() / CLOCKS_PER_SEC);
    return 0;
}

template < typename T >
inline T read(void){
    T ret(0);
    int flag(1);
    char c = getchar();
    while(c != '-' && !isdigit(c))c = getchar();
    if(c == '-')flag = -1, c = getchar();
    while(isdigit(c)){
        ret *= 10;
        ret += int(c - '0');
        c = getchar();
    }
    ret *= flag;
    return ret;
}

看起来没什么问题,但是交上去就会发现错了很多点,感觉可能是精度之类的问题,于是我们考虑去掉所有浮点数相关运算。

首先对于初始的排序,不难想到若将 $ k $ 升序,那么我们可以认为初始 $ k = -\infty $,则 $ Y $ 对大小关系影响忽略不计,仅比较 $ X $ 即可。

然后对于修改之间的排序把式子推一下换成乘法即可。还有个细节就是我们在修改的时候应仅保留一些满足 $ X $ 偏序关系的,可以大概理解为保证我们变的这个 $ k $ 是单调的,当然这个点我也没有完全弄明白,如果有大佬知道的话欢迎解惑。

Code

#define _USE_MATH_DEFINES
#include <bits/stdc++.h>

#define PI M_PI
#define E M_E
#define npt nullptr
#define SON i->to
#define OPNEW void* operator new(size_t)
#define ROPNEW(arr) void* Edge::operator new(size_t){static Edge* P = arr; return P++;}

using namespace std;

mt19937 rnd(random_device{}());
int rndd(int l, int r){return rnd() % (r - l + 1) + l;}
bool rnddd(int x){return rndd(1, 100) <= x;}

typedef unsigned int uint;
typedef unsigned long long unll;
typedef long long ll;
typedef long double ld;

#define MOD (998244353ll)

template < typename T = int >
inline T read(void);

ll inv36;
int N, K;
ll C[1100];
ll A[1100][10];
ll X[1100], Y[1100];
struct Node{
    int l, r;
    friend const bool operator < (const Node &a, const Node &b){
        return (Y[a.r] - Y[a.l]) * (X[b.l] - X[b.r]) < (Y[b.r] - Y[b.l]) * (X[a.l] - X[a.r]);
    }
};
basic_string < Node > mdfs;
multimap < ld, pair < int, int > > mp;
bitset < 1100 > inAns;
ll ansX(0), ansY(0);
ll ans(LONG_LONG_MIN);
int idx[1100];

ll qpow(ll a, ll b){
    ll ret(1), mul(a);
    while(b){
        if(b & 1)ret = ret * mul % MOD;
        b >>= 1;
        mul = mul * mul % MOD;
    }return ret;
}

int main(){
    // freopen("in.txt", "r", stdin);
    inv36 = qpow(36ll, MOD - 2);
    N = read(), K = read();
    for(int i = 1; i <= N; ++i)C[i] = read(), idx[i] = i;
    for(int i = 1; i <= N; ++i){
        for(int j = 1; j <= 6; ++j)
            A[i][j] = read(),
            X[i] += A[i][j], //6Xi
            Y[i] += 6 * A[i][j] * A[i][j]; //36Yi
        Y[i] -= X[i] * X[i] + 36 * C[i];
    }
    for(int i = 1; i <= N; ++i)for(int j = 1; j <= N; ++j)if(X[i] < X[j])mdfs += Node{i, j};
    sort(mdfs.begin(), mdfs.end());
    sort(idx + 1, idx + N + 1, [](const int &a, const int &b)->bool{return X[a] < X[b];});
    for(int i = 1; i <= K; ++i)
        ansX += X[idx[i]], ansY += Y[idx[i]], inAns[idx[i]] = true;
    ans = max(ans, ansX * ansX + ansY);
    for(auto mdf : mdfs)
        if(inAns[mdf.l] && !inAns[mdf.r])
            inAns[mdf.l] = false, inAns[mdf.r] = true,
            ansX += -X[mdf.l] + X[mdf.r], ansY += -Y[mdf.l] + Y[mdf.r],
            ans = max(ans, ansX * ansX + ansY);
    printf("%lld\n", (ans % MOD * inv36 % MOD + MOD) % MOD);
    fprintf(stderr, "Time: %.6lf\n", (double)clock() / CLOCKS_PER_SEC);
    return 0;
}

template < typename T >
inline T read(void){
    T ret(0);
    int flag(1);
    char c = getchar();
    while(c != '-' && !isdigit(c))c = getchar();
    if(c == '-')flag = -1, c = getchar();
    while(isdigit(c)){
        ret *= 10;
        ret += int(c - '0');
        c = getchar();
    }
    ret *= flag;
    return ret;
}

UPD

update-2022_12_19 初稿

posted @ 2023-01-07 15:52  Tsawke  阅读(43)  评论(0编辑  收藏  举报