E. MEXimize the Score
首先观察到性质是贡献肯定是cnt[0]+min(cnt[0],cnt[1])+min(cnt[0],cnt[1],cnt[2])....这样的
设状态为f[i][j],表示在选到数i时有j个数i产生了贡献的方法数
转移分为两种情况,一种是数i的数量限制,一种是f[i-1][j]限制(分类非常重要
那么f[i][j]对ans的贡献是包含这种选法的子序列数量,就相当于把序列前一部分选好了,后面任取,所以是
(j是该方案贡献的分数)
\(f[i][j]\times 2^{(i+1)+...+(n-1)}\times j\)
总结一下这题分为三个步骤,第一步观察,第二步把贡献拆分成两部分去求,先dp求出f[i][j]的贡献,再求子序列种类贡献,分类非常重要
然后用vector写这种题很容易越界,要判断好
#include <bits/stdc++.h>
using namespace std;
#define lowbit(x) (x & (-x))
#define pii pair<int, int>
#define mkp make_pair
#define LL long long
#define int long long
#define endl '\n'
const int N = 2e5 + 10, mod = 998244353;
int n, cnt[N], pinv[N], pre[N], fcnt[N];
int qpow(int x, int y)
{
int res = 1;
while (y)
{
if (y & 1)
res = 1ll * x * res % mod;
x = 1ll * x * x % mod;
y >>= 1;
}
return res;
}
int inv(int x)
{
return qpow(x, mod - 2);
}
int C(int x, int y)
{
if(y < x || x < 0) return 0;
return 1ll * pre[y] * pinv[x] % mod * pinv[y - x] % mod;
}
int add(int x, int y)
{
return (x + y >= mod) ? (x + y - mod) : (x + y);
}
int sub(int x, int y)
{
return (x - y < 0) ? (x - y + mod) : (x - y);
}
int mul(int x, int y)
{
return 1ll * x * y % mod;
}
void solve()
{
cin >> n;
vector<int> f(n + 10, 0);
for (int i = 1; i <= n; i++)
{
int x;
cin >> x;
cnt[x]++;
}
fcnt[n + 1] = fcnt[n] = 0;
for (int i = n - 1; i >= 1; i--)
fcnt[i] = fcnt[i + 1] + cnt[i];
LL ans = 0;
for (int i = 1; i <= cnt[0]; i++)
{
f[i] = C(i, cnt[0]);
ans = add(ans, mul(f[i], mul(i, qpow(2, fcnt[1]))));
}
for (int i = 1; i < n; i++)
{
vector<int> fx(cnt[i] + 5, 0);
vector<int> suf(cnt[i] + 5, 0);
vector<int> fsuf(cnt[i - 1] + 5, 0);
for (int j = cnt[i]; j >= 0; j--)
suf[j] = add(suf[j + 1], C(j, cnt[i]));
for (int j = cnt[i - 1]; j >= 1; j--)
fsuf[j] = add(fsuf[j + 1], f[j]);
for (int j = 1; j <= cnt[i]; j++)
{
if(j <= cnt[i - 1]) fx[j] = mul(f[j], suf[j]);
if(j + 1 <= cnt[i - 1]) fx[j] = add (fx[j], mul(fsuf[j + 1], C(j, cnt[i])));
ans = add(ans, mul(fx[j], mul(j, qpow(2, fcnt[i + 1]))));
// cout << suf[j] << ' ' << mul(fsuf[j + 1], C(j, cnt[i])) << ' ' << fx[j] << ' ' << endl;
}
// cout << i <<endl;
swap(fx, f);
}
cout << ans << endl;
// clear
for (int i = 0; i <= n; i++)
cnt[i] = 0;
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
pre[0] = pinv[0] = 1;
for (int i = 1; i <= N - 5; i++)
pre[i] = 1ll * pre[i - 1] * i % mod;
pinv[N - 6] = inv(pre[N - 6]);
for (int i = N - 7; i >= 0; i--)
pinv[i] = 1ll * pinv[i + 1] * (i + 1) % mod;
// cout<<pinv[2]*2%mod<<endl;
int T = 1;
cin >> T;
while (T--)
solve();
}
/*
4
3
0 0 1
4
0 0 1 1
5
0 0 1 2 2
4
1 1 1 1
*/