[JXOI2018]排序问题
大模拟
显然这个期望次数是\(\frac{(n+m)!}{\prod a_i!}\),\(a_i\)表示第\(i\)个数出现的次数,我们要最大化这个值只需要最小化\(\prod a_i!\)就好了
要加入\(m\)个范围在\([l,r]\)的数,肯定不会影响在原序列里出现过的且不属于\([l,r]\)的数的出现次数,这个直接算就好了
我们有\(r-l+1\)个数可以加入,一个贪心是我们先加入出现次数最少的数,正确性显然
于是用小根堆维护一下就有\(50pts\)了
我们可以把每一种数的出现次数搞出来,再用一个桶表示出现次数为\(i\)的有多少种数,这样可以转化成一个类似区间覆盖的东西,之后就没有了
离散化写跪调了一上午,std::unique竟然会改变原数列
代码
#include <bits/stdc++.h>
#include <tr1/unordered_map>
using namespace std::tr1;
#define re register
#define max(a, b) ((a) > (b) ? (a) : (b))
const int mod = 998244353;
const int maxn = 2e5 + 5;
const int maxM = 1e7 + 2e5 + 5;
inline int read() {
char c = getchar();
int x = 0;
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') x = (x << 3) + (x << 1) + c - 48, c = getchar();
return x;
}
inline int ksm(int a, int b) {
int S = 1;
for (; b; b >>= 1, a = 1ll * a * a % mod)
if (b & 1) S = 1ll * S * a % mod;
return S;
}
int n, m, l, r, sz, tot, T, U;
unordered_map<int, int> ma;
std::vector<int> v[maxn >> 1];
int tax[maxn], a[maxn], b[maxn], c[maxn];
int ifac[maxM], fac[maxM];
int N[maxn >> 1], M[maxn >> 1], L[maxn >> 1], R[maxn >> 1];
inline int find(int x) {
int lx = 1, ry = sz;
while (lx <= ry) {
int mid = lx + ry >> 1;
if (c[mid] == x)
return mid;
if (c[mid] < x)
lx = mid + 1;
else
ry = mid - 1;
}
return 0;
}
int main() {
T = read();
fac[0] = ifac[0] = 1;
for (re int i = 1; i <= T; i++) {
N[i] = read(), M[i] = read(), L[i] = read(), R[i] = read(), U = max(U, N[i] + M[i]);
for (re int j = 0; j < N[i]; j++) v[i].push_back(read());
}
for (re int i = 1; i <= U; i++) fac[i] = 1ll * fac[i - 1] * i % mod;
ifac[U] = ksm(fac[U], mod - 2);
for (re int i = U - 1; i; --i) ifac[i] = 1ll * ifac[i + 1] * (i + 1) % mod;
for (re int t = 1; t <= T; t++) {
tot = 0, n = N[t], m = M[t], l = L[t], r = R[t];
int ans = fac[n + m];
for (re int i = 0; i < n; i++)
if (v[t][i] < l || v[t][i] > r)
ma[v[t][i]]++;
else c[++tot] = v[t][i];
for (re int i = 0; i < n; i++) {
if (v[t][i] >= l && v[t][i] <= r) continue;
ans = 1ll * ans * ifac[ma[v[t][i]]] % mod;
ma[v[t][i]] = 0;
}
std::sort(c + 1, c + tot + 1);
for (re int i = 1; i <= tot; i++) a[i] = c[i];
sz = std::unique(c + 1, c + tot + 1) - c - 1;
for (re int i = 1; i <= tot; i++) tax[a[i] = find(a[i])]++;
for (re int i = 1; i <= sz; i++) b[tax[i]]++, tax[i] = 0;
int res = r - l + 1 - sz, now = 0, g = 1;
for (re int i = 1; i <= tot; i++) g = 1ll * g * ksm(fac[i], b[i]) % mod;
for (re int i = 1; i <= tot; i++) {
if (!b[i]) continue;
if (1ll * (i - now) * res > m) {
int k = m / res;
g = 1ll * g * ksm(1ll * fac[now + k] * ifac[now] % mod, res) % mod;
g = 1ll * g * ksm(k + now + 1, m % res) % mod;
m = 0; break;
}
m -= (i - now) * res;
g = 1ll * g * ksm(1ll * fac[i] * ifac[now] % mod, res) % mod;
res += b[i]; now = i;
}
if (m) {
int k = m / res;
g = 1ll * g * ksm(1ll * fac[now + k] * ifac[now] % mod, res) % mod;
g = 1ll * g * ksm(k + now + 1, m % res) % mod;
}
for (re int i = 1; i <= tot; i++) b[i] = 0;
printf("%d\n", 1ll * ans * ksm(g, mod - 2) % mod);
}
}