[JXOI2018]排序问题

嘟嘟嘟


这是今天做的第二道九条可怜的题,现在对他的题的印象是:表面清真可做,实则毒瘤坑人。


首先要感谢吉司机,我期望学的特烂,好在样例直接告诉我们期望怎么求了。
\(b_i\)表示第\(i\)个不同的数的出现次数,那么期望就是

\[\frac{(n + m)!}{b_1! * b_2! * \ldots b_{tot}!} \]

所以我们只要让分母尽量小就行了。


这个问题也不难,想想就知道,只要让所有\(b_i\)尽可能接近就行了。
证明很简单,以两个数为例:如果\(a + b = x\),问\(a, b\)取什么值的时候\(a! * b!\)最小。
\(a_1 + b_1 = x, a_2 + b_2 = x\),其中\(a_1 \geqslant b_1, a_2 \geqslant b_2, a_1 > a_2, b_1 < b_2\),那么\(a_1! * b_1! = a_2! * b_2! * \frac{(a_2 + 1) * (a_2 + 2) * \ldots * a_1}{(b_1 + 1) * (b_1 + 2) * \ldots * b_2}\)。因为\(a_1 = a_2 + \Delta t, b_1 = b_2 - \Delta t\),所以后面的分数的分子分母项数相同,而分子的每一项都大于分母,所以整个分数必定大于1.


尽量接近,翻译过来就是最大的最小,那么自然想到二分。然后似乎就没了。
但为什么裸二分不对呢,因为有的\(b_i\)可能在原数列中存在,且非常大,那么二分的时候就不能考虑他了,必须在后面单独算贡献。
令二分的答案是\(x\),我们再列一个二元一次方程组解出来哪些\(b_i\)等于\(x\),哪些\(b_i\)\(x - 1\)。(手算一下就行)


这题最后一个坑点就是用map会超时,秉着死也不写哈希表的精神,离散化后直接\(O(nlogn)\)预处理加上吸氧,总算给过了(loj就可以过)。虽然都带个\(log\),但map常数似乎真不小……


代码奉上

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
#include<map>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
//#define int long long
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 2e6 + 5;
const int maxN = 1.2e7 + 5;
const ll mod = 998244353;
const int base = 999979;
inline ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) last = ch, ch = getchar();
  while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
  if(last == '-') ans = -ans;
  return ans;
}
inline void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}

int n, _n, m, l, r;
ll a[maxn], b[maxn], c[maxn], cnt = 0, tot = 0;
ll bnum[maxn], cnum[maxn];

In ll quickpow(ll a, ll b)
{
  ll ret = 1;
  for(; b; b >>= 1, a = a * a % mod)
    if(b & 1) ret = ret * a % mod;
  return ret;
}

ll fac[maxN], inv[maxN];
In void init()
{
  fac[0] = inv[0] = 1;
  for(int i = 1; i < maxN; ++i) fac[i] = fac[i - 1] * i % mod;
  inv[maxN - 1] = quickpow(fac[maxN - 1], mod - 2);
  for(int i = maxN - 2; i; --i) inv[i] = inv[i + 1] * (i + 1) % mod;
}

In bool judge(int x)
{
  ll sum = cnt * x;
  for(int i = 1; i <= tot; ++i) if(cnum[i] < x) sum += x - cnum[i];
  return sum >= m;
}

In ll solve()
{
  n = read(), m = read(), l = read(), r = read();
  cnt = r - l + 1, tot = 0;
  for(int i = 1; i <= n; ++i)
    {
      a[i] = b[i] = read(); 
      bnum[i] = 0;
    }
  sort(b + 1, b + n + 1);
  _n = unique(b + 1, b + n + 1) - b - 1;
  for(int i = 1; i <= n; ++i)
    ++bnum[lower_bound(b + 1, b + _n + 1, a[i]) - b];
  for(int i = 1; i <= _n; ++i)
      if(b[i] >= l && b[i] <= r) c[++tot] = b[i], cnum[tot] = bnum[i], --cnt;
  ll L = 0, R = m + n;
  while(L < R)
    {
      int mid = (L + R) >> 1;
      if(judge(mid)) R = mid;
      else L = mid + 1;
    }
  ll cnt2 = cnt, sum = 0;
  for(int i = 1; i <= tot; ++i) if(cnum[i] < L) ++cnt2, sum += cnum[i];
  ll ans = fac[n + m];
  if(cnt2 * L == m + sum) ans = ans * quickpow(inv[L], cnt2) % mod;
  else
    {
      ll x = (1 - L) * cnt2 + m + sum, y = cnt2 * L - m - sum;
      ll tp1 = quickpow(inv[L], x), tp2 = quickpow(inv[L - 1], y);
      ans = ans * tp1 % mod * tp2 % mod;
    }
  for(int i = 1; i <= _n; ++i) if(b[i] < l || b[i] > r) ans = ans * inv[bnum[i]] % mod;
  for(int i = 1; i <= tot; ++i) if(cnum[i] >= L) ans = ans * inv[cnum[i]] % mod;
  return ans;
}

int main()
{
  init();
  int T = read();
  while(T--) write(solve()), enter;
  return 0;
}
posted @ 2019-03-20 16:27  mrclr  阅读(221)  评论(0编辑  收藏  举报