E. MEXimize the Score

首先观察到性质是贡献肯定是cnt[0]+min(cnt[0],cnt[1])+min(cnt[0],cnt[1],cnt[2])....这样的
设状态为f[i][j],表示在选到数i时有j个数i产生了贡献的方法数
转移分为两种情况,一种是数i的数量限制,一种是f[i-1][j]限制(分类非常重要
那么f[i][j]对ans的贡献是包含这种选法的子序列数量,就相当于把序列前一部分选好了,后面任取,所以是
(j是该方案贡献的分数)
\(f[i][j]\times 2^{(i+1)+...+(n-1)}\times j\)

总结一下这题分为三个步骤,第一步观察,第二步把贡献拆分成两部分去求,先dp求出f[i][j]的贡献,再求子序列种类贡献,分类非常重要
然后用vector写这种题很容易越界,要判断好

#include <bits/stdc++.h>
using namespace std;
#define lowbit(x) (x & (-x))
#define pii pair<int, int>
#define mkp make_pair
#define LL long long
#define int long long
#define endl '\n'

const int N = 2e5 + 10, mod = 998244353;

int n, cnt[N], pinv[N], pre[N], fcnt[N];

int qpow(int x, int y)
{
    int res = 1;
    while (y)
    {
        if (y & 1)
            res = 1ll * x * res % mod;
        x = 1ll * x * x % mod;
        y >>= 1;
    }
    return res;
}
int inv(int x)
{
    return qpow(x, mod - 2);
}
int C(int x, int y)
{
    if(y < x || x < 0) return 0;
    return 1ll * pre[y] * pinv[x] % mod * pinv[y - x] % mod;
}
int add(int x, int y)
{
    return (x + y >= mod) ? (x + y - mod) : (x + y);
}
int sub(int x, int y)
{
    return (x - y < 0) ? (x - y + mod) : (x - y);
}
int mul(int x, int y)
{
    return 1ll * x * y % mod;
}
void solve()
{
    cin >> n;
    vector<int> f(n + 10, 0);
    for (int i = 1; i <= n; i++)
    {
        int x;
        cin >> x;
        cnt[x]++;
    }
    fcnt[n + 1] = fcnt[n] = 0;
    for (int i = n - 1; i >= 1; i--)
        fcnt[i] = fcnt[i + 1] + cnt[i];
    LL ans = 0;
    for (int i = 1; i <= cnt[0]; i++)
    {
        f[i] = C(i, cnt[0]);
        ans = add(ans, mul(f[i], mul(i, qpow(2, fcnt[1]))));
    }
    for (int i = 1; i < n; i++)
    {
        vector<int> fx(cnt[i] + 5, 0);
        vector<int> suf(cnt[i] + 5, 0);
        vector<int> fsuf(cnt[i - 1] + 5, 0);
        for (int j = cnt[i]; j >= 0; j--)
            suf[j] = add(suf[j + 1], C(j, cnt[i]));
        for (int j = cnt[i - 1]; j >= 1; j--)
            fsuf[j] = add(fsuf[j + 1], f[j]);
        for (int j = 1; j <= cnt[i]; j++)
        {
            if(j <= cnt[i - 1]) fx[j] = mul(f[j], suf[j]);
            if(j + 1 <= cnt[i - 1]) fx[j] = add (fx[j], mul(fsuf[j + 1], C(j, cnt[i])));
            ans = add(ans, mul(fx[j], mul(j, qpow(2, fcnt[i + 1]))));
            // cout << suf[j] << ' ' << mul(fsuf[j + 1], C(j, cnt[i])) << ' ' << fx[j] << ' ' << endl;
        }
        // cout << i <<endl;
        swap(fx, f);
    }
    cout << ans << endl;
    // clear
    for (int i = 0; i <= n; i++)
        cnt[i] = 0;
}

signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    pre[0] = pinv[0] = 1;
    for (int i = 1; i <= N - 5; i++)
        pre[i] = 1ll * pre[i - 1] * i % mod;
    pinv[N - 6] = inv(pre[N - 6]);
    for (int i = N - 7; i >= 0; i--)
        pinv[i] = 1ll * pinv[i + 1] * (i + 1) % mod;
    // cout<<pinv[2]*2%mod<<endl;

    int T = 1;
    cin >> T;
    while (T--)
        solve();
}
/*
4
3
0 0 1
4
0 0 1 1
5
0 0 1 2 2
4
1 1 1 1
*/
posted @ 2024-11-17 14:41  lyrrr  阅读(5)  评论(0编辑  收藏  举报